AgentLoop(Verl)vs ParallelThinkingAgentLoopV3(Parallel-R1) vs ToRL

最近看到了一篇很有意思的论文Parallel-R1,是用RL训练一个并行推理的模型,大概的格式为:

<模型推理过程>
突然生成一个<parallel>,进入多路径推理
<parallel>
<path> ... </path> 每一条推理路径之间互相不可见(使用attention mask mask掉了)
<path> ... </path>
</parallel>
<summary>
对上面的路径进行总结
</summay>
<模型继续推理>

这篇论文中了NiPS Oral,整体读下来,确实很有水平的一篇文章。并且,我自己复现了一遍,和作者放出的日志基本大差不差。总的来说,是一篇很值得学习的文章。

这篇文章主要基于verl框架进行修改,主要修改了两个地方,第一个地方就是我们这篇文章讲到的AgentLoop,主要用于RL。我会对比Verl中的Agent多轮交互的loop实现,和这篇文章中的loop实现,来详细讲一讲。
下一个地方是SFT的Dataset实现,主要用于sft。

Verl AgentLoop

是Verl在今年六七月更新的新功能——Agent RL中的AgentLoop实现。是一个很经典的React设计,就是生成工具执行请求,执行工具,将执行结果插入到对话中的过程。
具体代码如下

@register("tool_agent")
class ToolAgentLoop(AgentLoopBase):
    @classmethod
    def init_class(cls, config, tokenizer, **kwargs):
        ...

    @rollout_trace_op
    async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
        metrics = {}
        request_id = uuid4().hex
        # 1. 生成prompt的token-id
        prompt_ids = await self.loop.run_in_executor(
            None,
            lambda: self.tokenizer.apply_chat_template(
                messages, tools=self.tool_schemas, add_generation_prompt=True, tokenize=True
            ),
        )
        response_mask = []

        user_turns, assistant_turns = 0, 0
        while True:
            with simple_timer("generate_sequences", metrics):
                # 2. 生成response的token id
                response_ids = await self.server_manager.generate(
                    request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params
                )
            prompt_ids += response_ids
            response_mask += [1] * len(response_ids)
            assistant_turns += 1

            # reach max response length
            if len(response_mask) >= self.response_length:
                break

            # reach max assistant turns
            if self.max_assistant_turns and assistant_turns >= self.max_assistant_turns:
                break

            # reach max user turns
            if self.max_user_turns and user_turns >= self.max_user_turns:
                break

            # 3. 尝试从reponse中解析出工具调用(这里是用的应该是<tool_call> </tool_cal>关键词匹配来解析)
            _, tool_calls = await self.tool_parser.extract_tool_calls(response_ids)
            # 如果没有工具,那么跳出
            if not tool_calls:
                break

            # 4. 如果有工具,那么执行工具,然后将工具执行结果加入到消息队列中
            tasks = []
            for tool_call in tool_calls[: self.max_parallel_calls]:
                tasks.append(self._call_tool(tool_call))
            with simple_timer("tool_calls", metrics):
                tool_responses = await asyncio.gather(*tasks)
            if any(isinstance(item, Exception) for item in tool_responses):
                break

            tool_response_ids = await self.loop.run_in_executor(
                None,
                lambda messages=tool_responses: self.tokenizer.apply_chat_template(
                    messages, add_generation_prompt=True, tokenize=True
                ),
            )
            tool_response_ids = tool_response_ids[len(self.system_prompt) :]

            # NOTE: last turn should not be user turn, or the EOS token reward
            # can't be propagated to previous token in GAE.
            if len(response_mask) + len(tool_response_ids) >= self.response_length:
                break

            prompt_ids += tool_response_ids
            response_mask += [0] * len(tool_response_ids)
            user_turns += 1
          # 5. 如此反复,直到达到最大循环次数

        response_ids = prompt_ids[-len(response_mask) :]
        prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]

        output = AgentLoopOutput(
            prompt_ids=prompt_ids,
            response_ids=response_ids[: self.response_length],
            response_mask=response_mask[: self.response_length],
            num_turns=user_turns + assistant_turns + 1,
            metrics=metrics,
        )
        return output

    async def _call_tool(self, tool_call: FunctionCall) -> dict[str, str]:
        """执行工具"""
        tool, instance_id = None, None
        try:
            # TODO: append malformed tool_call to the prompt: invalid function name or arguments
            tool_name = tool_call.name
            tool_args = json.loads(tool_call.arguments)
            tool = self.tools[tool_name]

            instance_id = await tool.create()
            tool_response, _, _ = await tool.execute(instance_id, tool_args)
        except Exception as e:
            logger.exception(f"Error when executing tool: {e}")
            return e
        finally:
            if tool and instance_id:
                await tool.release(instance_id)

        if len(tool_response) > self.max_tool_response_length:
            if self.tool_response_truncate_side == "left":
                tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)"
            elif self.tool_response_truncate_side == "right":
                tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :]
            else:
                length = self.max_tool_response_length // 2
                tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:]

        return {
            "role": "tool",
            "content": tool_response,
        }

