背景

需要对3B模型进行蒸馏,一张4090的卡无法完成实验。完成这个实验的前提是需要两张卡,一张用来加载学生模型,一张用来加载教师模型。

多卡使用

这里的多卡使用并不是像以往的方式,使用dataloaderparallel等方式,这种是数据并行的策略,不适合蒸馏的场景,因为蒸馏是一个模型做推理,一个模型做训练,并非数据并行计算。因此分开加载模型,一个用来训练,一个用来推理,训练的数据和训练卡放在同一个设备上即可。

device_stu = "cuda:0"
device_teh = "cuda:1"
# 模型加载
student_model.to(device_stu)
teacher_model.to(device_teh)
student_model.train()
teacher_model.eval()
# 数据加载
for batch_stu in dataloader(text):
	batch_teh = copy.deepcopy(batch_stu)
	batch_stu.to(device_stu)
	batch_teh.to(device_teh)
	logits_stu = student_model(**batch_stu)
	logits_teh = teacher_molde(**batch_teh)
	loss = kl(logits_stu, logits_teh, device_stu)
	loss.backend()

代码分析:

  • 学生模型和教师模型分开加载
  • 数据需要深度拷贝,否会出现设备不一致的错误
  • 把logits放在相同的设备,并计算损失
  • 反向传播
posted on 2025-01-20 14:43  蔚蓝色の天空  阅读(49)  评论(0)    收藏  举报