【PP-RL代码实现】rllm推理引擎 vs parallel 推理引擎

rllm的React推理:见这篇文章(主要是Agent和环境交互的过程)

本文中讲的是更加底层的推理,涉及到attention mask和position embedding等

注意:sft阶段,不会使用推理引擎,使用在RL阶段,才会使用推理引擎

rllm

定义在rllm/rllm/engine/rollout/verl_engine.py

import uuid

from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine
from rllm.parser import ChatTemplateParser
from rllm.workflows import TerminationEvent, TerminationReason
from verl.experimental.agent_loop.agent_loop import AsyncLLMServerManager


class VerlEngine(RolloutEngine):
    def __init__(self, config, rollout_manager, tokenizer, **kwargs):

    async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput:
        application_id = kwargs.pop("application_id", str(uuid.uuid4()))
        validate = self.validate or kwargs.pop("validate", False)
        enforce_max_prompt_length = kwargs.pop("enforce_max_prompt_length", True)

        # these go to the parser
        tools = kwargs.pop("tools", [])
        accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning)

        sampling_params = self.val_sampling_params.copy() if self.validate or validate else self.train_sampling_params.copy()
        sampling_params.update(kwargs)

        max_tokens = sampling_params.pop("max_tokens", sampling_params.pop("max_new_tokens", self.max_response_length))
			
        # 构建提示
        prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True, tools=tools, accumulate_reasoning=accumulate_reasoning)
        # 编码提示
        prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
        prompt_length = len(prompt_ids)

        if enforce_max_prompt_length and prompt_length > self.max_prompt_length:
            raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED)
			
       # 生成回答
        completion_ids: list[int] = await self.server_manager.generate(request_id=application_id, prompt_ids=prompt_ids, sampling_params=sampling_params)

        finish_reason = "stop"
        if len(completion_ids) >= max_tokens:
            finish_reason = "length"
            completion_ids = completion_ids[:max_tokens]

        completion_text = self.tokenizer.decode(completion_ids, skip_special_tokens=True)
        parsed_output = self.chat_parser.parse_completion(completion_ids)

        return ModelOutput(
            text=completion_text,
            content=parsed_output["content"],
            reasoning=parsed_output["reasoning"],
            tool_calls=parsed_output["tool_calls"],
            prompt_ids=prompt_ids,
            completion_ids=completion_ids,
            prompt_length=prompt_length,
            completion_length=len(completion_ids),
            finish_reason=finish_reason,
        )


parallel-R1的推理引擎

修改了verl框架中的AgentLoop,更加底层,涉及到了注意力mask和位置编码的修改

定义在verl/verl/parallel_thinking_generation_v3/parallel_thinking_loop_v3.py

