【PP-RL代码实现】rllm sft vs parallel sft

1. parallel -r1 sft

1.1 并行sft 数据生成与准备

地址在verl/data_preprocess_scripts
数据格式

{
  "index": 0,
  "output": "Mimi picked up 2 dozen seashells, which is 2 * 12 = 24 seashells.\n\nNow, Kyle found twice as many shells as Mimi, so we need to calculate Kyle's shells. At this point, let's consider different ways to approach this.\n\n<Parallel>\n<Path>Directly multiply Mimi's shells by 2: 24 * 2 = 48 shells for Kyle.</Path>\n<Path>Use dozens for calculation: Mimi has 2 dozen, so Kyle has 4 dozen (since twice 2 dozen is 4 dozen). Then, convert back to individual shells if needed, but we can keep it in dozens for now.</Path>\n</Parallel>\n<Summary>These parallel paths demonstrate that whether you calculate Kyle's shells directly or use dozens, the result is the same: 48 shells.</Summary>\n\nNow that we have Kyle's shells, we can find Leigh's. Leigh grabbed one-third of Kyle's shells.\n\n<Parallel>\n<Path>Divide Kyle's shells by 3: 48 / 3 = 16 shells for Leigh.</Path>\n<Path>Express Leigh's shells in terms of Mimi's shells: since Kyle has twice Mimi's, and Leigh has one-third of Kyle's, so Leigh has (1/3) * (2 * Mimi) = (2/3) * Mimi. With Mimi having 24, (2/3) * 24 = 16 shells.</Path>\n</Parallel>\n<Summary>These parallel paths show that Leigh's shells can be found either by dividing Kyle's amount by 3 or by directly relating it to Mimi's shells using a fraction.</Summary>\n\nTherefore, Leigh has 16 seashells.\n\nFinal Answer: 16",
  "data_source": "TongZheng1999/Deepseek-Qwen-3-8B-GSM8k-Parallel-data-Filtered",
  "prompt": [
    {
      "content": "Solve the following problem step by step.\nDuring the reasoning process, whenever you encounter a step that may benefit from multiple perspectives or independent reasoning, insert a parallel block at that point.\n\nWithin each parallel block:\nBegin the block with <Parallel>.\nInclude at least two distinct and independent reasoning paths.\nEach path must be enclosed within <Path> and </Path> tags.\nDo not include any ordering information or cross-references between paths, as they are generated simultaneously and independently.\nClose the block with </Parallel>.\nImmediately after each </Parallel>, write a concise summary of insights or conclusions drawn from all paths, enclosed in <Summary> and </Summary> tags.\n\nRepeat this process adaptively as needed throughout the reasoning.\nDo not explicitly mention that you are triggering parallel thinking—just insert the parallel block naturally within the reasoning chain.\n\nEnd your response with a line starting with Final Answer: followed by the final result.\n\nProblem: Mimi picked up 2 dozen seashells on the beach.  Kyle found twice as many shells as Mimi and put them in his pocket. Leigh grabbed one-third of the shells that Kyle found.  How many seashells did Leigh have?",
      "role": "user"
    }
  ],
  "ability": "math",
  "reward_model": {
    "ground_truth": "16",
    "style": "rule"
  },
  "extra_info": {
    "answer": "Mimi picked up 2 dozen seashells, which is 2 * 12 = 24 seashells.\n\nNow, Kyle found twice as many shells as Mimi, so we need to calculate Kyle's shells. At this point, let's consider different ways to approach this.\n\n<Parallel><Path>Directly multiply Mimi's shells by 2: 24 * 2 = 48 shells for Kyle.</Path><Path>Use dozens for calculation: Mimi has 2 dozen, so Kyle has 4 dozen (since twice 2 dozen is 4 dozen). Then, convert back to individual shells if needed, but we can keep it in dozens for now.</Path></Parallel>\n<Summary>These parallel paths demonstrate that whether you calculate Kyle's shells directly or use dozens, the result is the same: 48 shells.</Summary>\n\nNow that we have Kyle's shells, we can find Leigh's. Leigh grabbed one-third of Kyle's shells.\n\n<Parallel><Path>Divide Kyle's shells by 3: 48 / 3 = 16 shells for Leigh.</Path><Path>Express Leigh's shells in terms of Mimi's shells: since Kyle has twice Mimi's, and Leigh has one-third of Kyle's, so Leigh has (1/3) * (2 * Mimi) = (2/3) * Mimi. With Mimi having 24, (2/3) * 24 = 16 shells.</Path></Parallel>\n<Summary>These parallel paths show that Leigh's shells can be found either by dividing Kyle's amount by 3 or by directly relating it to Mimi's shells using a fraction.</Summary>\n\nTherefore, Leigh has 16 seashells.\n\nFinal Answer: 16",
    "index": 0,
    "question": "Solve the following problem step by step.\nDuring the reasoning process, whenever you encounter a step that may benefit from multiple perspectives or independent reasoning, insert a parallel block at that point.\n\nWithin each parallel block:\nBegin the block with <Parallel>.\nInclude at least two distinct and independent reasoning paths.\nEach path must be enclosed within <Path> and </Path> tags.\nDo not include any ordering information or cross-references between paths, as they are generated simultaneously and independently.\nClose the block with </Parallel>.\nImmediately after each </Parallel>, write a concise summary of insights or conclusions drawn from all paths, enclosed in <Summary> and </Summary> tags.\n\nRepeat this process adaptively as needed throughout the reasoning.\nDo not explicitly mention that you are triggering parallel thinking—just insert the parallel block naturally within the reasoning chain.\n\nEnd your response with a line starting with Final Answer: followed by the final result.\n\nProblem: Mimi picked up 2 dozen seashells on the beach.  Kyle found twice as many shells as Mimi and put them in his pocket. Leigh grabbed one-third of the shells that Kyle found.  How many seashells did Leigh have?",
    "reward_method": "accuracy_add_parallel_reward",
    "split": "train"
  }
}

