基础RAG实现,最佳入门选择(九)

RAG的相关段提取(RSE)

相关段提取(RSE)技术来提高RAG系统中的上下文质量。不是简单地检索孤立块的集合,而是识别和重建连续的文本段,为我们的语言模型提供更好的上下文。

关键概念

相关的块往往在文档中聚集在一起。通过识别这些集群并保持它们的连续性,我们为LLM提供了更连贯的上下文来使用。

在RSE(Relevant Segment Extraction,相关片段提取)流程中,我们会对每个文本块(chunk)计算其与查询(query)的相关性分数。相关性分数越高,说明该块与问题越相关。

但在实际应用中,很多块与问题无关,如果直接用相关性分数,有些无关块的分数可能接近0甚至略大于0。为了让算法更容易“跳过”这些无关块,我们会人为地给每个块的分数减去一个“惩罚值”,这个惩罚值就是 irrelevant_chunk_penalty。

  • 如果某个块本身相关性分数很低,减去惩罚后就变成负数,更容易被算法忽略。

  • 只有那些相关性分数高于惩罚值的块,才会被选入最终的连续片段。

  • 让无关内容更容易被排除,只保留真正与问题相关的片段。

  • 调节片段选择的“门槛”,惩罚值越大,选出的片段越“精”,但可能漏掉边缘相关内容;惩罚值越小,选出的片段越多,可能包含一些无关内容。

具体代码实现

PDF文本提取

从PDF文件中提取全部文本

def extract_text_from_pdf(pdf_path):
    """
    从PDF文件中提取全部文本
    :param pdf_path: PDF文件路径
    :return: 提取的文本内容(str)
    """
    print(f"[步骤] 正在从PDF提取文本: {pdf_path}")
    with open(pdf_path, 'rb') as f:
        reader = PdfReader(f)
        text = ""
        for i, page in enumerate(reader.pages):
            page_text = page.extract_text()
            if page_text:
                text += page_text
            print(f"  - 已提取第{i+1}页")
    print(f"[完成] PDF文本提取完成,总长度: {len(text)} 字符\n")
    return text

文本分块

文本分割为不重叠的块

def chunk_text(text, chunk_size=800, overlap=0):
    """
    将文本分割为不重叠的块
    :param text: 原始文本
    :param chunk_size: 每块字符数
    :param overlap: 块间重叠字符数(RSE一般为0)
    :return: 文本块列表
    """
    print(f"[步骤] 正在分块: 每块{chunk_size}字符,重叠{overlap}字符")
    chunks = []
    for i in range(0, len(text), chunk_size - overlap):
        chunk = text[i:i + chunk_size]
        if chunk:
            chunks.append(chunk)
    print(f"[完成] 分块完成,共{len(chunks)}块\n")
    return chunks

向量生成

用阿里embedding模型批量生成文本向量

def create_embeddings(texts, model=EMBEDDING_MODEL):
    """
    用阿里embedding模型批量生成文本向量
    :param texts: 文本列表
    :param model: 嵌入模型名
    :return: 向量列表
    """
    if isinstance(texts, str):
        texts = [texts]
    print(f"[嵌入生成] 正在生成{len(texts)}条文本的向量...")
    try:
        response = TextEmbedding.call(
            model=model,
            input=texts,
            api_key=ALI_API_KEY
        )
        if response.status_code == 200:
            embeddings = [np.array(item['embedding']) for item in response.output['embeddings']]
            print(f"[嵌入生成] 成功,返回{len(embeddings)}条向量\n")
            return embeddings
        else:
            print(f"[嵌入生成] 失败: {response.message}")
            return [np.zeros(1536)] * len(texts)
    except Exception as e:
        print(f"[嵌入生成] 异常: {e}")
        return [np.zeros(1536)] * len(texts)

简单向量库

简单的向量存储与检索类

