Jax框架:通过显存分析判断操作是否进行jit编译

相关:

https://jax.readthedocs.io/en/latest/device_memory_profiling.html



代码:

import jax
import jax.numpy as jnp
import jax.profiler

def func1(x):
  return jnp.tile(x, 10) * 0.5

def func2(x):
  y = func1(x)
  return y, jnp.tile(x, 10) + 1

x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))
y, z = func2(x)

z.block_until_ready()

jax.profiler.save_device_memory_profile("memory.prof")

显存分析的示意图:

image


jax.random.normal 操作,经过jit编译:

image



image




jnp.tile 操作,不经过jit编译:

image



posted on 2024-01-19 15:26  Angry_Panda  阅读(69)  评论(0)    收藏  举报

导航