1.2 parrallel-r1的并行数据处理器(ParallelThinkingSFTDataset)

和普通的数据处理器的主要区别在于:

	def generate_parallel_thinking_reasponse_mask(self, response_ids: torch.Tensor) -> torch.Tensor:
	"""
	生成注意力mask
	"""
	seq_len = response_ids.size(0)
	mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool))
	path_ranges = []
	in_parallel = False
	i = 0
	while i < seq_len:
	    tok = response_ids[i].item()
	    if tok == self.start_parallel_token:
	        in_parallel = True
	        path_ranges = []
	    elif in_parallel and tok == self.start_path_token:
	        # 找到这一条 <Path>... </Path> 的起止位置
	        path_start = i
	        ...
	        path_end = i
	        path_ranges.append((path_start, path_end))
	    elif tok == self.end_parallel_token:
	        in_parallel = False
	        # 在一个 <Parallel> 块结束时,把不同 path 之间的注意力抹掉
	        for idx1, (s1, e1) in enumerate(path_ranges):
	            for idx2, (s2, e2) in enumerate(path_ranges):
	                if idx1 != idx2:
	                    mask[s1:e1+1, s2:e2+1] = False
	        path_ranges = []
	    i += 1


	def compute_structured_position_ids(self, response_ids: torch.Tensor) -> torch.Tensor:
		"""
		 </Parallel> 后面的第一个 token position 设为 curr_pos + max_path_len + 1
		"""
	    pad_token_id = self.tokenizer.pad_token_id
	    seq_len = response_ids.size(0)
	    pos_ids = torch.zeros(seq_len, dtype=torch.long)
	    nonpad_mask = (response_ids != pad_token_id)
	
	    curr_pos = 0
	    i = 0
	    while i < seq_len:
	        if not nonpad_mask[i]:
	            i += 1
	            continue
	
	        tok = response_ids[i].item()
	        if tok == self.start_parallel_token:
	            pos_ids[i] = curr_pos
	            i += 1
	            path_lengths = []
	            ...
	            # 收集所有 path 的长度
	            while temp_i < seq_len:
	                ...
	                if response_ids[temp_i].item() == self.start_path_token:
	                    ...
	                    while temp_i < seq_len:
	                        ...
	                        pos_ids[temp_i] = local_pos
	                        ...
	                    path_len = pos_ids[path_end] - curr_pos
	                    path_lengths.append(path_len)
	            ...
	            max_path_len = max(path_lengths) if path_lengths else 0
	            if i < seq_len and nonpad_mask[i]:
	                pos_ids[i] = curr_pos + max_path_len + 1
	                curr_pos += max_path_len + 1 + 1
	        else:
	            pos_ids[i] = curr_pos
	            curr_pos += 1
	        i += 1