class SimpleVectorStore:
    """
    简单的向量存储与检索类
    """
    def __init__(self, dimension=1536):
        self.dimension = dimension
        self.vectors = []
        self.documents = []
        self.metadata = []

    def add_documents(self, documents, vectors=None, metadata=None):
        if vectors is None:
            vectors = [None] * len(documents)
        if metadata is None:
            metadata = [{} for _ in range(len(documents))]
        for doc, vec, meta in zip(documents, vectors, metadata):
            self.documents.append(doc)
            self.vectors.append(vec)
            self.metadata.append(meta)

    def search(self, query_vector, top_k=5):
        if not self.vectors or not self.documents:
            return []
        query_array = np.array(query_vector)
        similarities = []
        for i, vector in enumerate(self.vectors):
            if vector is not None:
                similarity = np.dot(query_array, vector) / (
                    np.linalg.norm(query_array) * np.linalg.norm(vector)
                )
                similarities.append((i, similarity))
        similarities.sort(key=lambda x: x[1], reverse=True)
        results = []
        for i, score in similarities[:top_k]:
            results.append({
                "document": self.documents[i],
                "score": float(score),
                "metadata": self.metadata[i]
            })
        return results

文档处理主流程

处理文档,提取文本、分块、生成嵌入、构建向量库

def process_document(pdf_path, chunk_size=800):
    """
    处理文档,提取文本、分块、生成嵌入、构建向量库
    :param pdf_path: PDF路径
    :param chunk_size: 块大小
    :return: 块列表、向量库、文档信息
    """
    print("[流程] 开始处理文档...")
    text = extract_text_from_pdf(pdf_path)
    print("[流程] 正在分块...")
    chunks = chunk_text(text, chunk_size=chunk_size, overlap=0)
    print(f"[流程] 共分得{len(chunks)}块")
    print("[流程] 正在生成嵌入...")
    chunk_embeddings = create_embeddings(chunks)
    vector_store = SimpleVectorStore()
    metadata = [{"chunk_index": i, "source": pdf_path} for i in range(len(chunks))]
    vector_store.add_documents(chunks, chunk_embeddings, metadata)
    doc_info = {
        "chunks": chunks,
        "source": pdf_path,
    }
    print("[流程] 文档处理完成\n")
    return chunks, vector_store, doc_info

计算相关性分数

计算每个块的相关性分数

def calculate_chunk_values(query, chunks, vector_store, irrelevant_chunk_penalty=0.2):
    """
    计算每个块的相关性分数
    :param query: 查询
    :param chunks: 块列表
    :param vector_store: 向量库
    :param irrelevant_chunk_penalty: 不相关块惩罚
    :return: 块分数列表
    """
    print("[流程] 正在为查询生成嵌入...")
    query_embedding = create_embeddings([query])[0]
    print("[流程] 正在检索所有块的相关性...")
    num_chunks = len(chunks)
    results = vector_store.search(query_embedding, top_k=num_chunks)
    relevance_scores = {result["metadata"]["chunk_index"]: result["score"] for result in results}
    chunk_values = []
    for i in range(num_chunks):
        score = relevance_scores.get(i, 0.0)
        value = score - irrelevant_chunk_penalty
        chunk_values.append(value)
        print(f"  - 块{i} 相关性: {score:.4f},扣除惩罚后: {value:.4f}")
    print("[流程] 块分数计算完成\n")
    return chunk_values

RSE连续片段提取

用最大子段和算法提取最优连续片段

def find_best_segments(chunk_values, max_segment_length=20, total_max_length=30, min_segment_value=0.2):
    """
    用最大子段和算法提取最优连续片段
    :param chunk_values: 块分数列表
    :param max_segment_length: 单片段最大长度
    :param total_max_length: 总片段最大长度
    :param min_segment_value: 最小片段分数
    :return: [(start, end)], [分数]
    """
    print("[流程] 正在查找最优连续片段...")
    best_segments = []
    segment_scores = []
    total_included_chunks = 0
    while total_included_chunks < total_max_length:
        best_score = min_segment_value
        best_segment = None
        for start in range(len(chunk_values)):
            if any(start >= s[0] and start < s[1] for s in best_segments):
                continue
            for length in range(1, min(max_segment_length, len(chunk_values) - start) + 1):
                end = start + length
                if any(end > s[0] and end <= s[1] for s in best_segments):
                    continue
                segment_value = sum(chunk_values[start:end])
                if segment_value > best_score:
                    best_score = segment_value
                    best_segment = (start, end)
        if best_segment:
            best_segments.append(best_segment)
            segment_scores.append(best_score)
            total_included_chunks += best_segment[1] - best_segment[0]
            print(f"  - 发现片段 {best_segment},分数: {best_score:.4f}")
        else:
            break
    best_segments = sorted(best_segments, key=lambda x: x[0])
    print(f"[流程] 共选出{len(best_segments)}个片段\n")
    return best_segments, segment_scores

