9. hlo与mhlo的相互转换
JAX中hlo与mhlo的相互转化¶
jax/_src/stages.py中是空的,真正的实现在pxla.py中。
C++
class XlaLowering(Lowering):
"""Adapts our various internal XLA-backed computations into a ``Lowering``."""
def hlo(self) -> xc.XlaComputation:
"""Return an HLO representation of this computation."""
raise NotImplementedError("must override")
def mhlo(self) -> mlir.ir.Module:
"""Return an MHLO representation of this computation."""
raise NotImplementedError("must override")
def compile(self) -> Executable:
raise NotImplementedError("must override")
jax/interpreters/pxla.py
Python
class MeshComputation(stages.XlaLowering):
_hlo: Union[ir.Module, xc.XlaComputation]
_executable: Optional[MeshExecutable]
def __init__(self, name: str, hlo: Union[ir.Module, xc.XlaComputation],
donated_invars: Sequence[bool], **compile_args):
self._name = name
self._hlo = hlo
self._donated_invars = donated_invars
self.compile_args = compile_args
self._executable = None
# -- stages.XlaLowering overrides
def hlo(self) -> xc.XlaComputation:
# this is a method for api consistency with dispatch.XlaComputation
if isinstance(self._hlo, xc.XlaComputation):
return self._hlo
return xe.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(self._hlo),
use_tuple_args=self.compile_args["tuple_args"])
def mhlo(self) -> ir.Module:
if isinstance(self._hlo, xc.XlaComputation):
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
with mlir.make_ir_context():
return ir.Module.parse(module_str)
return self._hlo
C++
class PmapComputation(stages.XlaLowering):
_hlo: Union[ir.Module, xc.XlaComputation]
_executable: Optional[PmapExecutable]
def __init__(self, hlo: Union[ir.Module, xc.XlaComputation], **compile_args):
self._executable = None
self._hlo = hlo
self.compile_args = compile_args
# -- stages.XlaLowering overrides
def hlo(self) -> xc.XlaComputation:
# this is a method for api consistency with dispatch.XlaComputation
if isinstance(self._hlo, xc.XlaComputation):
return self._hlo
else:
return xe.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(self._hlo),
use_tuple_args=self.compile_args["tuple_args"])
def mhlo(self) -> ir.Module:
if isinstance(self._hlo, xc.XlaComputation):
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
with mlir.make_ir_context():
return ir.Module.parse(module_str)
return self._hlo