主要的方法为

    def __getitem__(self, item):
        tokenizer = self.tokenizer

        prompt = self.prompts[item]
        response = self.responses[item]

        # apply chat template
        prompt_chat = [{"role": "user", "content": prompt}]
        # print(prompt_chat)
        # string
        prompt_chat_str = tokenizer.apply_chat_template(prompt_chat, add_generation_prompt=True, tokenize=False)
        response_chat_str = response + tokenizer.eos_token

        # tokenize
        prompt_ids_output = tokenizer(prompt_chat_str, return_tensors="pt", add_special_tokens=False)
        prompt_ids = prompt_ids_output["input_ids"][0]
        prompt_attention_mask = prompt_ids_output["attention_mask"][0]

        response_ids_output = tokenizer(response_chat_str, return_tensors="pt", add_special_tokens=False)
        response_ids = response_ids_output["input_ids"][0]
        response_attention_mask = response_ids_output["attention_mask"][0]

        attention_mask_1d = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)


        # response_attention_mask = self.generate_parallel_thinking_reasponse_mask(response_ids)
        # print(response_attention_mask.shape)
        # print(prompt_attention_mask)
        prompt_length = prompt_ids.shape[0]
        response_length = response_ids.shape[0]

        input_ids = torch.cat((prompt_ids, response_ids), dim=-1)

        # padding to max length
        sequence_length = input_ids.shape[0]
        if sequence_length < self.max_length:
            padded_input_ids = torch.ones(size=(self.max_length - sequence_length,), dtype=input_ids.dtype) * self.tokenizer.pad_token_id
            padded_attention_mask_1d = torch.zeros(size=(self.max_length - sequence_length,), dtype=attention_mask_1d.dtype)
            input_ids = torch.cat((input_ids, padded_input_ids))
            attention_mask_1d = torch.cat((attention_mask_1d, padded_attention_mask_1d))
        elif sequence_length > self.max_length:
            if self.truncation == "left":
                # actually, left truncation may not be reasonable
                input_ids = input_ids[-self.max_length :]
                attention_mask_1d = attention_mask_1d[-self.max_length :]
            elif self.truncation == "right":
                input_ids = input_ids[: self.max_length]
                attention_mask_1d = attention_mask_1d[: self.max_length]
            elif self.truncation == "error":
                raise NotImplementedError(f"{sequence_length=} is larger than {self.max_length=}")
            else:
                raise NotImplementedError(f"Unknown truncation method {self.truncation}")

        # !!!!!
        attention_mask = self.generate_parallel_thinking_reasponse_mask(input_ids)
        # print(attention_mask)

        if sequence_length < self.max_length:
            # 将padding部分的attention mask设置为False
            attention_mask[:,-(self.max_length - sequence_length):] = False
            attention_mask[-(self.max_length - sequence_length):,:] = False
        
        # print(attention_mask)
        attention_mask = attention_mask.unsqueeze(0)
        # attention_mask = torch.cat((prompt_attention_mask, response_attention_mask), dim=-1)
        float_attention_mask = torch.full_like(attention_mask, -torch.inf, dtype=torch.float)
        float_attention_mask = float_attention_mask.masked_fill(attention_mask, 0.0)


        ## !!!!!!
        position_ids = self.compute_structured_position_ids(input_ids)

        loss_mask = attention_mask_1d.clone()
        if prompt_length > 1:
            # mask out prompt for SFT.
            loss_mask[: min(prompt_length, loss_mask.size(0)) - 1] = 0
        # mask out the last token in response
        loss_mask[min(prompt_length + response_length, loss_mask.size(0)) - 1] = 0

        # print(float_attention_mask.shape)
        # print(loss_mask.shape)
        # print(position_ids.shape)

        return {
            "input_ids": input_ids,
            "attention_mask": float_attention_mask,
            "bool_attention_mask": attention_mask,
            "position_ids": position_ids,
            "loss_mask": loss_mask,
        }