片段重构

根据片段索引重构文本

def reconstruct_segments(chunks, best_segments):
    """
    根据片段索引重构文本
    :param chunks: 块列表
    :param best_segments: [(start, end)]
    :return: [{text, segment_range}]
    """
    print("[流程] 正在重构片段文本...")
    reconstructed_segments = []
    for start, end in best_segments:
        segment_text = " ".join(chunks[start:end])
        reconstructed_segments.append({
            "text": segment_text,
            "segment_range": (start, end),
        })
        print(f"  - 片段范围: {start}-{end-1},长度: {len(segment_text)} 字符")
    print("[流程] 片段重构完成\n")
    return reconstructed_segments

片段格式化

格式化片段为上下文字符串

def format_segments_for_context(segments):
    """
    格式化片段为上下文字符串
    :param segments: 片段列表
    :return: 上下文字符串
    """
    print("[流程] 正在格式化片段为上下文...")
    context = []
    for i, segment in enumerate(segments):
        segment_header = f"SEGMENT {i+1} (Chunks {segment['segment_range'][0]}-{segment['segment_range'][1]-1}):"
        context.append(segment_header)
        context.append(segment['text'])
        context.append("-" * 80)
    print("[流程] 片段格式化完成\n")
    return "\n\n".join(context)

LLM生成回答

用大模型基于上下文生成回答

def generate_response(query, context, model=LLM_MODEL):
    """
    用大模型基于上下文生成回答
    :param query: 用户问题
    :param context: 上下文
    :param model: 生成模型名
    :return: 回答内容
    """
    print("[流程] 正在调用大模型生成最终回答...")
    system_prompt = "你是一个AI助手,只能基于给定上下文回答问题。如果上下文无法直接回答,请回复:'信息不足,无法回答。'"
    user_prompt = f"""
上下文:\n{context}\n\n问题:{query}\n\n请只基于上述上下文简明准确作答。"""
    try:
        response = Generation.call(
            model=model,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            api_key=ALI_API_KEY,
            result_format='message'
        )
        if response.status_code == 200:
            print("[流程] 回答生成成功\n")
            return response.output.choices[0].message.content.strip()
        else:
            print(f"[流程] 回答生成失败: {response.message}")
            return ""
    except Exception as e:
        print(f"[流程] 回答生成异常: {e}")
        return ""

RSE主流程

RSE增强RAG主流程

def rag_with_rse(pdf_path, query, chunk_size=800, irrelevant_chunk_penalty=0.2):
    """
    RSE增强RAG主流程
    :param pdf_path: 文档路径
    :param query: 用户问题
    :param chunk_size: 块大小
    :param irrelevant_chunk_penalty: 不相关块惩罚
    :return: 结果字典
    """
    print("\n=== RSE增强RAG流程开始 ===")
    print(f"[输入] 问题: {query}")
    chunks, vector_store, doc_info = process_document(pdf_path, chunk_size)
    print("[流程] 正在计算相关性分数...")
    chunk_values = calculate_chunk_values(query, chunks, vector_store, irrelevant_chunk_penalty)
    best_segments, scores = find_best_segments(
        chunk_values,
        max_segment_length=20,
        total_max_length=30,
        min_segment_value=0.2
    )
    segments = reconstruct_segments(chunks, best_segments)
    context = format_segments_for_context(segments)
    response = generate_response(query, context)
    result = {
        "query": query,
        "segments": segments,
        "response": response
    }
    print("\n=== 最终AI回答 ===")
    print(response)
    print("=== RSE增强RAG流程结束 ===\n")
    return result

执行结果

========== RSE增强RAG主流程演示 ==========
[配置] 使用API密钥: sk-fc6ad...2f23
[配置] PDF路径: data/2888年Java程序员找工作最新场景题.pdf
[配置] 验证集路径: data/java_val.json
[配置] 块大小: 800,不相关块惩罚: 0.2

[主流程] 示例问题: Java程序员面试中常见的技术问题有哪些?

