vllm推理框架

LB.4
推理框架

主求解函数(Solver Function)

# Solver Function
# 主求解函数,循环生成多轮样本并提取最终答案
from tqdm import tqdm
num_generations = 2  # 每轮生成 128 个样本,设置总轮次数
def solve(question):
    ans = []  # 存储所有生成的答案
    for i in range(num_generations):
        # 为问题添加提示语,生成符合 CoT 风格的输入
        prompt = question + tool_instruction
        # 调用生成文本函数,生成 128 个样本
        generate_text = generate_text_vllm([prompt]*128, tokenizer, llm)
        # 提取文本中的答案
        ans.extend(extract_answer(generate_text))
    # 统计并返回最终一致答案
    answer = fin_pred(ans) 
    return answer

tool_instruction

tool_instruction = '\nPlease reason step by step, and put your final answer within \\boxed{}.'

文本生成(generate_text_vllm)

# Function to generate text using vLLM
# Temperature sampling used. You can try experimenting with different values

def generate_text_vllm(requests, tokenizer, model):
    sampling_params = vllm.SamplingParams(
      temperature=0.7,
      top_p = 0.8,
      min_p = 0.01,
      max_tokens=2048 
    )
    responses = model.generate(requests, sampling_params=sampling_params, use_tqdm=False)
    response_text_list = []
    for response in responses:
        response_text_list.append(response.outputs[0].text)
    return response_text_list

提取答案(extract_answer)

#Functions to find the final answer

def extract_answer(texts):
    sols = []
    for text in texts:
        try:
            ans = find_answer(text)
            if ans>=0:
                sols.append(ans)
        except:
            ans = -1
    return sols

(find_answer)

#Function to extract the \boxed{} answer from the generated text 

import re
def find_answer(generate_text):
    answer = -1
    try:
        result_output = re.findall(r'\\boxed\{(\d+)\}', generate_text)
        
        if len(result_output) > 0:
            no = naive_parse(result_output[0])
            if len(no) > 0:
                answer = int(no) % 1000
            #print(answer)
        else:
            ok = 1
    except Exception as e:
        #print(e)
        #print("="*100)
        answer = -1
    return answer

naive_parse

# Function to extract the numerical answer from the text output

def naive_parse(answer):
    out = []
    start = False
    end = False
    for l in reversed(list(answer)):
        if l in '0123456789' and not end:
            start = True
            out.append(l)
        else:
            if start:
                end = True
        
    out = reversed(out)
    return ''.join(out)

fin_pred

def fin_pred(sols):
    if len(sols):
        return get_majority_vote(sols)
    else:
        return 0

get_majority_vote

# Function to determine the most consistent answer. Our final answer will be the most consistent one among the possible ones

from collections import Counter
def get_majority_vote(answers):
    if not len(answers):
        return 0
    c = Counter(answers)
    value, _ = c.most_common()[0]
    print("Most Common answers : ",c.most_common()[:10])
    print("="*50)
    try:
        z = abs(value)
    except:
        z = value
    return z
posted @ 2024-12-30 13:35  HaibaraYuki  阅读(40)  评论(0)    收藏  举报