Jax框架 —— 如何在没有GPU和TPU的设备上debug代码 —— 在CPU上使用GPU仿真设置 —— Jax框架在多卡设备上的自动并行特性的仿真体验

Jax计算框架是Google用来取代Tensorflow的新一代计算框架,这个框架使用类似pytorch的技术,但是在pytorch技术之上加入了更加强大的技术,但是这也导致该框架使用起来要比pytorch难一些,但是该框架的计算性能又比较优秀,因此依旧具有较大的吸引力。

Jax框架的性能优势主要体现在单机多卡GPU设备的自动并行上,Jax和TensorFlow同样有着编译器的优势,Jax可以通过简单的代码编写实现多卡并行的方式,在数据并行和模型并行的基础之上又实现了数据切割并行的方式,可以说Jax框架在原生水平上支持目前所有的并行操作,并且在并行的代码编写上又比pytorch更加简单,但是由于并行的复杂性,Jax的自动并行往往有着不好掌握的特点,虽然在编写上代码更加精炼,但是逻辑难度加大了,不容易掌握。


可以说,Jax对于pytorch来说最大的优点就是对于多卡设备,或者说是大规模计算的原生支持上,比如在Jax官方的讲解文档之中都是默认对8个TPU进行操作的,为实现同样的操作(代码可运行),我们也需要使用8个卡的GPU,这个要求虽然对于企业来说都是极为容易实现的,但是对于个人用户这个要求却是极为困难的。如果没有多卡的运行环境,我们也就无法体验Jax框架的特性。

为了方便体验和debug Jax框架的多卡自动计算的特性,Jax框架给出了多卡仿真的设置参数:
官方文档:
https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html

image




不过要注意的是,这个仿真设置的前提是本机没有GPU或TPU设备,然后才可以进行GPU和TPU的仿真设置,比如本地主机有两个GPU,那么是无法实现该仿真设置的:

image

image

可以看到,在本地存在GPU或TPU设备的时候,是无法实现仿真环境设置的。比如我这个环境下就有两个GPU,但是由于Jax的运行特性往往是在多卡(8卡以上的主机环境)下,因此我们要运行调试一个8卡的Jax环境我们就需要Jax的多卡仿真,使用CPU仿真出8个GPU和TPU,这时我们就需要把真实存在的GPU或TPU屏蔽掉(因为物理真实存在的GPU数太少,无法调试Jax的多卡大规模计算代码):

屏蔽已有的真实物理GPU或TPU:

export CUDA_VISIBLE_DEVICES=-1

屏蔽掉真实存在的两个物理GPU后,用CPU仿真出8个卡的运行环境:(注意,这里是使用CPU仿真8个卡的环境,只能作为调试之用,判断Jax的多卡运行的代码的正取与否,并不能获得真实8卡的计算性能,要注意,这里只是使用CPU仿真8个GPU):

image


给出 Jax 框架的CPU仿真多卡GPU的代码:

import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
jax.devices()



posted on 2024-01-08 16:05  Angry_Panda  阅读(364)  评论(0)    收藏  举报

导航