=== RSE增强RAG流程开始 ===
[输入] 问题: Java程序员面试中常见的技术问题有哪些?
[流程] 开始处理文档...
[步骤] 正在从PDF提取文本: data/2888年Java程序员找工作最新场景题.pdf
  - 已提取第1页
  - 已提取第2页
  - 已提取第3页
  - 已提取第4页
  - 已提取第5页
  - 已提取第6页
  - 已提取第7页
  - 已提取第8页
  - 已提取第9页
  - 已提取第10页
[完成] PDF文本提取完成,总长度: 6984 字符

[流程] 正在分块...
[步骤] 正在分块: 每块800字符,重叠0字符
[完成] 分块完成,共9块

[流程] 共分得9块
[流程] 正在生成嵌入...
[嵌入生成] 正在生成9条文本的向量...
[嵌入生成] 成功,返回9条向量

[流程] 文档处理完成

[流程] 正在计算相关性分数...
[流程] 正在为查询生成嵌入...
[嵌入生成] 正在生成1条文本的向量...
[嵌入生成] 成功,返回1条向量

[流程] 正在检索所有块的相关性...
  - 块0 相关性: 0.6151,扣除惩罚后: 0.4151
  - 块1 相关性: 0.4011,扣除惩罚后: 0.2011
  - 块2 相关性: 0.3759,扣除惩罚后: 0.1759
  - 块3 相关性: 0.5489,扣除惩罚后: 0.3489
  - 块4 相关性: 0.3395,扣除惩罚后: 0.1395
  - 块5 相关性: 0.2633,扣除惩罚后: 0.0633
  - 块6 相关性: 0.2011,扣除惩罚后: 0.0011
  - 块7 相关性: 0.1932,扣除惩罚后: -0.0068
  - 块8 相关性: 0.1235,扣除惩罚后: -0.0765
[流程] 块分数计算完成

[流程] 正在查找最优连续片段...
  - 发现片段 (0, 7),分数: 1.3449
[流程] 共选出1个片段

[流程] 正在重构片段文本...
  - 片段范围: 0-6,长度: 5606 字符
[流程] 片段重构完成

[流程] 正在格式化片段为上下文...
[流程] 片段格式化完成

[流程] 正在调用大模型生成最终回答...
[流程] 回答生成成功


=== 最终AI回答 ===
根据文档内容,Java程序员面试中常见的技术问题包括:

1. 项目经验与挑战:面试官通常会询问候选人最自豪或最近完成的项目、解决过的最复杂的技术难题、经历过的最具挑战性的项目以及曾犯下的最大技术失误或引发的技术故障。这些问题旨在考察候选人的实际工作经验、问题解决能力、面对困难时的心态以及对待错误的态度和反思成长的能力。

2. 技术深度理解:针对简历上列出的技术技能(如Java并发编程、NIO、JVM、Spring框架等),面试官会深入探讨以评估候选人的真实技术水平。例如,对于声称熟练掌握Java的候选人,可能会被问及并发编程、JVM调优等相关知识;对于提到Go语言的人,则可能需要展示对《Effective Go》的理解程度等。

3. 实际应用场景:比如在电商平台中如何实现订单未支付过期自动关闭的功能,这涉及到定时任务、JDK延迟队列DelayQueue的应用以及Redis过期监听机制等具体技术解决方案的选择与实施。

综上所述,面试不仅关注理论知识,更注重候选人将这些知识应用于解决实际问题的能力。
=== RSE增强RAG流程结束 ===

========== 演示结束 ==========

进程已结束,退出代码为 0

完整示例代码

# -*- coding: utf-8 -*-
"""
RSE增强RAG主流程(阿里大模型版,详细中文注释+详细控制台输出)
作者:AI助手
"""
import os
import numpy as np
from tqdm import tqdm
from PyPDF2 import PdfReader
from dashscope import Generation, TextEmbedding
import json
import re

# ========== 密钥配置:优先从api_keys.py读取,否则用环境变量 ==========
# try:
#     from test.api_keys import ALI_API_KEY
# except ImportError:
#     ALI_API_KEY = os.getenv("ALI_API_KEY", "sk-0c198b40580347xxx3508b57c94")
ALI_API_KEY="sk-fc6ad8ecef4bxxxxxx5372f23"
# ==============================================

