# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import torch
import torch.nn.functional as F
import time
import deepspeed
from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
from deepspeed.accelerator import get_accelerator
from deepspeed_chat.dschat.utils.utils import print_rank_0
def print_all_ranks(tag, value, rank):
world_size = torch.distributed.get_world_size()
all_tensor = torch.zeros(world_size, dtype=torch.float32).to(
get_accelerator().current_device_name())
all_tensor[rank] = value
torch.distributed.all_reduce(all_tensor, op=torch.distributed.ReduceOp.SUM)
print_rank_0(f'{tag} {all_tensor}', rank)
def get_model_norm(model):
with torch.no_grad():
total = 0.0
for param in model.parameters():
should_gather = hasattr(
param,
'ds_id') and param.ds_status == ZeroParamStatus.NOT_AVAILABLE
with deepspeed.zero.GatheredParameters(param,
enabled=should_gather):
total += float(param.float().norm())
return total
def gather_log_probs(logits, labels):#lohit:[2, 511, 50272]. label[2, 511]
log_probs = F.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) #log_probs.gather(dim=-1, index=labels.unsqueeze(-1)) 用于从 log_probs 中提取每个 token 对应 labels 的 log 概率。这里 labels 的形状是 [2, 511],但为了匹配 gather 的维度要求,我们使用 unsqueeze(-1) 将 labels 的形状扩展为 [2, 511, 1]。
return log_probs_labels.squeeze(-1)
class DeepSpeedPPOTrainer():
def __init__(self, rlhf_engine, args):
self.rlhf_engine = rlhf_engine
self.actor_model = self.rlhf_engine.actor
self.critic_model = self.rlhf_engine.critic
self.ref_model = self.rlhf_engine.ref
self.reward_model = self.rlhf_engine.reward
self.tokenizer = self.rlhf_engine.tokenizer
self.args = args
self.max_answer_seq_len = args.max_answer_seq_len
self.end_of_conversation_token_id = self.tokenizer(
args.end_of_conversation_token)['input_ids'][-1]
self.z3_enabled = args.actor_zero_stage == 3
self.compute_fp32_loss = self.args.compute_fp32_loss
self.last_generated_experience = None
self.kl_ctl = 0.1
self.clip_reward_value = 5
self.cliprange = 0.2
self.cliprange_value = 0.2
self.gamma = 1.0
self.lam = 0.95
self.generate_time = 0.0
def _generate_sequence(self, prompts, mask, step):
max_min_length = self.max_answer_seq_len + prompts.shape[1] #最大回复长度256+prompt长度256
# 由于在启用 do_sample 后发生了概率/nan 错误,已添加此项修复:
# https://huggingface.co/meta-llama/Llama-2-7b-hf/commit/6fdf2e60f86ff2481f2241aaee459f85b5b0bbb9
if self.actor_model.module.config.model_type == "llama":
kwargs = dict(do_sample=False)
else:
kwargs = dict()
# 演员生成序列
with torch.no_grad():
seq = self.actor_model.module.generate(
prompts,
attention_mask=mask,
max_length=max_min_length,
pad_token_id=self.tokenizer.pad_token_id,
synced_gpus=self.z3_enabled,
**kwargs)
# 过滤掉没有答案(或非常短)的序列。这种情况发生在用户直接使用预训练模型检查点而没有进行有监督微调时。
# 注意:这会导致每个 GPU 拥有不同数量的样本。
batch_size = seq.shape[0] #2
prompt_length = prompts.shape[1] # 256
self.prompt_length = prompt_length # 256
ans = seq[:, prompt_length:] # 从256到512部分是生成的序列,它直接在原始数据续上了
valid_ans_len = (ans != self.tokenizer.pad_token_id).sum(dim=-1)
if self.args.print_answers and (step % self.args.print_answers_interval== 0):
print(f"--- prompt --> step={step}, rank={torch.distributed.get_rank()}, {self.tokenizer.batch_decode(prompts, skip_special_tokens=True)}")
print(f"--- ans --> step={step}, rank={torch.distributed.get_rank()}, {self.tokenizer.batch_decode(ans, skip_special_tokens=True)}")
out_seq = []
for i in range(batch_size):
if valid_ans_len[i] <= 1: # if the answer is shorter than 1 token, drop it
print(
f'Dropping too short generated answer: {step=}: \n'
f'prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n'
f'answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}'
)
continue
else:
out_seq.append(seq[i:i + 1])# seq[b, 512],取seq[0:1], seq[1:2]。
if not out_seq:
print(
f'All generated results are too short for rank={self.args.local_rank} step={step}\n'
f'-> prompts: {self.tokenizer.batch_decode(prompts, skip_special_tokens=False)}\n'
f'-> answers: {self.tokenizer.batch_decode(ans, skip_special_tokens=False)}'
)
return None
out_seq = torch.cat(out_seq, dim=0) # out_seq是list,每个元素[1,512], cat之后变成了[b, 512]
return out_seq
def generate_experience(self, prompts, mask, step):
self.eval() #生成经验过程全部参数固定不动
generate_start = time.time()
#演员模型推理了一下变成标签了
seq = self._generate_sequence(prompts, mask, step) #由batch个prompt获得batch个回答,注意seq是问题和回答的拼接,
generate_end = time.time()
if seq is None:
assert self.last_generated_experience is not None, f'Invalid generated experience at {step=}'
prompts = self.last_generated_experience['prompts']
seq = self.last_generated_experience['seq']
else:
self.last_generated_experience = {'prompts': prompts, 'seq': seq}
self.train() #转训练了?没有,为了获得分数
pad_token_id = self.tokenizer.pad_token_id
attention_mask = seq.not_equal(pad_token_id).long()#[2, 512],不等于pad的mask=1
with torch.no_grad(): #上一个generate,这个是forword
output = self.actor_model(seq, attention_mask=attention_mask)
output_ref = self.ref_model(seq, attention_mask=attention_mask)
reward_score = self.reward_model.forward_value(seq, attention_mask,prompt_length=self.prompt_length)['chosen_end_scores'].detach()
values = self.critic_model.forward_value(seq, attention_mask, return_value_only=True).detach()[:, :-1]
logits = output.logits
logits_ref = output_ref.logits
if self.compute_fp32_loss:
logits = logits.to(torch.float)
logits_ref = logits_ref.to(torch.float)
self.generate_time = generate_end - generate_start
return {
'prompts': prompts,
'logprobs': gather_log_probs(logits[:, :-1, :], seq[:, 1:]),#logits[:, :-1, :]指的是除了最后一个字,输出是[b, 512, 30000], #[b, 512]除了第一个字
'ref_logprobs': gather_log_probs(logits_ref[:, :-1, :], seq[:,1:]), #用于从 log_probs 中提取每个 token 对应 labels 的 log 概率
'value': values,
'rewards': reward_score,
'input_ids': seq,
"attention_mask": attention_mask
}
def compute_rewards(self, prompts, log_probs, ref_log_probs, reward_score,
action_mask):
"""
reward_function:计算最终的reward分数
复习一下几个相关参数的默认值:
self.kl_ctl = 0.1
self.clip_reward_value = 5
对于batch中的某个prompt来说,它最终的reward分数为:
(1) 先计算actor和ref_model的logit相似度: -self.kl_ctl * (log_probs - ref_log_probs)
其实写成self.kl_ctl * (ref_log_probs - log_probs)更好理解些
这个值越大,说明ref_model对actor生成的结果的认可度越高(即表明rlhf没有训歪),
没有训歪的情况下我们也应该给模型一些奖励,这个奖励就是self.kl_ctl * (ref_log_probs - log_probs)
(2)由于我们只取最后一个token对应位置的分数作为reward_score,因此我们只需要:
self.kl_ctl * (ref_log_probs - log_probs)的最后一位 + reward_score
(3) 同时我们对reward_score也做了大小限制,最大不超过self.clip_reward_value(超过统一给成self.clip_reward_value),
最小不低于-self.clip_reward_value(低于统一给成-self.clip_reward_value)
(4) 最后返回的rewards大小为:(batch_size, 各条数据的长度),对batch中的每条数据来说:
- response的最后一位:self.kl_ctl * (ref_log_probs - log_probs)的最后一位 + reward_score
- response的其余位置:self.kl_ctl * (ref_log_probs - log_probs)
"""
kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs) #[2, 511]
rewards = kl_divergence_estimate #[2, 511]
start = prompts.shape[1] - 1 #回复的开始位置:255
ends = start + action_mask[:, start:].sum(1) + 1#回复的结束位置:[512, 512],注意计算有效的回复
reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
self.clip_reward_value)#将两个分数值裁剪,[0.14, 0.59]
batch_size = log_probs.shape[0]
for j in range(batch_size):
rewards[j, start:ends[j]][-1] += reward_clip[j] #这部分表示将裁剪后的奖励值(reward_clip[j])加到提取到的最后一个奖励值上。因此,奖励就是参考模型与演员的KL再加奖励
return rewards
def train_rlhf(self, inputs):
# # 在这里训练 RLHF 模型
### process the old outputs
prompts = inputs['prompts']#原始数据
log_probs = inputs['logprobs']#计算 token 在其实际标签位置上的对数概率:
ref_log_probs = inputs['ref_logprobs']# 推理一次输出与seq标签在试剂标签的对数概率
reward_score = inputs['rewards']#[2],batch中每句话的分数
values = inputs['value']#[2,511],seq去掉最后一个字,511个字的价值
attention_mask = inputs['attention_mask']##seq中[2, 512],不等于pad的mask=1
seq = inputs['input_ids'] #seq是问题和回答的拼接,标签数据
start = prompts.size()[-1] - 1# 256-1=255
action_mask = attention_mask[:, 1:]#去掉第一个字的mask,[2,511]
old_values = values
with torch.no_grad():
old_rewards = self.compute_rewards(prompts, log_probs, ref_log_probs, reward_score, action_mask)#【2, 511】,计算奖励是为了训练演员,要求kl散度+奖励模型最后一个字的奖励分数
ends = start + action_mask[:, start:].sum(1) + 1 #有效回复的最终位置
# 我们需要在对话结束后将奖励和价值清零
# 否则优势/回报将会出错
for i in range(old_rewards.shape[0]):
old_rewards[i, ends[i]:] = 0 # old_rewards[0, 512:]
old_values[i, ends[i]:] = 0
advantages, returns = self.get_advantages_and_returns(old_values, old_rewards, start)
### process the new outputs
batch = {'input_ids': seq, "attention_mask": attention_mask}
actor_prob = self.actor_model(**batch, use_cache=False).logits
actor_log_prob = gather_log_probs(actor_prob[:, :-1, :], seq[:, 1:]) #输入序列是除了最后一个字,标签是除了第一个字
actor_loss = self.actor_loss_fn(actor_log_prob[:, start:],
log_probs[:, start:], advantages,
action_mask[:, start:])
self.actor_model.backward(actor_loss)
# T5走到这一步应该就够了
if not self.args.align_overflow:
self.actor_model.step()
# 批评家模型也需要训练,[2, 511]
value = self.critic_model.forward_value(**batch,
return_value_only=True,
use_cache=False)[:, :-1]
critic_loss = self.critic_loss_fn(value[:, start:], old_values[:,start:],
returns, action_mask[:, start:])
self.critic_model.backward(critic_loss)
if self.args.align_overflow:
actor_overflow = self.actor_model.optimizer.check_overflow(
external=True)
critic_overflow = self.critic_model.optimizer.check_overflow(
external=True)
rank = torch.distributed.get_rank()
if actor_overflow and not critic_overflow:
self.critic_model.optimizer.skip_step = True
print_rank_0(
"OVERFLOW: actor overflow, skipping both actor and critic steps",
rank)
elif not actor_overflow and critic_overflow:
self.actor_model.optimizer.skip_step = True
print_rank_0(
"OVERFLOW: critic overflow, skipping both actor and critic steps",
rank)
elif actor_overflow and critic_overflow:
print_rank_0(
"OVERFLOW: actor and critic overflow, skipping both actor and critic steps",
rank)
self.actor_model.step()
self.critic_model.step()
return actor_loss, critic_loss
def get_overflow(self):
# Overflow is not expected when using bf16
# Therefore, DeepSpeed's BF16_Optimizer does not maintain an overflow indication
if self.args.dtype == "bf16":
return False, False
actor_overflow = self.actor_model.optimizer.overflow
critic_overflow = self.critic_model.optimizer.overflow
return actor_overflow, critic_overflow
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
## policy gradient loss
"""
logprobs: 实时计算的,response部分的prob(只有这个是随着actor实时更新而改变的)
old_logprobs:老策略中,response部分的prob (这个是固定的,不随actor实时更新而改变)
advantages: 老策略中,response部分每个token对应的优势(这个是固定的,不随actor实时更新而改变)
mask:老策略中,response部分对应的mask情况这个是固定的,不随actor实时更新而改变)
之所以要引入logprobs计算actor_loss,是因为我们不希望策略每次更新的幅度太大,防止模型训歪
self.cliprange: 默认值是0.2
"""
log_ratio = (logprobs - old_logprobs) * mask
ratio = torch.exp(log_ratio) #e^log(a/b)=a/b , [2, 256]
pg_loss1 = -advantages * ratio #loss = -Adv * (a/b)
pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
1.0 + self.cliprange)
pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
return pg_loss
def critic_loss_fn(self, values, old_values, returns, mask):
# 用旧的value去约束新的value
## value loss
values_clipped = torch.clamp(#【2,256】
values,
old_values - self.cliprange_value,
old_values + self.cliprange_value,
)
if self.compute_fp32_loss:
values = values.float()
values_clipped = values_clipped.float()
# critic模型的loss定义为(预估预期收益-实际预期收益)**2
vf_loss1 = (values - returns)**2#MSE
vf_loss2 = (values_clipped - returns)**2
vf_loss = 0.5 * torch.sum(
torch.max(vf_loss1, vf_loss2) * mask) / mask.sum() # 同样,最后也是把critic loss平均到每个token上
return vf_loss
def get_advantages_and_returns(self, values, rewards, start):
"""
Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
没有引入GAE前的t时刻的优势值:
detal_t = r_t + gamma * V_t+1 - V_t
其中:
- r_t表示t时刻的即时收益
- V_t+1表示未来时刻的预期收益
- r_t + gamma * V_t+1可理解成t时刻的实际预期收益
- V_t可理解成t时刻的预估预期收益(是模型,例如critic model自己估算出来的)
引入GAE后的t时刻的优势值:
A_t = delta_t + gamma * lambda * A_t+1
粗暴理解为在t时刻时,不仅考虑当下优势,还考虑了未来的优势
为了知道A_t, 我们得知道A_t+1,所以在本算法中采取了从后往前做动态规划求解的方法,也即:
假设T是最后一个时刻,则有A_T+1 = 0, 所以有: A_T = delta_T
知道了A_T, 就可以依次往前倒推,把A_t-1, A_t-2之类都算出来了
引入GAE后t时刻的实际预期收益
returns_t = A_t + V_t
= delta_t + gamma * lambda * A_t+1 + V_t
= r_t + gamma * V_t+1 - V_t + gamma * lambda * A_t+1 + V_t
= r_t + gamma * (V_t+1 + lambda * A_t+1)
注意,这里不管是advantages还是returns,都只算response的部分
"""
# Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
lastgaelam = 0
advantages_reversed = []
length = rewards.size()[-1]
for t in reversed(range(start, length)):
nextvalues = values[:, t + 1] if t < length - 1 else 0.0
delta = rewards[:, t] + self.gamma * nextvalues - values[:, t]
lastgaelam = delta + self.gamma * self.lam * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)#优势
returns = advantages + values[:, start:] #实际收益=优势+价值
return advantages.detach(), returns
def _validate_training_mode(self):
assert self.actor_model.module.training
assert self.critic_model.module.training
def _validate_evaluation_mode(self):
assert not self.actor_model.module.training
assert not self.critic_model.module.training
assert not self.ref_model.module.training
assert not self.reward_model.module.training
def train(self):
self.actor_model.train()
self.critic_model.train()
def eval(self):
self.actor_model.eval()
self.critic_model.eval()
self.reward_model.eval()
self.ref_model.eval()
def dump_model_norms(self, tag):
actor_model_norm = get_model_norm(self.actor_model)
ref_model_norm = get_model_norm(self.ref_model)
critic_model_norm = get_model_norm(self.critic_model)
reward_model_norm = get_model_norm(self.reward_model)
print_all_ranks(f'{tag} global_actor_model_norm', actor_model_norm,
self.args.local_rank)
print_all_ranks(f'{tag} global_ref_model_norm', ref_model_norm,
self.args.local_rank)
print_all_ranks(f'{tag} global_critic_model_norm', critic_model_norm,
self.args.local_rank)
print_all_ranks(f'{tag} global_reward_model_norm', reward_model_norm,
self.args.local_rank)
class DeepSpeedPPOTrainerUnsupervised(DeepSpeedPPOTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def train_unsupervised(self, inputs, unsup_coef):
# Train the unsupervised model here
self._validate_training_mode()
outputs = self.actor_model(**inputs, use_cache=False)
loss = outputs.loss
self.actor_model.backward(unsup_coef * loss)
self.actor_model.step()
return loss