跳转至

hlo与mhlo的相互转换

JAX中hlo与mhlo的相互转化

jax/_src/stages.py中是空的,真正的实现在pxla.py中。

C++
class XlaLowering(Lowering):
  """Adapts our various internal XLA-backed computations into a ``Lowering``."""

  def hlo(self) -> xc.XlaComputation:
    """Return an HLO representation of this computation."""
    raise NotImplementedError("must override")

  def mhlo(self) -> mlir.ir.Module:
    """Return an MHLO representation of this computation."""
    raise NotImplementedError("must override")

  def compile(self) -> Executable:
    raise NotImplementedError("must override")

jax/interpreters/pxla.py

Python
class MeshComputation(stages.XlaLowering):
  _hlo: Union[ir.Module, xc.XlaComputation]
  _executable: Optional[MeshExecutable]

  def __init__(self, name: str, hlo: Union[ir.Module, xc.XlaComputation],
               donated_invars: Sequence[bool], **compile_args):
    self._name = name
    self._hlo = hlo
    self._donated_invars = donated_invars
    self.compile_args = compile_args
    self._executable = None

  # -- stages.XlaLowering overrides

  def hlo(self) -> xc.XlaComputation:
    # this is a method for api consistency with dispatch.XlaComputation
    if isinstance(self._hlo, xc.XlaComputation):
      return self._hlo
    return xe.mlir.mlir_module_to_xla_computation(
        mlir.module_to_string(self._hlo),
        use_tuple_args=self.compile_args["tuple_args"])

  def mhlo(self) -> ir.Module:
    if isinstance(self._hlo, xc.XlaComputation):
      module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
      with mlir.make_ir_context():
        return ir.Module.parse(module_str)
    return self._hlo
C++
class PmapComputation(stages.XlaLowering):
  _hlo: Union[ir.Module, xc.XlaComputation]
  _executable: Optional[PmapExecutable]

  def __init__(self, hlo: Union[ir.Module, xc.XlaComputation], **compile_args):
    self._executable = None
    self._hlo = hlo
    self.compile_args = compile_args

  # -- stages.XlaLowering overrides

  def hlo(self) -> xc.XlaComputation:
    # this is a method for api consistency with dispatch.XlaComputation
    if isinstance(self._hlo, xc.XlaComputation):
      return self._hlo
    else:
      return xe.mlir.mlir_module_to_xla_computation(
          mlir.module_to_string(self._hlo),
          use_tuple_args=self.compile_args["tuple_args"])

  def mhlo(self) -> ir.Module:
    if isinstance(self._hlo, xc.XlaComputation):
      module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
      with mlir.make_ir_context():
        return ir.Module.parse(module_str)
    return self._hlo