LLM_MODEL = "qwen-max"  # 通义千问主力模型
EMBEDDING_MODEL = "text-embedding-v2"  # 阿里云嵌入模型

# ========== PDF文本提取 ==========
def extract_text_from_pdf(pdf_path):
    """
    从PDF文件中提取全部文本
    :param pdf_path: PDF文件路径
    :return: 提取的文本内容(str)
    """
    print(f"[步骤] 正在从PDF提取文本: {pdf_path}")
    with open(pdf_path, 'rb') as f:
        reader = PdfReader(f)
        text = ""
        for i, page in enumerate(reader.pages):
            page_text = page.extract_text()
            if page_text:
                text += page_text
            print(f"  - 已提取第{i+1}页")
    print(f"[完成] PDF文本提取完成,总长度: {len(text)} 字符\n")
    return text

# ========== 文本分块 ==========
def chunk_text(text, chunk_size=800, overlap=0):
    """
    将文本分割为不重叠的块
    :param text: 原始文本
    :param chunk_size: 每块字符数
    :param overlap: 块间重叠字符数(RSE一般为0)
    :return: 文本块列表
    """
    print(f"[步骤] 正在分块: 每块{chunk_size}字符,重叠{overlap}字符")
    chunks = []
    for i in range(0, len(text), chunk_size - overlap):
        chunk = text[i:i + chunk_size]
        if chunk:
            chunks.append(chunk)
    print(f"[完成] 分块完成,共{len(chunks)}块\n")
    return chunks

# ========== 向量生成 ==========
def create_embeddings(texts, model=EMBEDDING_MODEL):
    """
    用阿里embedding模型批量生成文本向量
    :param texts: 文本列表
    :param model: 嵌入模型名
    :return: 向量列表
    """
    if isinstance(texts, str):
        texts = [texts]
    print(f"[嵌入生成] 正在生成{len(texts)}条文本的向量...")
    try:
        response = TextEmbedding.call(
            model=model,
            input=texts,
            api_key=ALI_API_KEY
        )
        if response.status_code == 200:
            embeddings = [np.array(item['embedding']) for item in response.output['embeddings']]
            print(f"[嵌入生成] 成功,返回{len(embeddings)}条向量\n")
            return embeddings
        else:
            print(f"[嵌入生成] 失败: {response.message}")
            return [np.zeros(1536)] * len(texts)
    except Exception as e:
        print(f"[嵌入生成] 异常: {e}")
        return [np.zeros(1536)] * len(texts)

# ========== 简单向量库 ==========
class SimpleVectorStore:
    """
    简单的向量存储与检索类
    """
    def __init__(self, dimension=1536):
        self.dimension = dimension
        self.vectors = []
        self.documents = []
        self.metadata = []

    def add_documents(self, documents, vectors=None, metadata=None):
        if vectors is None:
            vectors = [None] * len(documents)
        if metadata is None:
            metadata = [{} for _ in range(len(documents))]
        for doc, vec, meta in zip(documents, vectors, metadata):
            self.documents.append(doc)
            self.vectors.append(vec)
            self.metadata.append(meta)

    def search(self, query_vector, top_k=5):
        if not self.vectors or not self.documents:
            return []
        query_array = np.array(query_vector)
        similarities = []
        for i, vector in enumerate(self.vectors):
            if vector is not None:
                similarity = np.dot(query_array, vector) / (
                    np.linalg.norm(query_array) * np.linalg.norm(vector)
                )
                similarities.append((i, similarity))
        similarities.sort(key=lambda x: x[1], reverse=True)
        results = []
        for i, score in similarities[:top_k]:
            results.append({
                "document": self.documents[i],
                "score": float(score),
                "metadata": self.metadata[i]
            })
        return results

# ========== 文档处理主流程 ==========
def process_document(pdf_path, chunk_size=800):
    """
    处理文档,提取文本、分块、生成嵌入、构建向量库
    :param pdf_path: PDF路径
    :param chunk_size: 块大小
    :return: 块列表、向量库、文档信息
    """
    print("[流程] 开始处理文档...")
    text = extract_text_from_pdf(pdf_path)
    print("[流程] 正在分块...")
    chunks = chunk_text(text, chunk_size=chunk_size, overlap=0)
    print(f"[流程] 共分得{len(chunks)}块")
    print("[流程] 正在生成嵌入...")
    chunk_embeddings = create_embeddings(chunks)
    vector_store = SimpleVectorStore()
    metadata = [{"chunk_index": i, "source": pdf_path} for i in range(len(chunks))]
    vector_store.add_documents(chunks, chunk_embeddings, metadata)
    doc_info = {
        "chunks": chunks,
        "source": pdf_path,
    }
    print("[流程] 文档处理完成\n")
    return chunks, vector_store, doc_info