ParallelThinkingAgentLoopV3

对于这个Loop来说,其实单独看论文,感觉好像是单轮一次生成的,但其实从代码上看:这个loop也是多轮交互的。
大概过程如下:

  1. 主链路中,模型生成<parallel>后,停止,开始执行并行思考;反之,如果当前没有生成<parallel>那么直接跳出循环,返回结果
  2. 对于每一个思考路径来说,首先插入<path>标签
  3. 对于当前路径来说,遇见</path>时停止
  4. 并行思考结束后,回到主路径中,并且添加</Parallel><summary>标签
  5. 在主路径中继续生成,并遇到</summary>时,停止
  6. 回到1,直到达到最大循环次数或者达到最大长度时,停止。

此外,除了这个多轮交互之外,作者还对attention mask 和位置编码进行了修改:

具体代码如下

def append_tokens(new_ids: list[int],
                        is_parallel: bool = False,
                        path_spans: list[tuple[int, int]] | None = None, manual_mask_positions= None):
            
            nonlocal prompt_ids, position_ids, parallel_stack, position_required_masks, left_pad_len

            start = len(prompt_ids) # correct
            prompt_ids.extend(new_ids) # correct -> merge new_ids into prompt_ids
            response_mask.extend([1] * len(new_ids)) # correct -> extend response_mask with 1s for new_ids
            incr = torch.arange(1, len(new_ids) + 1, dtype=torch.long) + position_ids[-1] # -> correct
            position_ids = torch.cat([position_ids, incr]) # correct -> extend position_ids with new_ids' positions
            # 如果不是并行,那么直接返回
            if not is_parallel:
                return

            blk     = parallel_stack.pop()
            base_p  = blk["base"]
            longest = blk["longest"]
            

            # a) 校正 Path 内位置(包括:1. path内部:位置校正为path起始位置+path中的相对长度  2. path执行结束后,第一个token的位置编码=path起始时的位置编码+最大path长度) 
           # & b) 写跨 Path 掩码(让path和path之间的atention mask为True)
            for i, (s_rel, e_rel) in enumerate(path_spans):
                
              
                s_abs, e_abs = start + s_rel, start + e_rel
                # the first path we do not need to shift
                shift = position_ids[s_abs] - (base_p)
                # print("shift", shift.item())
                if shift.item():
                    position_ids[s_abs:e_abs] -= shift
                path_max_pos = position_ids[e_abs - 1]  
                longest = max(longest, (path_max_pos - base_p).item())
                # 构建跨path 的attention mask,让path和path之间的atention mask为True
                for j, (sj_rel, ej_rel) in enumerate(path_spans):
                    if j == i: continue
                    sj_abs, ej_abs = left_pad_len + start + sj_rel, left_pad_len + start + ej_rel
                    
                    if e_abs + left_pad_len > ej_abs:
                        if ej_abs < self.prompt_length + self.response_length:
                            position_required_masks.append((left_pad_len, s_abs + left_pad_len, e_abs + left_pad_len, sj_abs, ej_abs))
                            print(f"mask {s_abs+left_pad_len}:{e_abs+left_pad_len} 与 {sj_abs}:{ej_abs} attention")
                    elif e_abs + left_pad_len < ej_abs:
                        if sj_abs < self.prompt_length + self.response_length:
                            position_required_masks.append((left_pad_len, s_abs + left_pad_len, e_abs + left_pad_len, sj_abs, ej_abs))
                            print(f"mask {s_abs+left_pad_len}:{e_abs+left_pad_len} 与 {sj_abs}:{ej_abs} attention")
                    

            s2, e2 = path_spans[1]

            p_end_abs   = start + e2
            
            desired_end = base_p + longest + 1
            shift2      = position_ids[p_end_abs] - desired_end
            if shift2.item():
                position_ids[p_end_abs:] -= shift2

