8. dialect

MLIR-HLO是使用MLIR系统构建起来的可为不同硬件进行编译的独立于XLA的端到端编译器,但是MLIR-HLO可以应用在XLA项目中。MLIR-HLO定义了3种方言(上图中的Dialect)以支持HLO编译,分别为chlo、mhlo、func。

func_dialect

_src/lib/mlir/dialects/__init__.py

Python
import jaxlib.mlir.dialects.builtin as builtin
import jaxlib.mlir.dialects.chlo as chlo
import jaxlib.mlir.dialects.mhlo as mhlo
import jaxlib.mlir.dialects.func as func

jax/interprets/mlir.py

Python
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import func as func_dialect

由官网开发日志可知22年1月27日后JAX使用MHLO作为默认编译器IR,这里只有mhlohlo两种选择。

jax/src/stages.py

Python
class XlaLowering(Lowering):
  """Adapts our various internal XLA-backed computations into a ``Lowering``."""

  def hlo(self) -> xc.XlaComputation:
    """Return an HLO representation of this computation."""
    raise NotImplementedError("must override")

  def mhlo(self) -> mlir.ir.Module:
    """Return an MHLO representation of this computation."""
    raise NotImplementedError("must override")

  def compile(self) -> Executable:
    raise NotImplementedError("must override")

  def as_text(self, dialect: Optional[str] = None) -> str:
    if dialect is None or dialect == "mhlo":
      return str(self.mhlo())
    elif dialect == "hlo":
      return self.hlo().as_hlo_text()
    else:
      raise ValueError(f"unknown dialect: {dialect}")

  def compiler_ir(self, dialect: Optional[str] = None) -> Any:
    if dialect is None or dialect == "mhlo":
      return self.mhlo()
    elif dialect == "hlo":
      return self.hlo()
    else:
      raise ValueError(f"unknown dialect: {dialect}")

(参考链接:Change log — JAX documentation

jaxlib 0.1.76 (Jan 27, 2022)

With jaxlib 0.1.76, JAX uses the MHLO MLIR dialect as its primary target compiler IR by default.

jax 0.2.28 (Feb 1, 2022)

jax.jit(f).lower(...).compiler_ir() now defaults to the MHLO dialect if no dialect= is passed.

The jax.jit(f).lower(...).compiler_ir(dialect='mhlo') now returns an MLIR ir.Module object instead of its string representation.