# ========== 计算相关性分数 ==========
def calculate_chunk_values(query, chunks, vector_store, irrelevant_chunk_penalty=0.2):
    """
    计算每个块的相关性分数
    :param query: 查询
    :param chunks: 块列表
    :param vector_store: 向量库
    :param irrelevant_chunk_penalty: 不相关块惩罚
    :return: 块分数列表
    """
    print("[流程] 正在为查询生成嵌入...")
    query_embedding = create_embeddings([query])[0]
    print("[流程] 正在检索所有块的相关性...")
    num_chunks = len(chunks)
    results = vector_store.search(query_embedding, top_k=num_chunks)
    relevance_scores = {result["metadata"]["chunk_index"]: result["score"] for result in results}
    chunk_values = []
    for i in range(num_chunks):
        score = relevance_scores.get(i, 0.0)
        value = score - irrelevant_chunk_penalty
        chunk_values.append(value)
        print(f"  - 块{i} 相关性: {score:.4f},扣除惩罚后: {value:.4f}")
    print("[流程] 块分数计算完成\n")
    return chunk_values

# ========== RSE连续片段提取 ==========
def find_best_segments(chunk_values, max_segment_length=20, total_max_length=30, min_segment_value=0.2):
    """
    用最大子段和算法提取最优连续片段
    :param chunk_values: 块分数列表
    :param max_segment_length: 单片段最大长度
    :param total_max_length: 总片段最大长度
    :param min_segment_value: 最小片段分数
    :return: [(start, end)], [分数]
    """
    print("[流程] 正在查找最优连续片段...")
    best_segments = []
    segment_scores = []
    total_included_chunks = 0
    while total_included_chunks < total_max_length:
        best_score = min_segment_value
        best_segment = None
        for start in range(len(chunk_values)):
            if any(start >= s[0] and start < s[1] for s in best_segments):
                continue
            for length in range(1, min(max_segment_length, len(chunk_values) - start) + 1):
                end = start + length
                if any(end > s[0] and end <= s[1] for s in best_segments):
                    continue
                segment_value = sum(chunk_values[start:end])
                if segment_value > best_score:
                    best_score = segment_value
                    best_segment = (start, end)
        if best_segment:
            best_segments.append(best_segment)
            segment_scores.append(best_score)
            total_included_chunks += best_segment[1] - best_segment[0]
            print(f"  - 发现片段 {best_segment},分数: {best_score:.4f}")
        else:
            break
    best_segments = sorted(best_segments, key=lambda x: x[0])
    print(f"[流程] 共选出{len(best_segments)}个片段\n")
    return best_segments, segment_scores

# ========== 片段重构 ==========
def reconstruct_segments(chunks, best_segments):
    """
    根据片段索引重构文本
    :param chunks: 块列表
    :param best_segments: [(start, end)]
    :return: [{text, segment_range}]
    """
    print("[流程] 正在重构片段文本...")
    reconstructed_segments = []
    for start, end in best_segments:
        segment_text = " ".join(chunks[start:end])
        reconstructed_segments.append({
            "text": segment_text,
            "segment_range": (start, end),
        })
        print(f"  - 片段范围: {start}-{end-1},长度: {len(segment_text)} 字符")
    print("[流程] 片段重构完成\n")
    return reconstructed_segments

# ========== 片段格式化 ==========
def format_segments_for_context(segments):
    """
    格式化片段为上下文字符串
    :param segments: 片段列表
    :return: 上下文字符串
    """
    print("[流程] 正在格式化片段为上下文...")
    context = []
    for i, segment in enumerate(segments):
        segment_header = f"SEGMENT {i+1} (Chunks {segment['segment_range'][0]}-{segment['segment_range'][1]-1}):"
        context.append(segment_header)
        context.append(segment['text'])
        context.append("-" * 80)
    print("[流程] 片段格式化完成\n")
    return "\n\n".join(context)

