jax_to_pb
导出tensorflow能识别的pb文件¶
Python
// An example for reading a HloModule from a HloProto file and execute the
// module on PJRT CPU client.
//
// To build a HloModule,
//
// $ python3 jax/tools/jax_to_hlo.py \
// --fn examples.jax_cpp.prog.fn \
// --input_shapes '[("x", "f32[2,2]"), ("y", "f32[2,2]")]' \
// --constants '{"z": 2.0}' \
// --hlo_text_dest /tmp/fn_hlo.txt \
// --hlo_proto_dest /tmp/fn_hlo.pb
//
// To load and run the HloModule,
//
// $ bazel build examples/jax_cpp:main --experimental_repo_remote_exec --check_visibility=false
// $ bazel-bin/examples/jax_cpp/main
// 2021-01-12 15:35:28.316880: I examples/jax_cpp/main.cc:65] result = (
// f32[2,2] {
// { 1.5, 1.5 },
// { 3.5, 3.5 }
// }
// )
修改jax_to_ir.py
Python
$ cat prog.py
import jax.numpy as jnp
def fn(x, y, z):
return jnp.dot(x, y) / z
$ python jax_to_ir.py \
--fn prog.fn \
--input_shapes '[("y", "f32[128,32]"), ("x", "f32[8,128]")]' \
--constants '{"z": 3.14159}' \
--ir_format HLO \
--ir_human_dest tf_tmp/fn_hlo.txt \
--ir_dest tf_tmp/fn_hlo.pb
Python
(tf2.1) ken@lynxi:~/workspace/test/jax/pb_jax_to_tf$ tree
.
├── ir
│ ├── jax_ir0_jit_prim_fun.mlir
│ └── jax_ir1_jit_prim_fun.mlir
├── jax_to_ir.py
├── prog.py
├── __pycache__
│ └── prog.cpython-37.pyc
├── tf_tmp
│ ├── fn_hlo.pb
│ └── fn_hlo.txt
└── tmp
├── foo
│ ├── module_0000.jit_prim_fun.before_optimizations.dot
│ ├── module_0000.jit_prim_fun.before_optimizations.hlo.pb
│ ├── module_0000.jit_prim_fun.before_optimizations.html
│ ├── module_0000.jit_prim_fun.before_optimizations.txt
│ ├── module_0000.jit_prim_fun.cpu_after_optimizations-buffer-assignment.txt
│ ├── module_0000.jit_prim_fun.cpu_after_optimizations.dot
│ ├── module_0000.jit_prim_fun.cpu_after_optimizations.hlo.pb
│ ├── module_0000.jit_prim_fun.cpu_after_optimizations.html
│ ├── module_0000.jit_prim_fun.cpu_after_optimizations.top_level.html
│ ├── module_0000.jit_prim_fun.cpu_after_optimizations.txt
│ ├── module_0000.jit_prim_fun.ir-no-opt.ll
│ ├── module_0000.jit_prim_fun.ir-no-opt-noconst.ll
│ ├── module_0000.jit_prim_fun.ir-with-opt.ll
│ ├── module_0000.jit_prim_fun.ir-with-opt-noconst.ll
│ ├── module_0000.jit_prim_fun.o
│ ├── module_0001.jit_prim_fun.before_optimizations.dot
│ ├── module_0001.jit_prim_fun.before_optimizations.hlo.pb
│ ├── module_0001.jit_prim_fun.before_optimizations.html
│ ├── module_0001.jit_prim_fun.before_optimizations.txt
│ ├── module_0001.jit_prim_fun.cpu_after_optimizations-buffer-assignment.txt
│ ├── module_0001.jit_prim_fun.cpu_after_optimizations.dot
│ ├── module_0001.jit_prim_fun.cpu_after_optimizations.hlo.pb
│ ├── module_0001.jit_prim_fun.cpu_after_optimizations.html
│ ├── module_0001.jit_prim_fun.cpu_after_optimizations.top_level.html
│ ├── module_0001.jit_prim_fun.cpu_after_optimizations.txt
│ ├── module_0001.jit_prim_fun.ir-no-opt.ll
│ ├── module_0001.jit_prim_fun.ir-no-opt-noconst.ll
│ ├── module_0001.jit_prim_fun.ir-with-opt.ll
│ ├── module_0001.jit_prim_fun.ir-with-opt-noconst.ll
│ └── module_0001.jit_prim_fun.o
└── tf_dump_graph
├── before_increase_dynamism_for_auto_jit_pass.pbtxt
├── before_mark_for_compilation.pbtxt
├── mark_for_compilation_annotated.pbtxt
└── mark_for_compilation.pbtxt
6 directories, 41 files
以下是fn_hlo.pb的内容

读取jax导出的pb文件¶
使用v1方法读入pb文件
/home/ken/workspace/test/jax/pb_test
Python
import tensorflow as tf
with tf.io.gfile.GFile("./tf_tmp/fn_hlo.pb", "rb") as f:
graph_def = tf.compat.v1.GraphDef()
loaded = graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
graph_nodes=[n for n in graph_def.node]
names = []
for t in graph_nodes:
names.append(t.name)
print(names)