10. jaxpr转mhlohlo
源码中出现的ir一般是指
jax_src/lib/mlir中的
jax/interpreters/mlir.py
Python
def lower_jaxpr_to_fun( # 该函数会调用jaxpr_subcomp将jaxpr转为MHLO
ctx: ModuleContext,
name: str,
jaxpr: core.ClosedJaxpr,
effects: Sequence[core.Effect],
*,
create_tokens: bool = False,
public: bool = False,
replace_tokens_with_dummy: bool = False,
replicated_args: Optional[Sequence[bool]] = None,
arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None,
use_sharding_annotations: bool = True,
input_output_aliases: Optional[Sequence[Optional[int]]] = None,
num_output_tokens: int = 0,
) -> func_dialect.FuncOp:
"""Lowers jaxpr and its callees to an IR function.
Assumes that an MLIR context, location, and insertion point are set.
Args:
"""
...
out_vals, tokens_out = jaxpr_subcomp(ctx.replace(name_stack=callee_name_stack),
jaxpr.jaxpr, tokens_in, map(ir_constants, jaxpr.consts),
*args)
...
# jaxpr转为HMLO
def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr,
tokens: TokenSet,
consts: Sequence[Sequence[ir.Value]],
*args: Sequence[ir.Value]
) -> Tuple[Sequence[Sequence[ir.Value]], TokenSet]:
"""Lowers a jaxpr into mHLO, inlined into an existing function.
Assumes that an MLIR context, location, and insertion point are set.
"""
即时编译时,生成jaxpr后, jaxpr可以由xla_callable调用jaxpr_subcomp将其编译为XLA HLO程序。
(参考链接:Autodidax: JAX core from scratch — JAX documentation)
https://jax.readthedocs.io/en/latest/autodidax.html?highlight=hlo