主要的修改就是这两点。

其实,从这里可以看出,作者的工程水平是很高的,光是第一点:多轮交互的逻辑,想要捋顺,并且实现都很困难。如果加上第二点,已经超出我目前的水平太多。在这种训练的框架中改这么底层的配置,并且work了,就是强大工程能力的最大表现。

完整的的loop类如下,这个代码写的很细致,推荐大家好好看

@register("parallel_thinking_agent_v3")
class ParallelThinkingAgentLoopV3(AgentLoopBase):
    @classmethod
    def init_class(cls, config, tokenizer, **kwargs):
        if cls._class_initialized:
            return
        cls._class_initialized = True
        print("Performing class-level ParallelThinkingV3AgentLoop initialization")

        # Initialize tools from config file
        cls.tokenizer = tokenizer
        cls.add_diverse_prefix = config.actor_rollout_ref.rollout.agent.add_diverse_prefix
        cls.max_iterations_for_parallel_thinking = config.actor_rollout_ref.rollout.agent.max_iterations_for_parallel_thinking
        cls.num_paths = config.actor_rollout_ref.rollout.agent.num_paths
        cls.max_path_response_length = config.actor_rollout_ref.rollout.agent.max_path_response_length

        cls.eos_token_id = cls.tokenizer.eos_token_id
        cls.start_parallel_token = cls.tokenizer.encode('<Parallel>')[0]
        cls.end_parallel_token = cls.tokenizer.encode('</Parallel>')[0]
        cls.start_path_token = cls.tokenizer.encode('<Path>')[0]
        cls.end_path_token = cls.tokenizer.encode('</Path>')[0]
        cls.start_summary_token = cls.tokenizer.encode('<Summary>')[0]
        cls.end_summary_token = cls.tokenizer.encode('</Summary>')[0]
        cls.new_line_token = cls.tokenizer.encode('\n')

        cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length
        cls.response_length = config.actor_rollout_ref.rollout.response_length
        cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)

    @rollout_trace_op
    async def run(
        self,
        messages: list[dict[str, Any]],
        sampling_params: dict[str, Any],
    ) -> AgentLoopOutput:

        
        prompt_ids = await self.loop.run_in_executor(
            None,
            lambda: self.tokenizer.apply_chat_template(
                messages, add_generation_prompt=True, tokenize=True
            ),
        )
        init_len      = len(prompt_ids)                         
        position_ids  = torch.arange(init_len, dtype=torch.long)

        response_mask = []                                     
        iterations    = 0
        request_id    = uuid4().hex

        left_pad_len = self.prompt_length - init_len
        if left_pad_len < 0:
            print(f"Warning: prompt length {self.prompt_length} is less than initial prompt length {init_len}, truncating prompt.")
            left_pad_len = 0

        position_required_masks = []

        
        parallel_stack: list[dict] = []

        
        def append_tokens(new_ids: list[int],
                        is_parallel: bool = False,
                        path_spans: list[tuple[int, int]] | None = None, manual_mask_positions= None):
             # a) 校正 Path 内位置 & b) 写跨 Path 掩码
            
            nonlocal prompt_ids, position_ids, parallel_stack, position_required_masks, left_pad_len

            start = len(prompt_ids) # correct
            prompt_ids.extend(new_ids) # correct -> merge new_ids into prompt_ids
            response_mask.extend([1] * len(new_ids)) # correct -> extend response_mask with 1s for new_ids
            incr = torch.arange(1, len(new_ids) + 1, dtype=torch.long) + position_ids[-1] # -> correct
            position_ids = torch.cat([position_ids, incr]) # correct -> extend position_ids with new_ids' positions

            if not is_parallel:
                return

            blk     = parallel_stack.pop()
            base_p  = blk["base"]
            longest = blk["longest"]
            

            # a) 校正 Path 内位置 & b) 写跨 Path 掩码
            for i, (s_rel, e_rel) in enumerate(path_spans):
                
            
                s_abs, e_abs = start + s_rel, start + e_rel
                # the first path we do not need to shift
                shift = position_ids[s_abs] - (base_p)
                # print("shift", shift.item())
                if shift.item():
                    position_ids[s_abs:e_abs] -= shift
                path_max_pos = position_ids[e_abs - 1]  
                longest = max(longest, (path_max_pos - base_p).item())
                # 构建跨path mask
                for j, (sj_rel, ej_rel) in enumerate(path_spans):
                    if j == i: continue
                    sj_abs, ej_abs = left_pad_len + start + sj_rel, left_pad_len + start + ej_rel
                    
                    if e_abs + left_pad_len > ej_abs:
                        if ej_abs < self.prompt_length + self.response_length:
                            position_required_masks.append((left_pad_len, s_abs + left_pad_len, e_abs + left_pad_len, sj_abs, ej_abs))
                            print(f"mask {s_abs+left_pad_len}:{e_abs+left_pad_len} 与 {sj_abs}:{ej_abs} attention")
                    elif e_abs + left_pad_len < ej_abs:
                        if sj_abs < self.prompt_length + self.response_length:
                            position_required_masks.append((left_pad_len, s_abs + left_pad_len, e_abs + left_pad_len, sj_abs, ej_abs))
                            print(f"mask {s_abs+left_pad_len}:{e_abs+left_pad_len} 与 {sj_abs}:{ej_abs} attention")
                    

            s2, e2 = path_spans[1]

            p_end_abs   = start + e2
            
            desired_end = base_p + longest + 1
            shift2      = position_ids[p_end_abs] - desired_end
            if shift2.item():
                position_ids[p_end_abs:] -= shift2
        def should_stop() -> bool:
            gen_len = len(position_ids) - init_len
            if gen_len >= self.response_length:
                return True
            if (
                self.max_iterations_for_parallel_thinking
                and iterations >= self.max_iterations_for_parallel_thinking
            ):
                return True
            return False

        
        while True:
            sp_main = {**sampling_params, "stop_token_ids": [self.start_parallel_token, self.eos_token_id]}
            ids = await self.server_manager.generate(
                request_id=request_id,
                prompt_ids=prompt_ids,
                sampling_params=sp_main,
            )
            # 调用append_tokens方法,计算attention mask +position_token
            append_tokens(ids)
            # 如果达到最大长度或者没有检查到parallel,停止
            if should_stop() or not await self.check_parallel(ids):
                break

            assert ids[-1] != self.eos_token_id

            # elict parallel thinking 
            base_pos = position_ids[-1].item() + 1 # the position of the <Path> token
            parallel_stack.append({"base": base_pos, "longest": 0})
			
			# 调用_call_parallel_thinking方法,动态生成parallel
            parallel_ids, path_spans, manual_mask_positions = await self._call_parallel_thinking(
                prompt_ids, sampling_params
            )
            append_tokens(parallel_ids, is_parallel=True, path_spans=path_spans, manual_mask_positions=manual_mask_positions)

            iterations += 1
            if should_stop():
                break

        response_ids = prompt_ids[-len(response_mask) :]
        prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]
        assert init_len == (len(prompt_ids))
        
        if left_pad_len > 0:
            position_ids = torch.cat([
                torch.zeros(left_pad_len, dtype=torch.long),
                position_ids
            ])
        if len(response_ids) < self.response_length:
            right_pad_len = self.response_length - len(response_ids)
            position_ids = torch.cat([
                position_ids,
                torch.zeros(right_pad_len, dtype=torch.long)])
            # bool_mask[:, -right_pad_len:] = False
        else:
            position_ids = position_ids[:self.prompt_length + self.response_length]
            response_ids = response_ids[: self.response_length]
            response_mask = response_mask[: self.response_length]

        return AgentLoopOutput(
            prompt_ids          = prompt_ids,
            response_ids        = response_ids,
            response_mask       = response_mask,
            position_required_mask       = position_required_masks,
            multiverse_pos_ids  = position_ids.cpu(),        # 1-D
            num_turns           = iterations + 1,
            metrics             = {},
        )



    async def check_parallel(self, response_ids):
        """
        check whether need to conduct parallel thinking
        """
        if response_ids[-1] == self.start_parallel_token:
            return True
        else:
            return False
	
  
  	# 动态生成parallel
    async def _call_parallel_thinking(
        self,
        prompt_ids: list[int],
        sampling_params: dict[str, Any],
    ):
        print("Conducting parallel thinking...")
        num_paths = self.num_paths
        max_len   = self.max_path_response_length
        PATH_OPEN, PATH_CLOSE = self.start_path_token, self.end_path_token
        manual_mask_positions = []

        
        async def _gen_single_path(seed_i: int, prompt_i: list[int]):
            sp = copy.deepcopy(sampling_params)
            sp.update({"seed": seed_i, "n": 1, "stop_token_ids": [PATH_CLOSE, self.eos_token_id],
                    "temperature": 1.0})
            # sp.update({"n": 1, "stop_token_ids": [PATH_CLOSE, self.eos_token_id],
            #         "temperature": 1.0})
            manual_append = False
            ids = await self.server_manager.generate(
                request_id=uuid4().hex,
                prompt_ids=prompt_i + [PATH_OPEN],
                sampling_params=sp,
            )
            if max_len and len(ids) > max_len:
                ids = ids[:max_len]
            
            if ids[-1] != PATH_CLOSE:
                if ids[-1] == self.eos_token_id:
                    ids[-1] = PATH_CLOSE
                else:
                    ids.append(PATH_CLOSE)
                manual_append = True
            return [PATH_OPEN] + ids, manual_append

        tasks = []
        for i in range(num_paths):
            # prompt_i **不再含 PATH_OPEN**,防止 tag 重复
            prompt_i = prompt_ids
            tasks.append(
                asyncio.create_task(
                    _gen_single_path(random.randint(0, 2**31 - 1), prompt_i)
                )
            )

        path_token_lists = await asyncio.gather(*tasks)     

        parallel_ids = []             
        path_spans   = []
        cursor       = 0              

        #### <Path> </Path><Path> </Path> -> path_spans (0, len(path_1)) (len(path_1), len(path_1) + len(path_2))
        for (ids, manual_append_flag) in path_token_lists:
            assert ids[0] == PATH_OPEN and ids[-1] == PATH_CLOSE, \
                f"Path tokens must start with {PATH_OPEN} and end with {PATH_CLOSE}, got {ids}"
            span_start = cursor 
            parallel_ids.extend(ids)
            cursor     += len(ids)
            span_end   = cursor
            path_spans.append((span_start, span_end))
            manual_mask_positions.append(span_start)
            if manual_append_flag:
                manual_mask_positions.append(span_end-1)

        # ---------- 3) add </Parallel><Summary> ----------
        parallel_ids.append(self.end_parallel_token)           # </Parallel>
        manual_mask_positions.append(cursor)
        assert parallel_ids[cursor] == self.end_parallel_token
        cursor += 1
        parallel_ids.extend(self.new_line_token)          # \n
        manual_mask_positions.append(cursor)
        cursor += len(self.new_line_token)
        parallel_ids.append(self.start_summary_token)          # <Summary>
        manual_mask_positions.append(cursor)
        assert parallel_ids[cursor] == self.start_summary_token
        cursor += 1
        
        # if exploration stage alos superass the max length, we will not generate summary
        if len(prompt_ids) + 2 + len(parallel_ids) + 1 < self.prompt_length + self.response_length:
            print("Conducting Summary...")
            manual_append_summary = False
            sp_sum = copy.deepcopy(sampling_params)
            sp_sum.update({"n": 1, "stop_token_ids": [self.end_summary_token, self.eos_token_id]})
            summary_ids = await self.server_manager.generate(
                request_id=uuid4().hex,
                prompt_ids=prompt_ids + parallel_ids,   # 现在的完整 prompt
                sampling_params=sp_sum,
            )
            if summary_ids[-1] != self.end_summary_token:
                if summary_ids[-1] == self.eos_token_id:
                    summary_ids[-1] = self.end_summary_token
                else:
                    summary_ids.append(self.end_summary_token)
                manual_append_summary = True
            parallel_ids.extend(summary_ids)
            cursor += len(summary_ids)
            if manual_append_summary:
                manual_mask_positions.append(cursor-1)
            # print("asadas", parallel_ids[cursor-1])
            assert parallel_ids[cursor-1] == self.end_summary_token
            assert parallel_ids[-1] == self.end_summary_token, \
                f"Parallel thinking did not end with {self.end_summary_token}, got {parallel_ids[-1]}"
        # print(self.tokenizer.decode(parallel_ids, skip_special_tokens=False))
        return parallel_ids, path_spans, manual_mask_positions        # path_spans 已是相对 parallel_ids       
        