@register("parallel_thinking_agent_v3")
class ParallelThinkingAgentLoopV3(AgentLoopBase):
    @classmethod
    def init_class(cls, config, tokenizer, **kwargs):
        if cls._class_initialized:
            return
        cls._class_initialized = True
        print("Performing class-level ParallelThinkingV3AgentLoop initialization")

        # Initialize tools from config file
        cls.tokenizer = tokenizer
        cls.add_diverse_prefix = config.actor_rollout_ref.rollout.agent.add_diverse_prefix
        cls.max_iterations_for_parallel_thinking = config.actor_rollout_ref.rollout.agent.max_iterations_for_parallel_thinking
        cls.num_paths = config.actor_rollout_ref.rollout.agent.num_paths
        cls.max_path_response_length = config.actor_rollout_ref.rollout.agent.max_path_response_length

        cls.eos_token_id = cls.tokenizer.eos_token_id
        cls.start_parallel_token = cls.tokenizer.encode('<Parallel>')[0]
        cls.end_parallel_token = cls.tokenizer.encode('</Parallel>')[0]
        cls.start_path_token = cls.tokenizer.encode('<Path>')[0]
        cls.end_path_token = cls.tokenizer.encode('</Path>')[0]
        cls.start_summary_token = cls.tokenizer.encode('<Summary>')[0]
        cls.end_summary_token = cls.tokenizer.encode('</Summary>')[0]
        cls.new_line_token = cls.tokenizer.encode('\n')

        cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length
        cls.response_length = config.actor_rollout_ref.rollout.response_length
        cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)

    @rollout_trace_op
    async def run(
        self,
        messages: list[dict[str, Any]],
        sampling_params: dict[str, Any],
    ) -> AgentLoopOutput:

        
        prompt_ids = await self.loop.run_in_executor(
            None,
            lambda: self.tokenizer.apply_chat_template(
                messages, add_generation_prompt=True, tokenize=True
            ),
        )
        init_len      = len(prompt_ids)                         
        position_ids  = torch.arange(init_len, dtype=torch.long)

        response_mask = []                                     
        iterations    = 0
        request_id    = uuid4().hex

        left_pad_len = self.prompt_length - init_len
        if left_pad_len < 0:
            print(f"Warning: prompt length {self.prompt_length} is less than initial prompt length {init_len}, truncating prompt.")
            left_pad_len = 0

        position_required_masks = []

        
        parallel_stack: list[dict] = []

        
        def append_tokens(new_ids: list[int],
                        is_parallel: bool = False,
                        path_spans: list[tuple[int, int]] | None = None, manual_mask_positions= None):
             # a) 校正 Path 内位置 & b) 写跨 Path 掩码
            
            nonlocal prompt_ids, position_ids, parallel_stack, position_required_masks, left_pad_len

            start = len(prompt_ids) # correct
            prompt_ids.extend(new_ids) # correct -> merge new_ids into prompt_ids
            response_mask.extend([1] * len(new_ids)) # correct -> extend response_mask with 1s for new_ids
            incr = torch.arange(1, len(new_ids) + 1, dtype=torch.long) + position_ids[-1] # -> correct
            position_ids = torch.cat([position_ids, incr]) # correct -> extend position_ids with new_ids' positions

            if not is_parallel:
                return

            blk     = parallel_stack.pop()
            base_p  = blk["base"]
            longest = blk["longest"]
            

            # a) 校正 Path 内位置 & b) 写跨 Path 掩码
            for i, (s_rel, e_rel) in enumerate(path_spans):
                
            
                s_abs, e_abs = start + s_rel, start + e_rel
                # the first path we do not need to shift
                shift = position_ids[s_abs] - (base_p)
                # print("shift", shift.item())
                if shift.item():
                    position_ids[s_abs:e_abs] -= shift
                path_max_pos = position_ids[e_abs - 1]  
                longest = max(longest, (path_max_pos - base_p).item())
                # 构建跨path mask
                for j, (sj_rel, ej_rel) in enumerate(path_spans):
                    if j == i: continue
                    sj_abs, ej_abs = left_pad_len + start + sj_rel, left_pad_len + start + ej_rel
                    
                    if e_abs + left_pad_len > ej_abs:
                        if ej_abs < self.prompt_length + self.response_length:
                            position_required_masks.append((left_pad_len, s_abs + left_pad_len, e_abs + left_pad_len, sj_abs, ej_abs))
                            print(f"mask {s_abs+left_pad_len}:{e_abs+left_pad_len} 与 {sj_abs}:{ej_abs} attention")
                    elif e_abs + left_pad_len < ej_abs:
                        if sj_abs < self.prompt_length + self.response_length:
                            position_required_masks.append((left_pad_len, s_abs + left_pad_len, e_abs + left_pad_len, sj_abs, ej_abs))
                            print(f"mask {s_abs+left_pad_len}:{e_abs+left_pad_len} 与 {sj_abs}:{ej_abs} attention")
                    

            s2, e2 = path_spans[1]

            p_end_abs   = start + e2
            
            desired_end = base_p + longest + 1
            shift2      = position_ids[p_end_abs] - desired_end
            if shift2.item():
                position_ids[p_end_abs:] -= shift2
        def should_stop() -> bool:
            gen_len = len(position_ids) - init_len
            if gen_len >= self.response_length:
                return True
            if (
                self.max_iterations_for_parallel_thinking
                and iterations >= self.max_iterations_for_parallel_thinking
            ):
                return True
            return False

        
        while True:
            sp_main = {**sampling_params, "stop_token_ids": [self.start_parallel_token, self.eos_token_id]}
            ids = await self.server_manager.generate(
                request_id=request_id,
                prompt_ids=prompt_ids,
                sampling_params=sp_main,
            )
            # 调用append_tokens方法,计算attention mask +position_token
            append_tokens(ids)
            # 如果达到最大长度或者没有检查到parallel,停止
            if should_stop() or not await self.check_parallel(ids):
                break

            assert ids[-1] != self.eos_token_id

            # elict parallel thinking 
            base_pos = position_ids[-1].item() + 1 # the position of the <Path> token
            parallel_stack.append({"base": base_pos, "longest": 0})
			
			# 调用_call_parallel_thinking方法,动态生成parallel
            parallel_ids, path_spans, manual_mask_positions = await self._call_parallel_thinking(
                prompt_ids, sampling_params
            )
            append_tokens(parallel_ids, is_parallel=True, path_spans=path_spans, manual_mask_positions=manual_mask_positions)

            iterations += 1
            if should_stop():
                break

        response_ids = prompt_ids[-len(response_mask) :]
        prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]
        assert init_len == (len(prompt_ids))
        
        if left_pad_len > 0:
            position_ids = torch.cat([
                torch.zeros(left_pad_len, dtype=torch.long),
                position_ids
            ])
        if len(response_ids) < self.response_length:
            right_pad_len = self.response_length - len(response_ids)
            position_ids = torch.cat([
                position_ids,
                torch.zeros(right_pad_len, dtype=torch.long)])
            # bool_mask[:, -right_pad_len:] = False
        else:
            position_ids = position_ids[:self.prompt_length + self.response_length]
            response_ids = response_ids[: self.response_length]
            response_mask = response_mask[: self.response_length]

        return AgentLoopOutput(
            prompt_ids          = prompt_ids,
            response_ids        = response_ids,
            response_mask       = response_mask,
            position_required_mask       = position_required_masks,
            multiverse_pos_ids  = position_ids.cpu(),        # 1-D
            num_turns           = iterations + 1,
            metrics             = {},
        )



    async def check_parallel(self, response_ids):
        """
        check whether need to conduct parallel thinking
        """
        if response_ids[-1] == self.start_parallel_token:
            return True
        else:
            return False
	
  
  	# 动态生成parallel
    async def _call_parallel_thinking(
        self,
        prompt_ids: list[int],
        sampling_params: dict[str, Any],
    ):
        print("Conducting parallel thinking...")
        num_paths = self.num_paths
        max_len   = self.max_path_response_length
        PATH_OPEN, PATH_CLOSE = self.start_path_token, self.end_path_token
        manual_mask_positions = []

        
        async def _gen_single_path(seed_i: int, prompt_i: list[int]):
            sp = copy.deepcopy(sampling_params)
            sp.update({"seed": seed_i, "n": 1, "stop_token_ids": [PATH_CLOSE, self.eos_token_id],
                    "temperature": 1.0})
            # sp.update({"n": 1, "stop_token_ids": [PATH_CLOSE, self.eos_token_id],
            #         "temperature": 1.0})
            manual_append = False
            ids = await self.server_manager.generate(
                request_id=uuid4().hex,
                prompt_ids=prompt_i + [PATH_OPEN],
                sampling_params=sp,
            )
            if max_len and len(ids) > max_len:
                ids = ids[:max_len]
            
            if ids[-1] != PATH_CLOSE:
                if ids[-1] == self.eos_token_id:
                    ids[-1] = PATH_CLOSE
                else:
                    ids.append(PATH_CLOSE)
                manual_append = True
            return [PATH_OPEN] + ids, manual_append

        tasks = []
        for i in range(num_paths):
            # prompt_i **不再含 PATH_OPEN**,防止 tag 重复
            prompt_i = prompt_ids
            tasks.append(
                asyncio.create_task(
                    _gen_single_path(random.randint(0, 2**31 - 1), prompt_i)
                )
            )

        path_token_lists = await asyncio.gather(*tasks)     

        parallel_ids = []             
        path_spans   = []
        cursor       = 0              

        #### <Path> </Path><Path> </Path> -> path_spans (0, len(path_1)) (len(path_1), len(path_1) + len(path_2))
        for (ids, manual_append_flag) in path_token_lists:
            assert ids[0] == PATH_OPEN and ids[-1] == PATH_CLOSE, \
                f"Path tokens must start with {PATH_OPEN} and end with {PATH_CLOSE}, got {ids}"
            span_start = cursor 
            parallel_ids.extend(ids)
            cursor     += len(ids)
            span_end   = cursor
            path_spans.append((span_start, span_end))
            manual_mask_positions.append(span_start)
            if manual_append_flag:
                manual_mask_positions.append(span_end-1)

        # ---------- 3) add </Parallel><Summary> ----------
        parallel_ids.append(self.end_parallel_token)           # </Parallel>
        manual_mask_positions.append(cursor)
        assert parallel_ids[cursor] == self.end_parallel_token
        cursor += 1
        parallel_ids.extend(self.new_line_token)          # \n
        manual_mask_positions.append(cursor)
        cursor += len(self.new_line_token)
        parallel_ids.append(self.start_summary_token)          # <Summary>
        manual_mask_positions.append(cursor)
        assert parallel_ids[cursor] == self.start_summary_token
        cursor += 1
        
        # if exploration stage alos superass the max length, we will not generate summary
        if len(prompt_ids) + 2 + len(parallel_ids) + 1 < self.prompt_length + self.response_length:
            print("Conducting Summary...")
            manual_append_summary = False
            sp_sum = copy.deepcopy(sampling_params)
            sp_sum.update({"n": 1, "stop_token_ids": [self.end_summary_token, self.eos_token_id]})
            summary_ids = await self.server_manager.generate(
                request_id=uuid4().hex,
                prompt_ids=prompt_ids + parallel_ids,   # 现在的完整 prompt
                sampling_params=sp_sum,
            )
            if summary_ids[-1] != self.end_summary_token:
                if summary_ids[-1] == self.eos_token_id:
                    summary_ids[-1] = self.end_summary_token
                else:
                    summary_ids.append(self.end_summary_token)
                manual_append_summary = True
            parallel_ids.extend(summary_ids)
            cursor += len(summary_ids)
            if manual_append_summary:
                manual_mask_positions.append(cursor-1)
            # print("asadas", parallel_ids[cursor-1])
            assert parallel_ids[cursor-1] == self.end_summary_token
            assert parallel_ids[-1] == self.end_summary_token, \
                f"Parallel thinking did not end with {self.end_summary_token}, got {parallel_ids[-1]}"
        # print(self.tokenizer.decode(parallel_ids, skip_special_tokens=False))
        return parallel_ids, path_spans, manual_mask_positions        # path_spans 已是相对 parallel_ids       
        



posted @ 2025-12-19 22:43  Brain404  阅读(1)  评论(0)    收藏  举报