jax框架的 Pallas 方式的GPU扩展不可用
说下深度学习框架的GPU扩展功能的部分,也就是使用个人定制化的GPU代码编写方式来为深度学习框架做扩展。
深度学习框架本身就是一种对GPU功能的一种封装和调用,但是由于太high-level,因此就会摒弃掉一些原有的GPU底层的编程功能,为此可以使用GPU原始功能的代码来为深度学习编写扩展函数。
我们现在常用的深度学习的核函数最初都是以扩展包的扩展函数来出现的,然后再被合并到深度学习框架的原生代码中的。
使用深度学习框架的核函数太high level,用不到很多GPU的细节功能,但是直接使用GPU的原始编程接口,如:CUDA等等,又会由于编程难度较大,无法通用,因此也就有了介于两者中间的GPU扩展功能的编写封装语言,如pytorch所使用的triton,jax框架为了实现同样的中间水平的GPU扩展功能,就给出了jax-triton,通过在jax中使用jax-triton编写GPU扩展功能的代码,翻译成triton,然后再由triton翻译成CUDA。
与其说这种抽象程度介于深度学习框架和底层CUDA语言中间的triton是一种语言不如说是一种翻译器,或者是转换器。
Triton的文档网站:
https://triton-lang.org/main/getting-started/installation.html

但是,现有的深度学习框架在功能上比较完善的是只有两个,即pytorch和TensorFlow,但是TensorFlow由于逐渐被弃用,所以真正意义上的完善功能的深度学习框架只有pytorch,而其他的深度学习框架也都是处于experimental阶段的。
因此在jax框架中使用jax-Triton也不是被主要支持的,可以说在jax中使用jax-triton功能是需要额外安装pip包的,并且需要各个版本都对应,而anaconda中也并没有对其进行支持。
通过外网的讨论可以知道,jax的jax-Triton基本是不可用的,因为能配置出可用的jax-Triton是极难的。
网址:
https://github.com/google/jax/issues/18603

网址:
https://github.com/NVIDIA/JAX-Toolbox/issues/470

posted on 2024-01-17 18:18 Angry_Panda 阅读(124) 评论(0) 收藏 举报
浙公网安备 33010602011771号