跳转至

4. make_jaxpr

Text Only
graph TB
  jax_code ==> jaxpr
  jaxpr ==> HLO1(HLO xla的输入IR)
  HLO1 ==> MHLO
  MHLO == 优化 ==> MHLO
  MHLO ==> HLO
  HLO ==XLA编译器==> llvm(LLVM_IR等)
  llvm ==> hard(各种硬件加速器)

JAXPR

JAX执行函数转换(JIT、自动微分、自动并行、向量化)时先将python代码转换成中间语言jaxpr,然后再对jaxpr进行转换。JAX使用进程“tracing”生成jaxpr时,JAX通过tracer对象对每一个函数参数进行包装(wrap),然后tracer记录函数运行过程中对这些参数进行的操作。然后依据记录,JAX生成jaxpr。某个函数fun对应的jaxpr可以通过JAX的接口jax.make_jaxpr(fun)进行查看。

jaxpr的语法和教程

https://jax.readthedocs.io/en/latest/jaxpr.html

https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html

Python
jaxpr ::= { lambda Var* ; Var+.
            let Eqn*
            in  [Expr+] }

case1

Python
from jax import make_jaxpr
import jax.numpy as jnp
import numpy as np


def f(x, y):
  return jnp.dot(x + 1, y + 1)

x = np.random.randn(3, 4)
y = np.random.randn(4)
print(make_jaxpr(f)(x, y))

result

C++
{ lambda ; a:f32[3,4] b:f32[4]. let
    c:f32[3,4] = add a 1.0
    d:f32[4] = add b 1.0
    e:f32[3] = dot_general[
      dimension_numbers=(((1,), (0,)), ((), ()))
      precision=None
      preferred_element_type=None
    ] c d
  in (e,) }

case2

Python
import jax

def f(x):
    return jax.numpy.sin(jax.numpy.cos(x))

print(f(3.0))
print(jax.make_jaxpr(f)(3.0))

result

GAS
-0.83602184
{ lambda ; a:f32[]. let b:f32[] = cos a; c:f32[] = sin b in (c,) }