两者的微小区别:

对于ToolAgentLoop来说,工具的解析是模型先生成一长段话,然后从里面解析。而对于Parallel-R1来说,是模型生成token之后,立刻停止生成,做状态机的跳转。

举个例子,假如当前的工具调用为<tool_call></tool_call>(注意模型)

状态机(State Machine)是一种数学模型,用于描述系统在不同状态之间的转换。在计算机科学中,它通常用来处理具有多个阶段或状态的流程。
状态机基本概念
组成要素

  1. 状态(States):系统可能处于的不同情况
  2. 事件(Events):触发状态转换的条件或输入
  3. 动作(Actions):状态转换时执行的操作
  4. 转换(Transitions):从一个状态到另一个状态的规则
    状态图表示
状态A ──事件──→ 状态B
   ↑              ↓
   └──事件── 状态C

ToRL

同样的,在字节今年三月发的论文ToRL中,也有类似的状态机跳转流程。

主文件定义在ToRL/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py中,主函数为generate。
由于ToRL开源的时候,verl还没有开发AgentLoop功能,所以直接在Rollout中定义了Python工具解析的过程。
**注意,和AgentLoop不同,AgentLoop中,模型会直接生成Python 工具的tool schema(json格式);而在ToRL中,模型会生成python 的markdown块,然后在人为的将代码解析为JSON格式 **
下面我们一起来看一下:

  # 在```python ```中提取代码
  def extract_program(result: str, last_only=True):
    """
    extract the program after "```python", and before "```"
    """
    program = ''
    start = False
    for line in result.split('\n'):
        if line.startswith('```python') or line.endswith('```python'):
            if last_only:
                program = ''  # only extract the last program
            else:
                program += '\n# ========\n'
            start = True
        elif line.startswith('```'):
            start = False
        elif start:
            program += line + '\n'
    if start:
        # the code is incomplete
        program = ''
    return program

