jax中对单步操作的缓存对性能造成的影响
代码:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
def selu(x, alpha=1.65, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = random.normal(key, (1000005,))
%timeit selu(x).block_until_ready()
运行结果:

再次运行:

修改array的shape:
代码:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
def selu(x, alpha=1.65, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = random.normal(key, (1000003,))
%timeit selu(x).block_until_ready()
运行结果:

再次运行:

PS. 由此可以看出,jax对单步运行其实也是使用缓存操作的,对单步操作也可以通过缓存来进行多次调用的速度提升的。
本博客是博主个人学习时的一些记录,不保证是为原创,个别文章加入了转载的源地址,还有个别文章是汇总网上多份资料所成,在这之中也必有疏漏未加标注处,如有侵权请与博主联系。
如果未特殊标注则为原创,遵循 CC 4.0 BY-SA 版权协议。
posted on 2024-01-09 10:59 Angry_Panda 阅读(30) 评论(0) 收藏 举报
浙公网安备 33010602011771号