Jax的加速层的伪代码/中间层代码的生成和查看

地址:

https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jit-mechanics-tracing-and-static-variables



from jax import make_jaxpr

def f(x, y):
  return jnp.dot(x + 1, y + 1)

make_jaxpr(f)(x, y)


image

posted on 2024-01-09 12:31  Angry_Panda  阅读(20)  评论(0)    收藏  举报

导航