LB.4
推理框架
主求解函数(Solver Function)
# Solver Function
# 主求解函数,循环生成多轮样本并提取最终答案
from tqdm import tqdm
num_generations = 2 # 每轮生成 128 个样本,设置总轮次数
def solve(question):
ans = [] # 存储所有生成的答案
for i in range(num_generations):
# 为问题添加提示语,生成符合 CoT 风格的输入
prompt = question + tool_instruction
# 调用生成文本函数,生成 128 个样本
generate_text = generate_text_vllm([prompt]*128, tokenizer, llm)
# 提取文本中的答案
ans.extend(extract_answer(generate_text))
# 统计并返回最终一致答案
answer = fin_pred(ans)
return answer
tool_instruction = '\nPlease reason step by step, and put your final answer within \\boxed{}.'
文本生成(generate_text_vllm)
# Function to generate text using vLLM
# Temperature sampling used. You can try experimenting with different values
def generate_text_vllm(requests, tokenizer, model):
sampling_params = vllm.SamplingParams(
temperature=0.7,
top_p = 0.8,
min_p = 0.01,
max_tokens=2048
)
responses = model.generate(requests, sampling_params=sampling_params, use_tqdm=False)
response_text_list = []
for response in responses:
response_text_list.append(response.outputs[0].text)
return response_text_list
#Functions to find the final answer
def extract_answer(texts):
sols = []
for text in texts:
try:
ans = find_answer(text)
if ans>=0:
sols.append(ans)
except:
ans = -1
return sols
(find_answer)
#Function to extract the \boxed{} answer from the generated text
import re
def find_answer(generate_text):
answer = -1
try:
result_output = re.findall(r'\\boxed\{(\d+)\}', generate_text)
if len(result_output) > 0:
no = naive_parse(result_output[0])
if len(no) > 0:
answer = int(no) % 1000
#print(answer)
else:
ok = 1
except Exception as e:
#print(e)
#print("="*100)
answer = -1
return answer
naive_parse
# Function to extract the numerical answer from the text output
def naive_parse(answer):
out = []
start = False
end = False
for l in reversed(list(answer)):
if l in '0123456789' and not end:
start = True
out.append(l)
else:
if start:
end = True
out = reversed(out)
return ''.join(out)
fin_pred
def fin_pred(sols):
if len(sols):
return get_majority_vote(sols)
else:
return 0
get_majority_vote
# Function to determine the most consistent answer. Our final answer will be the most consistent one among the possible ones
from collections import Counter
def get_majority_vote(answers):
if not len(answers):
return 0
c = Counter(answers)
value, _ = c.most_common()[0]
print("Most Common answers : ",c.most_common()[:10])
print("="*50)
try:
z = abs(value)
except:
z = value
return z