1.3 parrallel-r1的并行sft训练器(FSDPParallelThinkingSFTTrainer)

和普通的trainer相比,不会显示计算attention mask,完全信任dataset传入的attention mask

2. 工具调用 sft

2.2 工具调用sft 数据生成

这篇文章

2.2 工具调用的数据处理器(见文稿/rllm-sft-dataset)

只对assistant轮算loss

  1. cumulative 模式:所有 assistant 轮都算 loss(默认模式——也这是普通sft和多轮对话sft通用的做法)
def _tokenize_and_mask_cumulative(self, messages):
    tokens, loss_mask = [], []
    for i in range(len(messages)):
        parsed = self.parser.parse([messages[i]], is_first_msg=(i == 0), add_generation_prompt=False)
        ids = self.tokenizer.encode(parsed, add_special_tokens=False)
        if messages[i]["role"] == "assistant":
            loss_mask.extend([1] * len(ids))   # 全部 assistant token 都训练
        else:
            loss_mask.extend([0] * len(ids))
        tokens.extend(ids)
    return tokens, loss_mask
  1. stepwise 模式:只训练最后一轮 assistant 回复(比如只训最终回答)
# 找到最后一个 assistant index
...
for i in range(len(messages)):
    parsed = self.parser.parse([messages[i]], is_first_msg=(i == 0), add_generation_prompt=False)
    ids = self.tokenizer.encode(parsed, add_special_tokens=False)
    if i == last_assistant_idx and messages[i]["role"] == "assistant":
        loss_mask.extend([1] * len(ids))      # 只训最后一轮 assistant
    else:
        loss_mask.extend([0] * len(ids))
    tokens.extend(ids)
     def __getitem__(self, item):
        messages = self.messages[item]
		
		
		# 里面调用了def _tokenize_and_mask_cumulative和def _tokenize_and_mask_cumulative方法
        tokens, loss_mask = self._tokenize_and_mask(messages)

        input_ids = torch.tensor(tokens, dtype=torch.long)
        loss_mask = torch.tensor(loss_mask, dtype=torch.long)
        attention_mask = torch.tensor([1] * len(tokens), dtype=torch.long)

        # Handle sequence length
        sequence_length = input_ids.shape[0]
        if sequence_length < self.max_length:
            # Pad sequences
            pad_token_id = self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else 0
            padded_input_ids = torch.full((self.max_length - sequence_length,), pad_token_id, dtype=input_ids.dtype)
            padded_attention_mask = torch.zeros((self.max_length - sequence_length,), dtype=attention_mask.dtype)
            padded_loss_mask = torch.zeros((self.max_length - sequence_length,), dtype=loss_mask.dtype)

            input_ids = torch.cat((input_ids, padded_input_ids))
            attention_mask = torch.cat((attention_mask, padded_attention_mask))
            loss_mask = torch.cat((loss_mask, padded_loss_mask))

        elif sequence_length > self.max_length:
            if self.truncation == "left":
                input_ids = input_ids[-self.max_length :]
                attention_mask = attention_mask[-self.max_length :]
                loss_mask = loss_mask[-self.max_length :]
            elif self.truncation == "right":
                input_ids = input_ids[: self.max_length]
                attention_mask = attention_mask[: self.max_length]
                loss_mask = loss_mask[: self.max_length]
            elif self.truncation == "error":
                raise ValueError(f"{sequence_length=} is larger than {self.max_length=}")
            else:
                raise ValueError(f"Unknown truncation method {self.truncation}")

        # Create position IDs
        position_ids = torch.arange(len(input_ids), dtype=torch.long)
        # Zero out position IDs for padding
        position_ids = position_ids * attention_mask

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "position_ids": position_ids,
            "loss_mask": loss_mask,
        }

1.3 工具调用的sft训练器

和正常的sfttrainer没有任何区别
posted @ 2025-12-19 22:42  Brain404  阅读(0)  评论(0)    收藏  举报