# ========== LLM生成回答 ==========
def generate_response(query, context, model=LLM_MODEL):
    """
    用大模型基于上下文生成回答
    :param query: 用户问题
    :param context: 上下文
    :param model: 生成模型名
    :return: 回答内容
    """
    print("[流程] 正在调用大模型生成最终回答...")
    system_prompt = "你是一个AI助手,只能基于给定上下文回答问题。如果上下文无法直接回答,请回复:'信息不足,无法回答。'"
    user_prompt = f"""
上下文:\n{context}\n\n问题:{query}\n\n请只基于上述上下文简明准确作答。"""
    try:
        response = Generation.call(
            model=model,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            api_key=ALI_API_KEY,
            result_format='message'
        )
        if response.status_code == 200:
            print("[流程] 回答生成成功\n")
            return response.output.choices[0].message.content.strip()
        else:
            print(f"[流程] 回答生成失败: {response.message}")
            return ""
    except Exception as e:
        print(f"[流程] 回答生成异常: {e}")
        return ""

# ========== RSE主流程 ==========
def rag_with_rse(pdf_path, query, chunk_size=800, irrelevant_chunk_penalty=0.2):
    """
    RSE增强RAG主流程
    :param pdf_path: 文档路径
    :param query: 用户问题
    :param chunk_size: 块大小
    :param irrelevant_chunk_penalty: 不相关块惩罚
    :return: 结果字典
    """
    print("\n=== RSE增强RAG流程开始 ===")
    print(f"[输入] 问题: {query}")
    chunks, vector_store, doc_info = process_document(pdf_path, chunk_size)
    print("[流程] 正在计算相关性分数...")
    chunk_values = calculate_chunk_values(query, chunks, vector_store, irrelevant_chunk_penalty)
    best_segments, scores = find_best_segments(
        chunk_values,
        max_segment_length=20,
        total_max_length=30,
        min_segment_value=0.2
    )
    segments = reconstruct_segments(chunks, best_segments)
    context = format_segments_for_context(segments)
    response = generate_response(query, context)
    result = {
        "query": query,
        "segments": segments,
        "response": response
    }
    print("\n=== 最终AI回答 ===")
    print(response)
    print("=== RSE增强RAG流程结束 ===\n")
    return result

# ========== main方法示例 ==========
def main():
    """
    主方法示例:自动读取val.json第一个问题和PDF,体验RSE增强RAG
    在RSE(Relevant Segment Extraction,相关片段提取)流程中,我们会对每个文本块(chunk)计算其与查询(query)的相关性分数。相关性分数越高,说明该块与问题越相关。 但在实际应用中,很多块与问题无关,如果直接用相关性分数,有些无关块的分数可能接近0甚至略大于0。为了让算法更容易“跳过”这些无关块,我们会人为地给每个块的分数减去一个“惩罚值”,这个惩罚值就是 irrelevant_chunk_penalty。 • 如果某个块本身相关性分数很低,减去惩罚后就变成负数,更容易被算法忽略。
    """
    # 路径配置
    pdf_path = "data/2888年Java程序员找工作最新场景题.pdf"
    val_path = "data/java_val.json"
    chunk_size = 800
    irrelevant_chunk_penalty = 0.2
    print("\n========== RSE增强RAG主流程演示 ==========")
    print(f"[配置] 使用API密钥: {ALI_API_KEY[:8]}...{ALI_API_KEY[-4:]}")
    print(f"[配置] PDF路径: {pdf_path}")
    print(f"[配置] 验证集路径: {val_path}")
    print(f"[配置] 块大小: {chunk_size},不相关块惩罚: {irrelevant_chunk_penalty}\n")
    # 读取验证集
    with open(val_path, encoding='utf-8') as f:
        data = json.load(f)
    query = data[0]['question']
    print(f"[主流程] 示例问题: {query}")
    # 执行RSE增强RAG
    rag_with_rse(pdf_path, query, chunk_size, irrelevant_chunk_penalty)
    print("========== 演示结束 ==========")

if __name__ == "__main__":
    main()

posted @ 2025-06-25 16:57  舒一笑不秃头  阅读(33)  评论(0)    收藏  举报