如何在jax框架的jit中快速的实现循环结构 —— Jax框架的jit编译是否可以使用循环结构
相关:
Jax框架的jit编译是否可以使用循环结构,如果使用循环结构需要注意什么
前文中给出了jit下使用python做循环结构的代码,下面再次给出这个代码,这个代码为jupyter-notebook环境,并且在jit编译时需要60秒,运行9.4秒左右:
from jax import jit, random
import jax.numpy as jnp
from functools import partial
@partial(jit, static_argnums=(2,))
def f(x, y, z):
print("Running f():")
print(f" x = {x}")
print(f" y = {y}")
print(f" z = {z}")
for _ in range(z):
y = jnp.dot(x + 0.0001, y + 0.0001)
print(f" result = {y}")
return y
key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))
z = 10000
%timeit f(x, y, z).block_until_ready()
Jax作为TensorFlow的改进版,在一定程度上和TensorFlow的特性保持一致,比如在循环、判断这样的结构上,如果在jit中使用python做控制就会使效率比较低(当然我们可以像pytorch那边把循环结构提出来后再用python来实现,这样的效率就会几乎无差别),但是jax和TensorFlow一样提供了可以进行原生编译的循环和判断结构,下面给出jax.lax.scan来实现上面代码的循环结构的实现:
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(0)
x = random.normal(key, (10000, 10000))
y = random.normal(key, (10000, ))
def body(arr, extra):
t = jnp.dot(arr[0] + 0.0001, arr[1] + 0.0001)
return (arr[0], t), extra
%timeit lax.scan(body, (x, y), jnp.ones(10000))
运行上面这个用jax.lax.scan实现的循环结构可以使用几乎不用考虑的编译时间然后直接进入运算阶段,因为这里并不需要像在jit中将python循环结构展开编译的过程,而是直接使用原生的编译好的循环操作来实现。
运行结果:

使用jax的原生循环或判断语句进行流程控制,可以避免掉大量的编译时间(比如这里的例子就避免掉了60秒的编译时间而且我们可以灵活的控制循环次数),效率更高,并且更加的灵活。但是,使用jax的原生控制语句和TensorFlow中使用TensorFlow的原生控制语句一样,都面临着增加学习成本的问题,并且到时使用难度加大,这一点估计是编译型计算框架无法摆脱的native issue问题。
比如这里给出30000次循环次数的代码:(几乎无编译时间)

本博客是博主个人学习时的一些记录,不保证是为原创,个别文章加入了转载的源地址,还有个别文章是汇总网上多份资料所成,在这之中也必有疏漏未加标注处,如有侵权请与博主联系。
如果未特殊标注则为原创,遵循 CC 4.0 BY-SA 版权协议。
posted on 2024-01-10 22:33 Angry_Panda 阅读(131) 评论(0) 收藏 举报
浙公网安备 33010602011771号