# 提取Python代码,并且将其包裹在code字段中
def _detect_tool(text: str) -> Tuple[bool, str, str, str]:
    program = extract_program(text)
    if program:
        program = json.dumps({'code': program}, ensure_ascii=False)
    return (program != ''), PythonExecutor.name, program, text
  
  
  
  
  def _tir_generate(self, prompts=None, sampling_params=None, prompt_token_ids=None, use_tqdm=False):
        sampling_params=copy.deepcopy(sampling_params)
        # prompts=self.tokenizer.batch_decode(prompt_token_ids, skip_special_tokens=True)
        prompts=[self.tokenizer.decode(prompt['prompt_token_ids'], skip_special_tokens=False) for prompt in prompts]
        prompts=[prompt for prompt in prompts for _ in range(sampling_params.n) ]
        sampling_params.n=1
        sampling_params.detokenize=True
        sampling_params.stop=["```output"]
        samples_info=[{"prompt": prompt, "sequence": prompt, "response": "", "stop": False, "finish_reason": None,"index": index, "mask_info": [], "execution_pass": 0} for index, prompt in enumerate(prompts)]
        program2output=[]
        num_llm_calls_available=copy.deepcopy(self.config.num_llm_calls_available)
        # 如果当前还有剩余的工具调用
        while num_llm_calls_available >= 0:
            if num_llm_calls_available==0: sampling_params.stop=None
            num_llm_calls_available-=1
            # llm generate response, stop at eos token or ```output
            input_prompts, indices=self._get_prompts_and_indices(samples_info)
            input_prompts = [{
                'prompt_token_ids': self.tokenizer.encode(x, add_special_tokens=False)[:self.config.prompt_length+self.config.response_length]} for x in input_prompts]
            ## !!!当模型生成```output时,停止生成(此时的outputs中不包含```output)(因为```output是模型自己生成的,所以最后的loss要计算```output标签)
            outputs = self.inference_engine.generate(prompts=input_prompts, sampling_params=sampling_params, use_tqdm=use_tqdm)
            sorted_outputs = sorted(outputs, key=lambda output: int(output.request_id))
            responses=[x.outputs[0].text for x in sorted_outputs]
            finish_reason=[x.outputs[0].finish_reason for x in sorted_outputs]
            stop_reason=[x.outputs[0].stop_reason for x in sorted_outputs]
            # 如果当前是最后一次工具调用,那么移除停止词(让模型完整生成)
            if num_llm_calls_available==-1:
                for i ,index in enumerate(indices):
                    samples_info[index]['response']+=responses[i]
                    samples_info[index]['sequence']+=responses[i]
                    samples_info[index]['stop']=True
                    samples_info[index]['finish_reason']=finish_reason[i]
                break
						# 是否需要执行Python
            def _python_execution(finish_reason, stop_reason):
                if finish_reason=='stop' and stop_reason==None: return False
                if finish_reason=='stop' and stop_reason=='```output': return True
                if finish_reason=='length': False
                return False
            # 检查是否需要执行Python
            is_execution=[_python_execution(finish_reason[i], stop_reason[i]) for i in range(len(finish_reason))]
            # 如果所有样本都不需要执行代码,跳出循环
            if all([not x for x in is_execution]): break

            # 检测每个响应中的工具调用(提取代码)
            tool_infos=[ _detect_tool(response) for response in responses]
            
            # 收集需要执行代码的索引和输入
            tool_indices=[]
            tool_inputs=[]
            for i, tool_info in enumerate(tool_infos):
                if tool_info[0] and is_execution[i]: # 如果检测到工具调用且需要执行
                    tool_indices.append(i)
                    tool_inputs.append(tool_info[2]) # tool_info[2]包含代码
            # 定义后处理函数:处理代码执行结果,将其包裹在```output ```中,并且将错误信息也包含在```output ```中
            def postproc_observation(observation):
                execution_pass=0
                try:
                    observation_list=observation
                    if observation_list[-1] == 'Done':
                        observation = observation_list[0]
                        execution_pass=1
                    else:
                        observation = observation_list[-1]
                except Exception:
                    observation="Error"
                if "Error" in observation: observation=observation.strip().split("\n")[-1]
                if len(observation.strip())==0: observation="timeout_decorator.timeout_decorator.TimeoutError: 'Timed Out'"
                observation = observation.strip()
                if len(observation)>=256:
                    observation = observation[:128]+"..."+observation[-128:]
                observation = f'{OBS_START}\n{observation}{OBS_END}'
                return observation, execution_pass

            # execute python code

            # observations=self.executor.batch_apply([json5.loads(x)['code'] for x in tool_inputs])
            observations=self.code_interpreter_batch_call([json5.loads(x)['code'] for x in tool_inputs])
            
            # construction responses from observations
            responses=[response+"\n" if not response.endswith('\n') else response for response in responses]
            responses_w_res=copy.deepcopy(responses)
            execution_passes=[0 for _ in range(len(responses))]
            for i, index in enumerate(tool_indices):
                processed_observation=postproc_observation(observations[i])
                responses_w_res[index]+=processed_observation[0]
                execution_passes[index]=processed_observation[1]
            
            # program2output.append([{"code": tool_input, "answer": postproc_observation(observations[idx])} for idx, tool_input in enumerate(tool_inputs)])
            # update samples_info
            for i ,index in enumerate(indices):
                # !!!!将```output   后面,一直到回答末尾的loss都mask掉(注意,计算```output的loss),因为```output是模型自己生成的
                mask=[ len(responses[i]) + len('```output'), len(responses_w_res[i]) ]
                samples_info[index]['mask_info'].append(mask)
                samples_info[index]['response']+=responses_w_res[i]
                samples_info[index]['sequence']+=responses_w_res[i]
                samples_info[index]['stop']=not is_execution[i]
                samples_info[index]['finish_reason']=finish_reason[i]
                samples_info[index]['execution_pass']=execution_passes[i]
        
        for i, line in enumerate(samples_info):
            if samples_info[i]['finish_reason']!='length': samples_info[i]['response']+=self.tokenizer.eos_token
        
        responses_ids=[]
        tool_output_masks=[]
        execution_passes=[]
        for idx, sample_info in enumerate(samples_info):
            response_id, tool_output_mask = self._tokenize_and_find_mask_token_indices(sample_info)
            responses_ids.append(response_id[:self.config.response_length])
            tool_output_masks.append(tool_output_mask[:self.config.response_length])
            execution_passes.append(sample_info['execution_pass'])
            # save id and mask to check correctness
        #     samples_info[idx]['responses_id']=response_id[:self.config.response_length]
        #     samples_info[idx]['tool_output_mask']=tool_output_mask[:self.config.response_length].tolist()
        

        # with open("/mnt/bn/seedllm3-lixuefeng-2/code/o1/verl-tir/sample_infos.json", 'w', encoding='utf-8') as f:
        #     json.dump(samples_info, f, ensure_ascii=False, indent=2)
        return responses_ids, tool_output_masks, torch.tensor(execution_passes, dtype=torch.long)

posted @ 2025-12-20 21:49  Brain404  阅读(35)  评论(0)    收藏  举报