AOT
https://jax.readthedocs.io/en/latest/aot.html
- Stage out a specialized version of the original Python callable
Fto an internal representation. The specialization reflects a restriction ofFto input types inferred from properties of the argumentsxandy(usually their shape and element type). - Lower this specialized, staged-out computation to the XLA compiler’s input language, MHLO.
- Compile the lowered HLO program to produce an optimized executable for the target device (CPU, GPU, or TPU).
- Execute the compiled executable with the arrays
xandyas arguments.
JAX’s AOT API gives you direct control over steps #2, #3, and #4 (but not #1), plus some other features along the way. An example:
Python
>>> import jax
>>> import jax.numpy as jnp
>>> import numpy as np
>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4
>>> lowered = jax.jit(f).lower(x, y)
>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f.0 {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = mhlo.constant dense<2> : tensor<i32>
%1 = mhlo.multiply %0, %arg0 : tensor<i32>
%2 = mhlo.add %1, %arg1 : tensor<i32>
return %2 : tensor<i32>
}
}
>>> compiled = lowered.compile()
>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()[0]['flops']
2.0
>>> # Execute the compiled function!
>>> compiled(x, y)
DeviceArray(10, dtype=int32)