dialect
MLIR-HLO是使用MLIR系统构建起来的可为不同硬件进行编译的独立于XLA的端到端编译器,但是MLIR-HLO可以应用在XLA项目中。MLIR-HLO定义了3种方言(上图中的Dialect)以支持HLO编译,分别为chlo、mhlo、func。
func_dialect
_src/lib/mlir/dialects/__init__.py
Python
import jaxlib.mlir.dialects.builtin as builtin
import jaxlib.mlir.dialects.chlo as chlo
import jaxlib.mlir.dialects.mhlo as mhlo
import jaxlib.mlir.dialects.func as func
jax/interprets/mlir.py中
Python
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import func as func_dialect
由官网开发日志可知22年1月27日后JAX使用MHLO作为默认编译器IR,这里只有mhlo和hlo两种选择。
jax/src/stages.py中
Python
class XlaLowering(Lowering):
"""Adapts our various internal XLA-backed computations into a ``Lowering``."""
def hlo(self) -> xc.XlaComputation:
"""Return an HLO representation of this computation."""
raise NotImplementedError("must override")
def mhlo(self) -> mlir.ir.Module:
"""Return an MHLO representation of this computation."""
raise NotImplementedError("must override")
def compile(self) -> Executable:
raise NotImplementedError("must override")
def as_text(self, dialect: Optional[str] = None) -> str:
if dialect is None or dialect == "mhlo":
return str(self.mhlo())
elif dialect == "hlo":
return self.hlo().as_hlo_text()
else:
raise ValueError(f"unknown dialect: {dialect}")
def compiler_ir(self, dialect: Optional[str] = None) -> Any:
if dialect is None or dialect == "mhlo":
return self.mhlo()
elif dialect == "hlo":
return self.hlo()
else:
raise ValueError(f"unknown dialect: {dialect}")
(参考链接:Change log — JAX documentation)
jaxlib 0.1.76 (Jan 27, 2022)
With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR by default.
jax 0.2.28 (Feb 1, 2022)
jax.jit(f).lower(...).compiler_ir() now defaults to the MHLO dialect if no dialect= is passed.
The jax.jit(f).lower(...).compiler_ir(dialect='mhlo') now returns an MLIR ir.Module object instead of its string representation.