5. mhlo & hlo
/home/ken/workspace/test/jax/test_lower
Python
import jax
import jax.numpy as jnp
import numpy as np
def f(x, y):
return 2 * x + y
x, y = 3, 4
lowered = jax.jit(f).lower(x, y)
# Print lowered MHLO
print(lowered.as_text()) # eqaul to print(lowered.as_text("mhlo"))
# if you want hlo, you should use
# print(lowered.as_text("hlo"))
HLO¶
Python
HloModule jit_f, entry_computation_layout={(s32[],s32[])->s32[]}
ENTRY main.6 {
Arg_0.1 = s32[] parameter(0)
constant.3 = s32[] constant(2)
multiply.4 = s32[] multiply(Arg_0.1, constant.3)
Arg_1.2 = s32[] parameter(1)
ROOT add.5 = s32[] add(multiply.4, Arg_1.2)
}