背景
需要对3B模型进行蒸馏,一张4090的卡无法完成实验。完成这个实验的前提是需要两张卡,一张用来加载学生模型,一张用来加载教师模型。
多卡使用
这里的多卡使用并不是像以往的方式,使用dataloaderparallel等方式,这种是数据并行的策略,不适合蒸馏的场景,因为蒸馏是一个模型做推理,一个模型做训练,并非数据并行计算。因此分开加载模型,一个用来训练,一个用来推理,训练的数据和训练卡放在同一个设备上即可。
device_stu = "cuda:0"
device_teh = "cuda:1"
# 模型加载
student_model.to(device_stu)
teacher_model.to(device_teh)
student_model.train()
teacher_model.eval()
# 数据加载
for batch_stu in dataloader(text):
batch_teh = copy.deepcopy(batch_stu)
batch_stu.to(device_stu)
batch_teh.to(device_teh)
logits_stu = student_model(**batch_stu)
logits_teh = teacher_molde(**batch_teh)
loss = kl(logits_stu, logits_teh, device_stu)
loss.backend()
代码分析:
- 学生模型和教师模型分开加载
- 数据需要深度拷贝,否会出现设备不一致的错误
- 把logits放在相同的设备,并计算损失
- 反向传播
浙公网安备 33010602011771号