【PP-RL代码实现】rllm sft vs parallel sft
1. parallel -r1 sft
1.1 并行sft 数据生成与准备
地址在verl/data_preprocess_scripts中
数据格式
{
"index": 0,
"output": "Mimi picked up 2 dozen seashells, which is 2 * 12 = 24 seashells.\n\nNow, Kyle found twice as many shells as Mimi, so we need to calculate Kyle's shells. At this point, let's consider different ways to approach this.\n\n<Parallel>\n<Path>Directly multiply Mimi's shells by 2: 24 * 2 = 48 shells for Kyle.</Path>\n<Path>Use dozens for calculation: Mimi has 2 dozen, so Kyle has 4 dozen (since twice 2 dozen is 4 dozen). Then, convert back to individual shells if needed, but we can keep it in dozens for now.</Path>\n</Parallel>\n<Summary>These parallel paths demonstrate that whether you calculate Kyle's shells directly or use dozens, the result is the same: 48 shells.</Summary>\n\nNow that we have Kyle's shells, we can find Leigh's. Leigh grabbed one-third of Kyle's shells.\n\n<Parallel>\n<Path>Divide Kyle's shells by 3: 48 / 3 = 16 shells for Leigh.</Path>\n<Path>Express Leigh's shells in terms of Mimi's shells: since Kyle has twice Mimi's, and Leigh has one-third of Kyle's, so Leigh has (1/3) * (2 * Mimi) = (2/3) * Mimi. With Mimi having 24, (2/3) * 24 = 16 shells.</Path>\n</Parallel>\n<Summary>These parallel paths show that Leigh's shells can be found either by dividing Kyle's amount by 3 or by directly relating it to Mimi's shells using a fraction.</Summary>\n\nTherefore, Leigh has 16 seashells.\n\nFinal Answer: 16",
"data_source": "TongZheng1999/Deepseek-Qwen-3-8B-GSM8k-Parallel-data-Filtered",
"prompt": [
{
"content": "Solve the following problem step by step.\nDuring the reasoning process, whenever you encounter a step that may benefit from multiple perspectives or independent reasoning, insert a parallel block at that point.\n\nWithin each parallel block:\nBegin the block with <Parallel>.\nInclude at least two distinct and independent reasoning paths.\nEach path must be enclosed within <Path> and </Path> tags.\nDo not include any ordering information or cross-references between paths, as they are generated simultaneously and independently.\nClose the block with </Parallel>.\nImmediately after each </Parallel>, write a concise summary of insights or conclusions drawn from all paths, enclosed in <Summary> and </Summary> tags.\n\nRepeat this process adaptively as needed throughout the reasoning.\nDo not explicitly mention that you are triggering parallel thinking—just insert the parallel block naturally within the reasoning chain.\n\nEnd your response with a line starting with Final Answer: followed by the final result.\n\nProblem: Mimi picked up 2 dozen seashells on the beach. Kyle found twice as many shells as Mimi and put them in his pocket. Leigh grabbed one-third of the shells that Kyle found. How many seashells did Leigh have?",
"role": "user"
}
],
"ability": "math",
"reward_model": {
"ground_truth": "16",
"style": "rule"
},
"extra_info": {
"answer": "Mimi picked up 2 dozen seashells, which is 2 * 12 = 24 seashells.\n\nNow, Kyle found twice as many shells as Mimi, so we need to calculate Kyle's shells. At this point, let's consider different ways to approach this.\n\n<Parallel><Path>Directly multiply Mimi's shells by 2: 24 * 2 = 48 shells for Kyle.</Path><Path>Use dozens for calculation: Mimi has 2 dozen, so Kyle has 4 dozen (since twice 2 dozen is 4 dozen). Then, convert back to individual shells if needed, but we can keep it in dozens for now.</Path></Parallel>\n<Summary>These parallel paths demonstrate that whether you calculate Kyle's shells directly or use dozens, the result is the same: 48 shells.</Summary>\n\nNow that we have Kyle's shells, we can find Leigh's. Leigh grabbed one-third of Kyle's shells.\n\n<Parallel><Path>Divide Kyle's shells by 3: 48 / 3 = 16 shells for Leigh.</Path><Path>Express Leigh's shells in terms of Mimi's shells: since Kyle has twice Mimi's, and Leigh has one-third of Kyle's, so Leigh has (1/3) * (2 * Mimi) = (2/3) * Mimi. With Mimi having 24, (2/3) * 24 = 16 shells.</Path></Parallel>\n<Summary>These parallel paths show that Leigh's shells can be found either by dividing Kyle's amount by 3 or by directly relating it to Mimi's shells using a fraction.</Summary>\n\nTherefore, Leigh has 16 seashells.\n\nFinal Answer: 16",
"index": 0,
"question": "Solve the following problem step by step.\nDuring the reasoning process, whenever you encounter a step that may benefit from multiple perspectives or independent reasoning, insert a parallel block at that point.\n\nWithin each parallel block:\nBegin the block with <Parallel>.\nInclude at least two distinct and independent reasoning paths.\nEach path must be enclosed within <Path> and </Path> tags.\nDo not include any ordering information or cross-references between paths, as they are generated simultaneously and independently.\nClose the block with </Parallel>.\nImmediately after each </Parallel>, write a concise summary of insights or conclusions drawn from all paths, enclosed in <Summary> and </Summary> tags.\n\nRepeat this process adaptively as needed throughout the reasoning.\nDo not explicitly mention that you are triggering parallel thinking—just insert the parallel block naturally within the reasoning chain.\n\nEnd your response with a line starting with Final Answer: followed by the final result.\n\nProblem: Mimi picked up 2 dozen seashells on the beach. Kyle found twice as many shells as Mimi and put them in his pocket. Leigh grabbed one-third of the shells that Kyle found. How many seashells did Leigh have?",
"reward_method": "accuracy_add_parallel_reward",
"split": "train"
}
}
1.2 parrallel-r1的并行数据处理器(ParallelThinkingSFTDataset)
和普通的数据处理器的主要区别在于:
def generate_parallel_thinking_reasponse_mask(self, response_ids: torch.Tensor) -> torch.Tensor:
"""
生成注意力mask
"""
seq_len = response_ids.size(0)
mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
path_ranges = []
in_parallel = False
i = 0
while i < seq_len:
tok = response_ids[i].item()
if tok == self.start_parallel_token:
in_parallel = True
path_ranges = []
elif in_parallel and tok == self.start_path_token:
# 找到这一条 <Path>... </Path> 的起止位置
path_start = i
...
path_end = i
path_ranges.append((path_start, path_end))
elif tok == self.end_parallel_token:
in_parallel = False
# 在一个 <Parallel> 块结束时,把不同 path 之间的注意力抹掉
for idx1, (s1, e1) in enumerate(path_ranges):
for idx2, (s2, e2) in enumerate(path_ranges):
if idx1 != idx2:
mask[s1:e1+1, s2:e2+1] = False
path_ranges = []
i += 1
def compute_structured_position_ids(self, response_ids: torch.Tensor) -> torch.Tensor:
"""
</Parallel> 后面的第一个 token position 设为 curr_pos + max_path_len + 1
"""
pad_token_id = self.tokenizer.pad_token_id
seq_len = response_ids.size(0)
pos_ids = torch.zeros(seq_len, dtype=torch.long)
nonpad_mask = (response_ids != pad_token_id)
curr_pos = 0
i = 0
while i < seq_len:
if not nonpad_mask[i]:
i += 1
continue
tok = response_ids[i].item()
if tok == self.start_parallel_token:
pos_ids[i] = curr_pos
i += 1
path_lengths = []
...
# 收集所有 path 的长度
while temp_i < seq_len:
...
if response_ids[temp_i].item() == self.start_path_token:
...
while temp_i < seq_len:
...
pos_ids[temp_i] = local_pos
...
path_len = pos_ids[path_end] - curr_pos
path_lengths.append(path_len)
...
max_path_len = max(path_lengths) if path_lengths else 0
if i < seq_len and nonpad_mask[i]:
pos_ids[i] = curr_pos + max_path_len + 1
curr_pos += max_path_len + 1 + 1
else:
pos_ids[i] = curr_pos
curr_pos += 1
i += 1
主要的方法为
def __getitem__(self, item):
tokenizer = self.tokenizer
prompt = self.prompts[item]
response = self.responses[item]
# apply chat template
prompt_chat = [{"role": "user", "content": prompt}]
# print(prompt_chat)
# string
prompt_chat_str = tokenizer.apply_chat_template(prompt_chat, add_generation_prompt=True, tokenize=False)
response_chat_str = response + tokenizer.eos_token
# tokenize
prompt_ids_output = tokenizer(prompt_chat_str, return_tensors="pt", add_special_tokens=False)
prompt_ids = prompt_ids_output["input_ids"][0]
prompt_attention_mask = prompt_ids_output["attention_mask"][0]
response_ids_output = tokenizer(response_chat_str, return_tensors="pt", add_special_tokens=False)
response_ids = response_ids_output["input_ids"][0]
response_attention_mask = response_ids_output["attention_mask"][0]
attention_mask_1d = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)
# response_attention_mask = self.generate_parallel_thinking_reasponse_mask(response_ids)
# print(response_attention_mask.shape)
# print(prompt_attention_mask)
prompt_length = prompt_ids.shape[0]
response_length = response_ids.shape[0]
input_ids = torch.cat((prompt_ids, response_ids), dim=-1)
# padding to max length
sequence_length = input_ids.shape[0]
if sequence_length < self.max_length:
padded_input_ids = torch.ones(size=(self.max_length - sequence_length,), dtype=input_ids.dtype) * self.tokenizer.pad_token_id
padded_attention_mask_1d = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask_1d.dtype)
input_ids = torch.cat((input_ids, padded_input_ids))
attention_mask_1d = torch.cat((attention_mask_1d, padded_attention_mask_1d))
elif sequence_length > self.max_length:
if self.truncation == "left":
# actually, left truncation may not be reasonable
input_ids = input_ids[-self.max_length :]
attention_mask_1d = attention_mask_1d[-self.max_length :]
elif self.truncation == "right":
input_ids = input_ids[: self.max_length]
attention_mask_1d = attention_mask_1d[: self.max_length]
elif self.truncation == "error":
raise NotImplementedError(f"{sequence_length=} is larger than {self.max_length=}")
else:
raise NotImplementedError(f"Unknown truncation method {self.truncation}")
# !!!!!
attention_mask = self.generate_parallel_thinking_reasponse_mask(input_ids)
# print(attention_mask)
if sequence_length < self.max_length:
# 将padding部分的attention mask设置为False
attention_mask[:,-(self.max_length - sequence_length):] = False
attention_mask[-(self.max_length - sequence_length):,:] = False
# print(attention_mask)
attention_mask = attention_mask.unsqueeze(0)
# attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)
float_attention_mask = torch.full_like(attention_mask, -torch.inf, dtype=torch.float)
float_attention_mask = float_attention_mask.masked_fill(attention_mask, 0.0)
## !!!!!!
position_ids = self.compute_structured_position_ids(input_ids)
loss_mask = attention_mask_1d.clone()
if prompt_length > 1:
# mask out prompt for SFT.
loss_mask[: min(prompt_length, loss_mask.size(0)) - 1] = 0
# mask out the last token in response
loss_mask[min(prompt_length + response_length, loss_mask.size(0)) - 1] = 0
# print(float_attention_mask.shape)
# print(loss_mask.shape)
# print(position_ids.shape)
return {
"input_ids": input_ids,
"attention_mask": float_attention_mask,
"bool_attention_mask": attention_mask,
"position_ids": position_ids,
"loss_mask": loss_mask,
}
1.3 parrallel-r1的并行sft训练器(FSDPParallelThinkingSFTTrainer)
和普通的trainer相比,不会显示计算attention mask,完全信任dataset传入的attention mask
2. 工具调用 sft
2.2 工具调用sft 数据生成
见这篇文章
2.2 工具调用的数据处理器(见文稿/rllm-sft-dataset)
只对assistant轮算loss
- cumulative 模式:所有 assistant 轮都算 loss(默认模式——也这是普通sft和多轮对话sft通用的做法)
def _tokenize_and_mask_cumulative(self, messages):
tokens, loss_mask = [], []
for i in range(len(messages)):
parsed = self.parser.parse([messages[i]], is_first_msg=(i == 0), add_generation_prompt=False)
ids = self.tokenizer.encode(parsed, add_special_tokens=False)
if messages[i]["role"] == "assistant":
loss_mask.extend([1] * len(ids)) # 全部 assistant token 都训练
else:
loss_mask.extend([0] * len(ids))
tokens.extend(ids)
return tokens, loss_mask
- stepwise 模式:只训练最后一轮 assistant 回复(比如只训最终回答)
# 找到最后一个 assistant index
...
for i in range(len(messages)):
parsed = self.parser.parse([messages[i]], is_first_msg=(i == 0), add_generation_prompt=False)
ids = self.tokenizer.encode(parsed, add_special_tokens=False)
if i == last_assistant_idx and messages[i]["role"] == "assistant":
loss_mask.extend([1] * len(ids)) # 只训最后一轮 assistant
else:
loss_mask.extend([0] * len(ids))
tokens.extend(ids)
def __getitem__(self, item):
messages = self.messages[item]
# 里面调用了def _tokenize_and_mask_cumulative和def _tokenize_and_mask_cumulative方法
tokens, loss_mask = self._tokenize_and_mask(messages)
input_ids = torch.tensor(tokens, dtype=torch.long)
loss_mask = torch.tensor(loss_mask, dtype=torch.long)
attention_mask = torch.tensor([1] * len(tokens), dtype=torch.long)
# Handle sequence length
sequence_length = input_ids.shape[0]
if sequence_length < self.max_length:
# Pad sequences
pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0
padded_input_ids = torch.full((self.max_length - sequence_length,), pad_token_id, dtype=input_ids.dtype)
padded_attention_mask = torch.zeros((self.max_length - sequence_length,), dtype=attention_mask.dtype)
padded_loss_mask = torch.zeros((self.max_length - sequence_length,), dtype=loss_mask.dtype)
input_ids = torch.cat((input_ids, padded_input_ids))
attention_mask = torch.cat((attention_mask, padded_attention_mask))
loss_mask = torch.cat((loss_mask, padded_loss_mask))
elif sequence_length > self.max_length:
if self.truncation == "left":
input_ids = input_ids[-self.max_length :]
attention_mask = attention_mask[-self.max_length :]
loss_mask = loss_mask[-self.max_length :]
elif self.truncation == "right":
input_ids = input_ids[: self.max_length]
attention_mask = attention_mask[: self.max_length]
loss_mask = loss_mask[: self.max_length]
elif self.truncation == "error":
raise ValueError(f"{sequence_length=} is larger than {self.max_length=}")
else:
raise ValueError(f"Unknown truncation method {self.truncation}")
# Create position IDs
position_ids = torch.arange(len(input_ids), dtype=torch.long)
# Zero out position IDs for padding
position_ids = position_ids * attention_mask
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"loss_mask": loss_mask,
}
1.3 工具调用的sft训练器
和正常的sfttrainer没有任何区别

浙公网安备 33010602011771号