AgentLoop(Verl)vs ParallelThinkingAgentLoopV3(Parallel-R1) vs ToRL
最近看到了一篇很有意思的论文Parallel-R1,是用RL训练一个并行推理的模型,大概的格式为:
<模型推理过程>
突然生成一个<parallel>,进入多路径推理
<parallel>
<path> ... </path> 每一条推理路径之间互相不可见(使用attention mask mask掉了)
<path> ... </path>
</parallel>
<summary>
对上面的路径进行总结
</summay>
<模型继续推理>
这篇论文中了NiPS Oral,整体读下来,确实很有水平的一篇文章。并且,我自己复现了一遍,和作者放出的日志基本大差不差。总的来说,是一篇很值得学习的文章。
这篇文章主要基于verl框架进行修改,主要修改了两个地方,第一个地方就是我们这篇文章讲到的AgentLoop,主要用于RL。我会对比Verl中的Agent多轮交互的loop实现,和这篇文章中的loop实现,来详细讲一讲。
下一个地方是SFT的Dataset实现,主要用于sft。
Verl AgentLoop
是Verl在今年六七月更新的新功能——Agent RL中的AgentLoop实现。是一个很经典的React设计,就是生成工具执行请求,执行工具,将执行结果插入到对话中的过程。
具体代码如下
@register("tool_agent")
class ToolAgentLoop(AgentLoopBase):
@classmethod
def init_class(cls, config, tokenizer, **kwargs):
...
@rollout_trace_op
async def run(self, messages: list[dict[str, Any]], sampling_params: dict[str, Any]) -> AgentLoopOutput:
metrics = {}
request_id = uuid4().hex
# 1. 生成prompt的token-id
prompt_ids = await self.loop.run_in_executor(
None,
lambda: self.tokenizer.apply_chat_template(
messages, tools=self.tool_schemas, add_generation_prompt=True, tokenize=True
),
)
response_mask = []
user_turns, assistant_turns = 0, 0
while True:
with simple_timer("generate_sequences", metrics):
# 2. 生成response的token id
response_ids = await self.server_manager.generate(
request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params
)
prompt_ids += response_ids
response_mask += [1] * len(response_ids)
assistant_turns += 1
# reach max response length
if len(response_mask) >= self.response_length:
break
# reach max assistant turns
if self.max_assistant_turns and assistant_turns >= self.max_assistant_turns:
break
# reach max user turns
if self.max_user_turns and user_turns >= self.max_user_turns:
break
# 3. 尝试从reponse中解析出工具调用(这里是用的应该是<tool_call> </tool_cal>关键词匹配来解析)
_, tool_calls = await self.tool_parser.extract_tool_calls(response_ids)
# 如果没有工具,那么跳出
if not tool_calls:
break
# 4. 如果有工具,那么执行工具,然后将工具执行结果加入到消息队列中
tasks = []
for tool_call in tool_calls[: self.max_parallel_calls]:
tasks.append(self._call_tool(tool_call))
with simple_timer("tool_calls", metrics):
tool_responses = await asyncio.gather(*tasks)
if any(isinstance(item, Exception) for item in tool_responses):
break
tool_response_ids = await self.loop.run_in_executor(
None,
lambda messages=tool_responses: self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True
),
)
tool_response_ids = tool_response_ids[len(self.system_prompt) :]
# NOTE: last turn should not be user turn, or the EOS token reward
# can't be propagated to previous token in GAE.
if len(response_mask) + len(tool_response_ids) >= self.response_length:
break
prompt_ids += tool_response_ids
response_mask += [0] * len(tool_response_ids)
user_turns += 1
# 5. 如此反复,直到达到最大循环次数
response_ids = prompt_ids[-len(response_mask) :]
prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]
output = AgentLoopOutput(
prompt_ids=prompt_ids,
response_ids=response_ids[: self.response_length],
response_mask=response_mask[: self.response_length],
num_turns=user_turns + assistant_turns + 1,
metrics=metrics,
)
return output
async def _call_tool(self, tool_call: FunctionCall) -> dict[str, str]:
"""执行工具"""
tool, instance_id = None, None
try:
# TODO: append malformed tool_call to the prompt: invalid function name or arguments
tool_name = tool_call.name
tool_args = json.loads(tool_call.arguments)
tool = self.tools[tool_name]
instance_id = await tool.create()
tool_response, _, _ = await tool.execute(instance_id, tool_args)
except Exception as e:
logger.exception(f"Error when executing tool: {e}")
return e
finally:
if tool and instance_id:
await tool.release(instance_id)
if len(tool_response) > self.max_tool_response_length:
if self.tool_response_truncate_side == "left":
tool_response = tool_response[: self.max_tool_response_length] + "...(truncated)"
elif self.tool_response_truncate_side == "right":
tool_response = "(truncated)..." + tool_response[-self.max_tool_response_length :]
else:
length = self.max_tool_response_length // 2
tool_response = tool_response[:length] + "...(truncated)..." + tool_response[-length:]
return {
"role": "tool",
"content": tool_response,
}
ParallelThinkingAgentLoopV3
对于这个Loop来说,其实单独看论文,感觉好像是单轮一次生成的,但其实从代码上看:这个loop也是多轮交互的。
大概过程如下:
- 主链路中,模型生成
<parallel>后,停止,开始执行并行思考;反之,如果当前没有生成<parallel>那么直接跳出循环,返回结果 - 对于每一个思考路径来说,首先插入
<path>标签 - 对于当前路径来说,遇见
</path>时停止 - 并行思考结束后,回到主路径中,并且添加
</Parallel>和<summary>标签 - 在主路径中继续生成,并遇到
</summary>时,停止 - 回到1,直到达到最大循环次数或者达到最大长度时,停止。
此外,除了这个多轮交互之外,作者还对attention mask 和位置编码进行了修改:
具体代码如下
def append_tokens(new_ids: list[int],
is_parallel: bool = False,
path_spans: list[tuple[int, int]] | None = None, manual_mask_positions= None):
nonlocal prompt_ids, position_ids, parallel_stack, position_required_masks, left_pad_len
start = len(prompt_ids) # correct
prompt_ids.extend(new_ids) # correct -> merge new_ids into prompt_ids
response_mask.extend([1] * len(new_ids)) # correct -> extend response_mask with 1s for new_ids
incr = torch.arange(1, len(new_ids) + 1, dtype=torch.long) + position_ids[-1] # -> correct
position_ids = torch.cat([position_ids, incr]) # correct -> extend position_ids with new_ids' positions
# 如果不是并行,那么直接返回
if not is_parallel:
return
blk = parallel_stack.pop()
base_p = blk["base"]
longest = blk["longest"]
# a) 校正 Path 内位置(包括:1. path内部:位置校正为path起始位置+path中的相对长度 2. path执行结束后,第一个token的位置编码=path起始时的位置编码+最大path长度)
# & b) 写跨 Path 掩码(让path和path之间的atention mask为True)
for i, (s_rel, e_rel) in enumerate(path_spans):
s_abs, e_abs = start + s_rel, start + e_rel
# the first path we do not need to shift
shift = position_ids[s_abs] - (base_p)
# print("shift", shift.item())
if shift.item():
position_ids[s_abs:e_abs] -= shift
path_max_pos = position_ids[e_abs - 1]
longest = max(longest, (path_max_pos - base_p).item())
# 构建跨path 的attention mask,让path和path之间的atention mask为True
for j, (sj_rel, ej_rel) in enumerate(path_spans):
if j == i: continue
sj_abs, ej_abs = left_pad_len + start + sj_rel, left_pad_len + start + ej_rel
if e_abs + left_pad_len > ej_abs:
if ej_abs < self.prompt_length + self.response_length:
position_required_masks.append((left_pad_len, s_abs + left_pad_len, e_abs + left_pad_len, sj_abs, ej_abs))
print(f"mask {s_abs+left_pad_len}:{e_abs+left_pad_len} 与 {sj_abs}:{ej_abs} attention")
elif e_abs + left_pad_len < ej_abs:
if sj_abs < self.prompt_length + self.response_length:
position_required_masks.append((left_pad_len, s_abs + left_pad_len, e_abs + left_pad_len, sj_abs, ej_abs))
print(f"mask {s_abs+left_pad_len}:{e_abs+left_pad_len} 与 {sj_abs}:{ej_abs} attention")
s2, e2 = path_spans[1]
p_end_abs = start + e2
desired_end = base_p + longest + 1
shift2 = position_ids[p_end_abs] - desired_end
if shift2.item():
position_ids[p_end_abs:] -= shift2
主要的修改就是这两点。
其实,从这里可以看出,作者的工程水平是很高的,光是第一点:多轮交互的逻辑,想要捋顺,并且实现都很困难。如果加上第二点,已经超出我目前的水平太多。在这种训练的框架中改这么底层的配置,并且work了,就是强大工程能力的最大表现。
完整的的loop类如下,这个代码写的很细致,推荐大家好好看
@register("parallel_thinking_agent_v3")
class ParallelThinkingAgentLoopV3(AgentLoopBase):
@classmethod
def init_class(cls, config, tokenizer, **kwargs):
if cls._class_initialized:
return
cls._class_initialized = True
print("Performing class-level ParallelThinkingV3AgentLoop initialization")
# Initialize tools from config file
cls.tokenizer = tokenizer
cls.add_diverse_prefix = config.actor_rollout_ref.rollout.agent.add_diverse_prefix
cls.max_iterations_for_parallel_thinking = config.actor_rollout_ref.rollout.agent.max_iterations_for_parallel_thinking
cls.num_paths = config.actor_rollout_ref.rollout.agent.num_paths
cls.max_path_response_length = config.actor_rollout_ref.rollout.agent.max_path_response_length
cls.eos_token_id = cls.tokenizer.eos_token_id
cls.start_parallel_token = cls.tokenizer.encode('<Parallel>')[0]
cls.end_parallel_token = cls.tokenizer.encode('</Parallel>')[0]
cls.start_path_token = cls.tokenizer.encode('<Path>')[0]
cls.end_path_token = cls.tokenizer.encode('</Path>')[0]
cls.start_summary_token = cls.tokenizer.encode('<Summary>')[0]
cls.end_summary_token = cls.tokenizer.encode('</Summary>')[0]
cls.new_line_token = cls.tokenizer.encode('\n')
cls.prompt_length = config.actor_rollout_ref.rollout.prompt_length
cls.response_length = config.actor_rollout_ref.rollout.response_length
cls.system_prompt = tokenizer.apply_chat_template([{}], add_generation_prompt=False, tokenize=True)
@rollout_trace_op
async def run(
self,
messages: list[dict[str, Any]],
sampling_params: dict[str, Any],
) -> AgentLoopOutput:
prompt_ids = await self.loop.run_in_executor(
None,
lambda: self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True
),
)
init_len = len(prompt_ids)
position_ids = torch.arange(init_len, dtype=torch.long)
response_mask = []
iterations = 0
request_id = uuid4().hex
left_pad_len = self.prompt_length - init_len
if left_pad_len < 0:
print(f"Warning: prompt length {self.prompt_length} is less than initial prompt length {init_len}, truncating prompt.")
left_pad_len = 0
position_required_masks = []
parallel_stack: list[dict] = []
def append_tokens(new_ids: list[int],
is_parallel: bool = False,
path_spans: list[tuple[int, int]] | None = None, manual_mask_positions= None):
# a) 校正 Path 内位置 & b) 写跨 Path 掩码
nonlocal prompt_ids, position_ids, parallel_stack, position_required_masks, left_pad_len
start = len(prompt_ids) # correct
prompt_ids.extend(new_ids) # correct -> merge new_ids into prompt_ids
response_mask.extend([1] * len(new_ids)) # correct -> extend response_mask with 1s for new_ids
incr = torch.arange(1, len(new_ids) + 1, dtype=torch.long) + position_ids[-1] # -> correct
position_ids = torch.cat([position_ids, incr]) # correct -> extend position_ids with new_ids' positions
if not is_parallel:
return
blk = parallel_stack.pop()
base_p = blk["base"]
longest = blk["longest"]
# a) 校正 Path 内位置 & b) 写跨 Path 掩码
for i, (s_rel, e_rel) in enumerate(path_spans):
s_abs, e_abs = start + s_rel, start + e_rel
# the first path we do not need to shift
shift = position_ids[s_abs] - (base_p)
# print("shift", shift.item())
if shift.item():
position_ids[s_abs:e_abs] -= shift
path_max_pos = position_ids[e_abs - 1]
longest = max(longest, (path_max_pos - base_p).item())
# 构建跨path mask
for j, (sj_rel, ej_rel) in enumerate(path_spans):
if j == i: continue
sj_abs, ej_abs = left_pad_len + start + sj_rel, left_pad_len + start + ej_rel
if e_abs + left_pad_len > ej_abs:
if ej_abs < self.prompt_length + self.response_length:
position_required_masks.append((left_pad_len, s_abs + left_pad_len, e_abs + left_pad_len, sj_abs, ej_abs))
print(f"mask {s_abs+left_pad_len}:{e_abs+left_pad_len} 与 {sj_abs}:{ej_abs} attention")
elif e_abs + left_pad_len < ej_abs:
if sj_abs < self.prompt_length + self.response_length:
position_required_masks.append((left_pad_len, s_abs + left_pad_len, e_abs + left_pad_len, sj_abs, ej_abs))
print(f"mask {s_abs+left_pad_len}:{e_abs+left_pad_len} 与 {sj_abs}:{ej_abs} attention")
s2, e2 = path_spans[1]
p_end_abs = start + e2
desired_end = base_p + longest + 1
shift2 = position_ids[p_end_abs] - desired_end
if shift2.item():
position_ids[p_end_abs:] -= shift2
def should_stop() -> bool:
gen_len = len(position_ids) - init_len
if gen_len >= self.response_length:
return True
if (
self.max_iterations_for_parallel_thinking
and iterations >= self.max_iterations_for_parallel_thinking
):
return True
return False
while True:
sp_main = {**sampling_params, "stop_token_ids": [self.start_parallel_token, self.eos_token_id]}
ids = await self.server_manager.generate(
request_id=request_id,
prompt_ids=prompt_ids,
sampling_params=sp_main,
)
# 调用append_tokens方法,计算attention mask +position_token
append_tokens(ids)
# 如果达到最大长度或者没有检查到parallel,停止
if should_stop() or not await self.check_parallel(ids):
break
assert ids[-1] != self.eos_token_id
# elict parallel thinking
base_pos = position_ids[-1].item() + 1 # the position of the <Path> token
parallel_stack.append({"base": base_pos, "longest": 0})
# 调用_call_parallel_thinking方法,动态生成parallel
parallel_ids, path_spans, manual_mask_positions = await self._call_parallel_thinking(
prompt_ids, sampling_params
)
append_tokens(parallel_ids, is_parallel=True, path_spans=path_spans, manual_mask_positions=manual_mask_positions)
iterations += 1
if should_stop():
break
response_ids = prompt_ids[-len(response_mask) :]
prompt_ids = prompt_ids[: len(prompt_ids) - len(response_mask)]
assert init_len == (len(prompt_ids))
if left_pad_len > 0:
position_ids = torch.cat([
torch.zeros(left_pad_len, dtype=torch.long),
position_ids
])
if len(response_ids) < self.response_length:
right_pad_len = self.response_length - len(response_ids)
position_ids = torch.cat([
position_ids,
torch.zeros(right_pad_len, dtype=torch.long)])
# bool_mask[:, -right_pad_len:] = False
else:
position_ids = position_ids[:self.prompt_length + self.response_length]
response_ids = response_ids[: self.response_length]
response_mask = response_mask[: self.response_length]
return AgentLoopOutput(
prompt_ids = prompt_ids,
response_ids = response_ids,
response_mask = response_mask,
position_required_mask = position_required_masks,
multiverse_pos_ids = position_ids.cpu(), # 1-D
num_turns = iterations + 1,
metrics = {},
)
async def check_parallel(self, response_ids):
"""
check whether need to conduct parallel thinking
"""
if response_ids[-1] == self.start_parallel_token:
return True
else:
return False
# 动态生成parallel
async def _call_parallel_thinking(
self,
prompt_ids: list[int],
sampling_params: dict[str, Any],
):
print("Conducting parallel thinking...")
num_paths = self.num_paths
max_len = self.max_path_response_length
PATH_OPEN, PATH_CLOSE = self.start_path_token, self.end_path_token
manual_mask_positions = []
async def _gen_single_path(seed_i: int, prompt_i: list[int]):
sp = copy.deepcopy(sampling_params)
sp.update({"seed": seed_i, "n": 1, "stop_token_ids": [PATH_CLOSE, self.eos_token_id],
"temperature": 1.0})
# sp.update({"n": 1, "stop_token_ids": [PATH_CLOSE, self.eos_token_id],
# "temperature": 1.0})
manual_append = False
ids = await self.server_manager.generate(
request_id=uuid4().hex,
prompt_ids=prompt_i + [PATH_OPEN],
sampling_params=sp,
)
if max_len and len(ids) > max_len:
ids = ids[:max_len]
if ids[-1] != PATH_CLOSE:
if ids[-1] == self.eos_token_id:
ids[-1] = PATH_CLOSE
else:
ids.append(PATH_CLOSE)
manual_append = True
return [PATH_OPEN] + ids, manual_append
tasks = []
for i in range(num_paths):
# prompt_i **不再含 PATH_OPEN**,防止 tag 重复
prompt_i = prompt_ids
tasks.append(
asyncio.create_task(
_gen_single_path(random.randint(0, 2**31 - 1), prompt_i)
)
)
path_token_lists = await asyncio.gather(*tasks)
parallel_ids = []
path_spans = []
cursor = 0
#### <Path> </Path><Path> </Path> -> path_spans (0, len(path_1)) (len(path_1), len(path_1) + len(path_2))
for (ids, manual_append_flag) in path_token_lists:
assert ids[0] == PATH_OPEN and ids[-1] == PATH_CLOSE, \
f"Path tokens must start with {PATH_OPEN} and end with {PATH_CLOSE}, got {ids}"
span_start = cursor
parallel_ids.extend(ids)
cursor += len(ids)
span_end = cursor
path_spans.append((span_start, span_end))
manual_mask_positions.append(span_start)
if manual_append_flag:
manual_mask_positions.append(span_end-1)
# ---------- 3) add </Parallel><Summary> ----------
parallel_ids.append(self.end_parallel_token) # </Parallel>
manual_mask_positions.append(cursor)
assert parallel_ids[cursor] == self.end_parallel_token
cursor += 1
parallel_ids.extend(self.new_line_token) # \n
manual_mask_positions.append(cursor)
cursor += len(self.new_line_token)
parallel_ids.append(self.start_summary_token) # <Summary>
manual_mask_positions.append(cursor)
assert parallel_ids[cursor] == self.start_summary_token
cursor += 1
# if exploration stage alos superass the max length, we will not generate summary
if len(prompt_ids) + 2 + len(parallel_ids) + 1 < self.prompt_length + self.response_length:
print("Conducting Summary...")
manual_append_summary = False
sp_sum = copy.deepcopy(sampling_params)
sp_sum.update({"n": 1, "stop_token_ids": [self.end_summary_token, self.eos_token_id]})
summary_ids = await self.server_manager.generate(
request_id=uuid4().hex,
prompt_ids=prompt_ids + parallel_ids, # 现在的完整 prompt
sampling_params=sp_sum,
)
if summary_ids[-1] != self.end_summary_token:
if summary_ids[-1] == self.eos_token_id:
summary_ids[-1] = self.end_summary_token
else:
summary_ids.append(self.end_summary_token)
manual_append_summary = True
parallel_ids.extend(summary_ids)
cursor += len(summary_ids)
if manual_append_summary:
manual_mask_positions.append(cursor-1)
# print("asadas", parallel_ids[cursor-1])
assert parallel_ids[cursor-1] == self.end_summary_token
assert parallel_ids[-1] == self.end_summary_token, \
f"Parallel thinking did not end with {self.end_summary_token}, got {parallel_ids[-1]}"
# print(self.tokenizer.decode(parallel_ids, skip_special_tokens=False))
return parallel_ids, path_spans, manual_mask_positions # path_spans 已是相对 parallel_ids
两者的微小区别:
对于ToolAgentLoop来说,工具的解析是模型先生成一长段话,然后从里面解析。而对于Parallel-R1来说,是模型生成token之后,立刻停止生成,做状态机的跳转。
举个例子,假如当前的工具调用为<tool_call></tool_call>(注意模型)
状态机(State Machine)是一种数学模型,用于描述系统在不同状态之间的转换。在计算机科学中,它通常用来处理具有多个阶段或状态的流程。
状态机基本概念
组成要素
- 状态(States):系统可能处于的不同情况
- 事件(Events):触发状态转换的条件或输入
- 动作(Actions):状态转换时执行的操作
- 转换(Transitions):从一个状态到另一个状态的规则
状态图表示
状态A ──事件──→ 状态B
↑ ↓
└──事件── 状态C
ToRL
同样的,在字节今年三月发的论文ToRL中,也有类似的状态机跳转流程。
主文件定义在ToRL/verl/workers/rollout/vllm_rollout/vllm_rollout_spmd.py中,主函数为generate。
由于ToRL开源的时候,verl还没有开发AgentLoop功能,所以直接在Rollout中定义了Python工具解析的过程。
**注意,和AgentLoop不同,AgentLoop中,模型会直接生成Python 工具的tool schema(json格式);而在ToRL中,模型会生成python 的markdown块,然后在人为的将代码解析为JSON格式 **
下面我们一起来看一下:
# 在```python ```中提取代码
def extract_program(result: str, last_only=True):
"""
extract the program after "```python", and before "```"
"""
program = ''
start = False
for line in result.split('\n'):
if line.startswith('```python') or line.endswith('```python'):
if last_only:
program = '' # only extract the last program
else:
program += '\n# ========\n'
start = True
elif line.startswith('```'):
start = False
elif start:
program += line + '\n'
if start:
# the code is incomplete
program = ''
return program
# 提取Python代码,并且将其包裹在code字段中
def _detect_tool(text: str) -> Tuple[bool, str, str, str]:
program = extract_program(text)
if program:
program = json.dumps({'code': program}, ensure_ascii=False)
return (program != ''), PythonExecutor.name, program, text
def _tir_generate(self, prompts=None, sampling_params=None, prompt_token_ids=None, use_tqdm=False):
sampling_params=copy.deepcopy(sampling_params)
# prompts=self.tokenizer.batch_decode(prompt_token_ids, skip_special_tokens=True)
prompts=[self.tokenizer.decode(prompt['prompt_token_ids'], skip_special_tokens=False) for prompt in prompts]
prompts=[prompt for prompt in prompts for _ in range(sampling_params.n) ]
sampling_params.n=1
sampling_params.detokenize=True
sampling_params.stop=["```output"]
samples_info=[{"prompt": prompt, "sequence": prompt, "response": "", "stop": False, "finish_reason": None,"index": index, "mask_info": [], "execution_pass": 0} for index, prompt in enumerate(prompts)]
program2output=[]
num_llm_calls_available=copy.deepcopy(self.config.num_llm_calls_available)
# 如果当前还有剩余的工具调用
while num_llm_calls_available >= 0:
if num_llm_calls_available==0: sampling_params.stop=None
num_llm_calls_available-=1
# llm generate response, stop at eos token or ```output
input_prompts, indices=self._get_prompts_and_indices(samples_info)
input_prompts = [{
'prompt_token_ids': self.tokenizer.encode(x, add_special_tokens=False)[:self.config.prompt_length+self.config.response_length]} for x in input_prompts]
## !!!当模型生成```output时,停止生成(此时的outputs中不包含```output)(因为```output是模型自己生成的,所以最后的loss要计算```output标签)
outputs = self.inference_engine.generate(prompts=input_prompts, sampling_params=sampling_params, use_tqdm=use_tqdm)
sorted_outputs = sorted(outputs, key=lambda output: int(output.request_id))
responses=[x.outputs[0].text for x in sorted_outputs]
finish_reason=[x.outputs[0].finish_reason for x in sorted_outputs]
stop_reason=[x.outputs[0].stop_reason for x in sorted_outputs]
# 如果当前是最后一次工具调用,那么移除停止词(让模型完整生成)
if num_llm_calls_available==-1:
for i ,index in enumerate(indices):
samples_info[index]['response']+=responses[i]
samples_info[index]['sequence']+=responses[i]
samples_info[index]['stop']=True
samples_info[index]['finish_reason']=finish_reason[i]
break
# 是否需要执行Python
def _python_execution(finish_reason, stop_reason):
if finish_reason=='stop' and stop_reason==None: return False
if finish_reason=='stop' and stop_reason=='```output': return True
if finish_reason=='length': False
return False
# 检查是否需要执行Python
is_execution=[_python_execution(finish_reason[i], stop_reason[i]) for i in range(len(finish_reason))]
# 如果所有样本都不需要执行代码,跳出循环
if all([not x for x in is_execution]): break
# 检测每个响应中的工具调用(提取代码)
tool_infos=[ _detect_tool(response) for response in responses]
# 收集需要执行代码的索引和输入
tool_indices=[]
tool_inputs=[]
for i, tool_info in enumerate(tool_infos):
if tool_info[0] and is_execution[i]: # 如果检测到工具调用且需要执行
tool_indices.append(i)
tool_inputs.append(tool_info[2]) # tool_info[2]包含代码
# 定义后处理函数:处理代码执行结果,将其包裹在```output ```中,并且将错误信息也包含在```output ```中
def postproc_observation(observation):
execution_pass=0
try:
observation_list=observation
if observation_list[-1] == 'Done':
observation = observation_list[0]
execution_pass=1
else:
observation = observation_list[-1]
except Exception:
observation="Error"
if "Error" in observation: observation=observation.strip().split("\n")[-1]
if len(observation.strip())==0: observation="timeout_decorator.timeout_decorator.TimeoutError: 'Timed Out'"
observation = observation.strip()
if len(observation)>=256:
observation = observation[:128]+"..."+observation[-128:]
observation = f'{OBS_START}\n{observation}{OBS_END}'
return observation, execution_pass
# execute python code
# observations=self.executor.batch_apply([json5.loads(x)['code'] for x in tool_inputs])
observations=self.code_interpreter_batch_call([json5.loads(x)['code'] for x in tool_inputs])
# construction responses from observations
responses=[response+"\n" if not response.endswith('\n') else response for response in responses]
responses_w_res=copy.deepcopy(responses)
execution_passes=[0 for _ in range(len(responses))]
for i, index in enumerate(tool_indices):
processed_observation=postproc_observation(observations[i])
responses_w_res[index]+=processed_observation[0]
execution_passes[index]=processed_observation[1]
# program2output.append([{"code": tool_input, "answer": postproc_observation(observations[idx])} for idx, tool_input in enumerate(tool_inputs)])
# update samples_info
for i ,index in enumerate(indices):
# !!!!将```output 后面,一直到回答末尾的loss都mask掉(注意,计算```output的loss),因为```output是模型自己生成的
mask=[ len(responses[i]) + len('```output'), len(responses_w_res[i]) ]
samples_info[index]['mask_info'].append(mask)
samples_info[index]['response']+=responses_w_res[i]
samples_info[index]['sequence']+=responses_w_res[i]
samples_info[index]['stop']=not is_execution[i]
samples_info[index]['finish_reason']=finish_reason[i]
samples_info[index]['execution_pass']=execution_passes[i]
for i, line in enumerate(samples_info):
if samples_info[i]['finish_reason']!='length': samples_info[i]['response']+=self.tokenizer.eos_token
responses_ids=[]
tool_output_masks=[]
execution_passes=[]
for idx, sample_info in enumerate(samples_info):
response_id, tool_output_mask = self._tokenize_and_find_mask_token_indices(sample_info)
responses_ids.append(response_id[:self.config.response_length])
tool_output_masks.append(tool_output_mask[:self.config.response_length])
execution_passes.append(sample_info['execution_pass'])
# save id and mask to check correctness
# samples_info[idx]['responses_id']=response_id[:self.config.response_length]
# samples_info[idx]['tool_output_mask']=tool_output_mask[:self.config.response_length].tolist()
# with open("/mnt/bn/seedllm3-lixuefeng-2/code/o1/verl-tir/sample_infos.json", 'w', encoding='utf-8') as f:
# json.dump(samples_info, f, ensure_ascii=False, indent=2)
return responses_ids, tool_output_masks, torch.tensor(execution_passes, dtype=torch.long)

浙公网安备 33010602011771号