一个简单的例子测试numpy和Jax的性能对比 (续)
numpy代码:
import numpy as np
import time
x = np.random.random([10000, 10000]).astype(np.float32)
try:
st = time.time()
y = np.matmul(x, x)
print(time.time() - st)
print(y)
except Exception as e:
print(f"error: {e}")

Jax代码:
import jax.numpy as np
from jax import random
import time
x = random.uniform(random.PRNGKey(0), [10000, 10000])
st = time.time()
try:
y = np.matmul(x, x)
print(time.time() - st)
print(y)
except Exception as e:
print(f"error: {e}")

可以说,在这个例子里面,Jax和numpy的性能基本持平。
本博客是博主个人学习时的一些记录,不保证是为原创,个别文章加入了转载的源地址,还有个别文章是汇总网上多份资料所成,在这之中也必有疏漏未加标注处,如有侵权请与博主联系。
如果未特殊标注则为原创,遵循 CC 4.0 BY-SA 版权协议。
posted on 2024-01-03 23:48 Angry_Panda 阅读(55) 评论(0) 收藏 举报
浙公网安备 33010602011771号