【PP-RL代码实现】rllm推理引擎 vs parallel 推理引擎
rllm的React推理:见这篇文章(主要是Agent和环境交互的过程)
本文中讲的是更加底层的推理,涉及到attention mask和position embedding等
注意:sft阶段,不会使用推理引擎,使用在RL阶段,才会使用推理引擎
rllm
定义在rllm/rllm/engine/rollout/verl_engine.py中
import uuid
from rllm.engine.rollout.rollout_engine import ModelOutput, RolloutEngine
from rllm.parser import ChatTemplateParser
from rllm.workflows import TerminationEvent, TerminationReason
from verl.experimental.agent_loop.agent_loop import AsyncLLMServerManager
class VerlEngine(RolloutEngine):
def __init__(self, config, rollout_manager, tokenizer, **kwargs):
async def get_model_response(self, messages: list[dict], **kwargs) -> ModelOutput:
application_id = kwargs.pop("application_id", str(uuid.uuid4()))
validate = self.validate or kwargs.pop("validate", False)
enforce_max_prompt_length = kwargs.pop("enforce_max_prompt_length", True)
# these go to the parser
tools = kwargs.pop("tools", [])
accumulate_reasoning = kwargs.pop("accumulate_reasoning", self.accumulate_reasoning)
sampling_params = self.val_sampling_params.copy() if self.validate or validate else self.train_sampling_params.copy()
sampling_params.update(kwargs)
max_tokens = sampling_params.pop("max_tokens", sampling_params.pop("max_new_tokens", self.max_response_length))
# 构建提示
prompt = self.chat_parser.parse(messages, add_generation_prompt=True, is_first_msg=True, tools=tools, accumulate_reasoning=accumulate_reasoning)
# 编码提示
prompt_ids = self.tokenizer.encode(prompt, add_special_tokens=False)
prompt_length = len(prompt_ids)
if enforce_max_prompt_length and prompt_length > self.max_prompt_length:
raise TerminationEvent(TerminationReason.MAX_PROMPT_LENGTH_EXCEEDED)
# 生成回答
completion_ids: list[int] = await self.server_manager.generate(request_id=application_id, prompt_ids=prompt_ids, sampling_params=sampling_params)
finish_reason = "stop"
if len(completion_ids) >= max_tokens:
finish_reason = "length"
completion_ids = completion_ids[:max_tokens]
completion_text = self.tokenizer.decode(completion_ids, skip_special_tokens=True)
parsed_output = self.chat_parser.parse_completion(completion_ids)
return ModelOutput(
text=completion_text,
content=parsed_output["content"],
reasoning=parsed_output["reasoning"],
tool_calls=parsed_output["tool_calls"],
prompt_ids=prompt_ids,
completion_ids=completion_ids,
prompt_length=prompt_length,
completion_length=len(completion_ids),
finish_reason=finish_reason,
)
parallel-R1的推理引擎
修改了verl框架中的AgentLoop,更加底层,涉及到了注意力mask和位置编码的修改
定义在verl/verl/parallel_thinking_generation_v3/parallel_thinking_loop_v3.py 中
@register("parallel_thinking_agent_v3")
class ParallelThinkingAgentLoopV3(AgentLoopBase):
@classmethod
def init_class(cls, config, tokenizer, **kwargs):
if cls._class_initialized:
return
cls._class_initialized = True
print("Performing class-level ParallelThinkingV3AgentLoop initialization")
# Initialize tools from config file
cls.tokenizer = tokenizer
cls.add_diverse_prefix = config.actor_rollout_ref.rollout.agent.add_diverse_prefix
cls.max_iterations_for_parallel_thinking = config.actor_rollout_ref.rollout.agent.max_iterations_for_parallel_thinking
cls.num_paths = config.actor_rollout_ref.rollout.agent.num_paths
cls.max_path_response_length = config.actor_rollout_ref.rollout.agent.max_path_response_length
cls.eos_token_id = cls.tokenizer.eos_token_id
cls.start_parallel_token = cls.tokenizer.encode('<Parallel>')[0]
cls.end_parallel_token = cls.tokenizer.encode('</Parallel>')[0]
cls.start_path_token = cls.tokenizer.encode('<Path>')[0]
cls.end_path_token = cls.tokenizer.encode('</Path>')[0]
cls.start_summary_token = cls.tokenizer.encode('<Summary>')[0]
cls.end_summary_token = cls.tokenizer.encode('</Summary>')[0]
cls.new_line_token = cls.tokenizer.encode('\n')
cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length
cls.response_length = config.actor_rollout_ref.rollout.response_length
cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)
@rollout_trace_op
async def run(
self,
messages: list[dict[str, Any]],
sampling_params: dict[str, Any],
) -> AgentLoopOutput:
prompt_ids = await self.loop.run_in_executor(
None,
lambda: self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True
),
)
init_len = len(prompt_ids)
position_ids = torch.arange(init_len, dtype=torch.long)
response_mask = []
iterations = 0
request_id = uuid4().hex
left_pad_len = self.prompt_length - init_len
if left_pad_len < 0:
print(f"Warning: prompt length {self.prompt_length} is less than initial prompt length {init_len}, truncating prompt.")
left_pad_len = 0
position_required_masks = []
parallel_stack: list[dict] = []
def append_tokens(new_ids: list[int],
is_parallel: bool = False,
path_spans: list[tuple[int, int]] | None = None, manual_mask_positions= None):
# a) 校正 Path 内位置 & b) 写跨 Path 掩码
nonlocal prompt_ids, position_ids, parallel_stack, position_required_masks, left_pad_len
start = len(prompt_ids) # correct
prompt_ids.extend(new_ids) # correct -> merge new_ids into prompt_ids
response_mask.extend([1] * len(new_ids)) # correct -> extend response_mask with 1s for new_ids
incr = torch.arange(1, len(new_ids) + 1, dtype=torch.long) + position_ids[-1] # -> correct
position_ids = torch.cat([position_ids, incr]) # correct -> extend position_ids with new_ids' positions
if not is_parallel:
return
blk = parallel_stack.pop()
base_p = blk["base"]
longest = blk["longest"]
# a) 校正 Path 内位置 & b) 写跨 Path 掩码
for i, (s_rel, e_rel) in enumerate(path_spans):
s_abs, e_abs = start + s_rel, start + e_rel
# the first path we do not need to shift
shift = position_ids[s_abs] - (base_p)
# print("shift", shift.item())
if shift.item():
position_ids[s_abs:e_abs] -= shift
path_max_pos = position_ids[e_abs - 1]
longest = max(longest, (path_max_pos - base_p).item())
# 构建跨path mask
for j, (sj_rel, ej_rel) in enumerate(path_spans):
if j == i: continue
sj_abs, ej_abs = left_pad_len + start + sj_rel, left_pad_len + start + ej_rel
if e_abs + left_pad_len > ej_abs:
if ej_abs < self.prompt_length + self.response_length:
position_required_masks.append((left_pad_len, s_abs + left_pad_len, e_abs + left_pad_len, sj_abs, ej_abs))
print(f"mask {s_abs+left_pad_len}:{e_abs+left_pad_len} 与 {sj_abs}:{ej_abs} attention")
elif e_abs + left_pad_len < ej_abs:
if sj_abs < self.prompt_length + self.response_length:
position_required_masks.append((left_pad_len, s_abs + left_pad_len, e_abs + left_pad_len, sj_abs, ej_abs))
print(f"mask {s_abs+left_pad_len}:{e_abs+left_pad_len} 与 {sj_abs}:{ej_abs} attention")
s2, e2 = path_spans[1]
p_end_abs = start + e2
desired_end = base_p + longest + 1
shift2 = position_ids[p_end_abs] - desired_end
if shift2.item():
position_ids[p_end_abs:] -= shift2
def should_stop() -> bool:
gen_len = len(position_ids) - init_len
if gen_len >= self.response_length:
return True
if (
self.max_iterations_for_parallel_thinking
and iterations >= self.max_iterations_for_parallel_thinking
):
return True
return False
while True:
sp_main = {**sampling_params, "stop_token_ids": [self.start_parallel_token, self.eos_token_id]}
ids = await self.server_manager.generate(
request_id=request_id,
prompt_ids=prompt_ids,
sampling_params=sp_main,
)
# 调用append_tokens方法,计算attention mask +position_token
append_tokens(ids)
# 如果达到最大长度或者没有检查到parallel,停止
if should_stop() or not await self.check_parallel(ids):
break
assert ids[-1] != self.eos_token_id
# elict parallel thinking
base_pos = position_ids[-1].item() + 1 # the position of the <Path> token
parallel_stack.append({"base": base_pos, "longest": 0})
# 调用_call_parallel_thinking方法,动态生成parallel
parallel_ids, path_spans, manual_mask_positions = await self._call_parallel_thinking(
prompt_ids, sampling_params
)
append_tokens(parallel_ids, is_parallel=True, path_spans=path_spans, manual_mask_positions=manual_mask_positions)
iterations += 1
if should_stop():
break
response_ids = prompt_ids[-len(response_mask) :]
prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]
assert init_len == (len(prompt_ids))
if left_pad_len > 0:
position_ids = torch.cat([
torch.zeros(left_pad_len, dtype=torch.long),
position_ids
])
if len(response_ids) < self.response_length:
right_pad_len = self.response_length - len(response_ids)
position_ids = torch.cat([
position_ids,
torch.zeros(right_pad_len, dtype=torch.long)])
# bool_mask[:, -right_pad_len:] = False
else:
position_ids = position_ids[:self.prompt_length + self.response_length]
response_ids = response_ids[: self.response_length]
response_mask = response_mask[: self.response_length]
return AgentLoopOutput(
prompt_ids = prompt_ids,
response_ids = response_ids,
response_mask = response_mask,
position_required_mask = position_required_masks,
multiverse_pos_ids = position_ids.cpu(), # 1-D
num_turns = iterations + 1,
metrics = {},
)
async def check_parallel(self, response_ids):
"""
check whether need to conduct parallel thinking
"""
if response_ids[-1] == self.start_parallel_token:
return True
else:
return False
# 动态生成parallel
async def _call_parallel_thinking(
self,
prompt_ids: list[int],
sampling_params: dict[str, Any],
):
print("Conducting parallel thinking...")
num_paths = self.num_paths
max_len = self.max_path_response_length
PATH_OPEN, PATH_CLOSE = self.start_path_token, self.end_path_token
manual_mask_positions = []
async def _gen_single_path(seed_i: int, prompt_i: list[int]):
sp = copy.deepcopy(sampling_params)
sp.update({"seed": seed_i, "n": 1, "stop_token_ids": [PATH_CLOSE, self.eos_token_id],
"temperature": 1.0})
# sp.update({"n": 1, "stop_token_ids": [PATH_CLOSE, self.eos_token_id],
# "temperature": 1.0})
manual_append = False
ids = await self.server_manager.generate(
request_id=uuid4().hex,
prompt_ids=prompt_i + [PATH_OPEN],
sampling_params=sp,
)
if max_len and len(ids) > max_len:
ids = ids[:max_len]
if ids[-1] != PATH_CLOSE:
if ids[-1] == self.eos_token_id:
ids[-1] = PATH_CLOSE
else:
ids.append(PATH_CLOSE)
manual_append = True
return [PATH_OPEN] + ids, manual_append
tasks = []
for i in range(num_paths):
# prompt_i **不再含 PATH_OPEN**,防止 tag 重复
prompt_i = prompt_ids
tasks.append(
asyncio.create_task(
_gen_single_path(random.randint(0, 2**31 - 1), prompt_i)
)
)
path_token_lists = await asyncio.gather(*tasks)
parallel_ids = []
path_spans = []
cursor = 0
#### <Path> </Path><Path> </Path> -> path_spans (0, len(path_1)) (len(path_1), len(path_1) + len(path_2))
for (ids, manual_append_flag) in path_token_lists:
assert ids[0] == PATH_OPEN and ids[-1] == PATH_CLOSE, \
f"Path tokens must start with {PATH_OPEN} and end with {PATH_CLOSE}, got {ids}"
span_start = cursor
parallel_ids.extend(ids)
cursor += len(ids)
span_end = cursor
path_spans.append((span_start, span_end))
manual_mask_positions.append(span_start)
if manual_append_flag:
manual_mask_positions.append(span_end-1)
# ---------- 3) add </Parallel><Summary> ----------
parallel_ids.append(self.end_parallel_token) # </Parallel>
manual_mask_positions.append(cursor)
assert parallel_ids[cursor] == self.end_parallel_token
cursor += 1
parallel_ids.extend(self.new_line_token) # \n
manual_mask_positions.append(cursor)
cursor += len(self.new_line_token)
parallel_ids.append(self.start_summary_token) # <Summary>
manual_mask_positions.append(cursor)
assert parallel_ids[cursor] == self.start_summary_token
cursor += 1
# if exploration stage alos superass the max length, we will not generate summary
if len(prompt_ids) + 2 + len(parallel_ids) + 1 < self.prompt_length + self.response_length:
print("Conducting Summary...")
manual_append_summary = False
sp_sum = copy.deepcopy(sampling_params)
sp_sum.update({"n": 1, "stop_token_ids": [self.end_summary_token, self.eos_token_id]})
summary_ids = await self.server_manager.generate(
request_id=uuid4().hex,
prompt_ids=prompt_ids + parallel_ids, # 现在的完整 prompt
sampling_params=sp_sum,
)
if summary_ids[-1] != self.end_summary_token:
if summary_ids[-1] == self.eos_token_id:
summary_ids[-1] = self.end_summary_token
else:
summary_ids.append(self.end_summary_token)
manual_append_summary = True
parallel_ids.extend(summary_ids)
cursor += len(summary_ids)
if manual_append_summary:
manual_mask_positions.append(cursor-1)
# print("asadas", parallel_ids[cursor-1])
assert parallel_ids[cursor-1] == self.end_summary_token
assert parallel_ids[-1] == self.end_summary_token, \
f"Parallel thinking did not end with {self.end_summary_token}, got {parallel_ids[-1]}"
# print(self.tokenizer.decode(parallel_ids, skip_special_tokens=False))
return parallel_ids, path_spans, manual_mask_positions # path_spans 已是相对 parallel_ids

浙公网安备 33010602011771号