Jax框架:通过显存分析判断操作是否进行jit编译
相关:
https://jax.readthedocs.io/en/latest/device_memory_profiling.html
代码:
import jax
import jax.numpy as jnp
import jax.profiler
def func1(x):
return jnp.tile(x, 10) * 0.5
def func2(x):
y = func1(x)
return y, jnp.tile(x, 10) + 1
x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))
y, z = func2(x)
z.block_until_ready()
jax.profiler.save_device_memory_profile("memory.prof")
显存分析的示意图:

jax.random.normal 操作,经过jit编译:


jnp.tile 操作,不经过jit编译:

本博客是博主个人学习时的一些记录,不保证是为原创,个别文章加入了转载的源地址,还有个别文章是汇总网上多份资料所成,在这之中也必有疏漏未加标注处,如有侵权请与博主联系。
如果未特殊标注则为原创,遵循 CC 4.0 BY-SA 版权协议。
posted on 2024-01-19 15:26 Angry_Panda 阅读(69) 评论(0) 收藏 举报
浙公网安备 33010602011771号