跳转至

5. mhlo & hlo

/home/ken/workspace/test/jax/test_lower

Python
import jax
import jax.numpy as jnp
import numpy as np

def f(x, y): 
  return 2 * x + y

x, y = 3, 4

lowered = jax.jit(f).lower(x, y)

# Print lowered MHLO
print(lowered.as_text())  # eqaul to print(lowered.as_text("mhlo"))

# if you want hlo, you should use
# print(lowered.as_text("hlo")) 

HLO

Python
HloModule jit_f, entry_computation_layout={(s32[],s32[])->s32[]}

ENTRY main.6 {
  Arg_0.1 = s32[] parameter(0)
  constant.3 = s32[] constant(2)
  multiply.4 = s32[] multiply(Arg_0.1, constant.3)
  Arg_1.2 = s32[] parameter(1)
  ROOT add.5 = s32[] add(multiply.4, Arg_1.2)
}

MHLO

Python
module @jit_f {
  func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
    %0 = mhlo.constant dense<2> : tensor<i32>
    %1 = mhlo.multiply %0, %arg0 : tensor<i32>
    %2 = mhlo.add %1, %arg1 : tensor<i32>
    return %2 : tensor<i32>
  }
}