Python langchain RAG演示

示例代码:

import json
import os
import pickle
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from loguru import logger
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig


class PersistentDocumentStore:
    """支持磁盘持久化的文档存储系统"""

    def __init__(self, storage_dir: str = "./rag_storage"):
        """
        初始化持久化文档存储

        Args:
            storage_dir: 存储目录路径
        """
        self.__storage_dir = storage_dir
        self.documents_file = os.path.join(storage_dir, "documents.json")
        self.vectors_file = os.path.join(storage_dir, "vectors.pkl")
        self.vectorizer_file = os.path.join(storage_dir, "vectorizer.pkl")

        # 确保存储目录存在
        os.makedirs(storage_dir, exist_ok=True)

        # 初始化
        self.documents: List[str] = []
        self.vectorizer = TfidfVectorizer(max_features=2000)
        self.vectors = None

        # 从磁盘加载已有数据
        self._load_from_disk()

    def _load_from_disk(self):
        """从磁盘加载数据"""
        try:
            # 加载文档
            if os.path.exists(self.documents_file):
                with open(self.documents_file, "r", encoding="utf-8") as f:
                    self.documents = json.load(f)
                logger.info(f"从磁盘加载了 {len(self.documents)} 条文档")

            # 加载向量化器
            if os.path.exists(self.vectorizer_file):
                with open(self.vectorizer_file, "rb") as f:
                    self.vectorizer = pickle.load(f)
                logger.info("从磁盘加载了向量化器")

            # 加载向量
            if os.path.exists(self.vectors_file):
                with open(self.vectors_file, "rb") as f:
                    self.vectors = pickle.load(f)
                logger.info("从磁盘加载了向量矩阵")

        except Exception as e:
            logger.warning(f"加载磁盘数据失败: {e}")
            self.documents = []
            self.vectors = None

    def _save_to_disk(self):
        """保存数据到磁盘"""
        try:
            # 保存文档
            with open(self.documents_file, "w", encoding="utf-8") as f:
                json.dump(self.documents, f, ensure_ascii=False, indent=2)

            # 保存向量化器
            with open(self.vectorizer_file, "wb") as f:
                pickle.dump(self.vectorizer, f)

            # 保存向量
            if self.vectors is not None:
                with open(self.vectors_file, "wb") as f:
                    pickle.dump(self.vectors, f)

            logger.info("数据已保存到磁盘")

        except Exception as e:
            logger.error(f"保存磁盘数据失败: {e}")

    def add_documents(self, docs: List[str]):
        """添加文档并更新存储"""
        new_docs = [doc for doc in docs if doc not in self.documents]
        if new_docs:
            self.documents.extend(new_docs)
            if self.documents:
                self.vectors = self.vectorizer.fit_transform(self.documents)
            self._save_to_disk()
            logger.info(
                f"添加了 {len(new_docs)} 条新文档,总计 {len(self.documents)} 条"
            )

    def remove_documents(self, indices: List[int]):
        """按索引移除文档"""
        if not indices:
            return

        indices = sorted(indices, reverse=True)
        for idx in indices:
            if 0 <= idx < len(self.documents):
                del self.documents[idx]

        if self.documents:
            self.vectors = self.vectorizer.fit_transform(self.documents)
        else:
            self.vectors = None

        self._save_to_disk()
        logger.info(f"移除了 {len(indices)} 条文档,剩余 {len(self.documents)} 条")

    def search(self, query: str, top_k: int = 3) -> List[Tuple[str, float, int]]:
        """搜索相关文档,返回文档、相似度和索引"""
        if not self.documents or self.vectors is None:
            return []

        query_vec = self.vectorizer.transform([query])
        similarities = cosine_similarity(query_vec, self.vectors).flatten()

        # 获取最相关的文档索引
        top_indices = np.argsort(similarities)[-top_k:][::-1]
        return [(self.documents[i], similarities[i], i) for i in top_indices]  # type: ignore

    def get_stats(self) -> Dict[str, Any]:
        """获取存储统计信息"""
        return {
            "total_documents": len(self.documents),
            "storage_dir": self.__storage_dir,
            "vector_dim": self.vectors.shape[1] if self.vectors is not None else 0,
        }


class Qwen3PersistentRAG(LLM):
    """支持磁盘持久化的RAG系统"""

    tokenizer: Optional[AutoTokenizer] = None
    model: Optional[AutoModelForCausalLM] = None

    def __init__(self, model_dir: str, storage_dir: str = "./rag_storage"):
        super().__init__()
        self.__model_dir = model_dir
        self.__storage_dir = storage_dir

        # 初始化持久化文档存储
        self.__document_store = PersistentDocumentStore(storage_dir)

        logger.info("正在初始化持久化RAG系统...")
        self._load_model()

    def _load_model(self):
        """加载模型"""
        logger.info("检测设备...")
        self.__device = self._get_device()

        logger.info("加载分词器...")
        self.tokenizer = AutoTokenizer.from_pretrained(
            self.__model_dir, trust_remote_code=True
        )

        logger.info("加载模型...")
        self.model = AutoModelForCausalLM.from_pretrained(
            self.__model_dir,
            device_map=None,
            trust_remote_code=True,
            torch_dtype=(
                torch.float16 if self.__device != torch.device("cpu") else torch.float32
            ),
            low_cpu_mem_usage=True,
        ).to(
            self.__device  # type: ignore
        )  # type: ignore

        logger.info("加载生成配置...")
        self.model.generation_config = GenerationConfig.from_pretrained(  # type: ignore
            self.__model_dir, trust_remote_code=True
        )

    def _get_device(self) -> torch.device:
        """获取最佳设备"""
        if torch.backends.mps.is_available():
            logger.info("使用Apple Silicon GPU (MPS)")
            return torch.device("mps")
        elif torch.cuda.is_available():
            logger.info(f"使用NVIDIA GPU: {torch.cuda.get_device_name(0)}")
            return torch.device("cuda")
        else:
            logger.info("使用CPU")
            return torch.device("cpu")

    def add_documents(self, docs: List[str]):
        """添加文档到存储"""
        self.__document_store.add_documents(docs)

    def remove_documents(self, indices: List[int]):
        """移除指定索引的文档"""
        self.__document_store.remove_documents(indices)

    def list_documents(self) -> List[str]:
        """列出所有文档"""
        return self.__document_store.documents

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """RAG增强的调用方法"""

        # 检索相关文档
        relevant_docs = self.__document_store.search(prompt, top_k=3)

        # 构建增强的prompt
        context_parts = []
        for doc, score, idx in relevant_docs:
            context_parts.append(f"[文档{idx+1} - 相关度{score:.2f}]\n{doc}")

        context = "\n\n".join(context_parts)
        enhanced_prompt = f"""基于以下上下文回答问题:

            {context}

            问题:{prompt}

            请基于提供的上下文信息回答问题。如果上下文中没有相关信息,请说明。

            回答:"""

        # 生成回答
        messages = [{"role": "user", "content": enhanced_prompt}]
        text = self.tokenizer.apply_chat_template(  # type: ignore
            messages,
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=True,
        )

        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.__device)  # type: ignore
        generated_ids = self.model.generate(  # type: ignore
            **model_inputs, max_new_tokens=1024, temperature=0.7, do_sample=True
        )

        response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[  # type: ignore
            0
        ]

        # 清理回答
        if "/think>" in response:
            response = response.split("/think>")[-1].strip()

        return response

    def get_stats(self) -> Dict[str, Any]:
        """获取系统统计信息"""
        stats = self.__document_store.get_stats()
        stats.update({"model_dir": self.__model_dir, "device": str(self.__device)})
        return stats

    @property
    def _llm_type(self) -> str:
        return "Qwen3PersistentRAG"


# 使用示例
if __name__ == "__main__":
    import time

    # 初始化持久化RAG系统
    rag = Qwen3PersistentRAG(
        model_dir="模型和模型配置存放文件夹路径", storage_dir="./my_rag_storage"
    )

    # 管理文档的交互式界面
    print("=== 持久化RAG系统 ===")
    print("命令:")
    print("  add <文档内容> - 添加文档")
    print("  list - 列出所有文档")
    print("  remove <索引> - 移除文档")
    print("  stats - 查看统计信息")
    print("  q - 退出")
    print()

    while True:
        command = input("> ").strip()

        if command.lower() == "q":
            break
        elif command.startswith("add "):
            doc = command[4:]
            rag.add_documents([doc])
            print(f"已添加文档")
        elif command == "list":
            docs = rag.list_documents()
            for i, doc in enumerate(docs):
                print(f"{i}: {doc[:50]}...")
        elif command.startswith("remove "):
            try:
                idx = int(command[7:])
                rag.remove_documents([idx])
                print(f"已移除文档 {idx}")
            except ValueError:
                print("请输入有效索引")
        elif command == "stats":
            stats = rag.get_stats()
            print(json.dumps(stats, indent=2, ensure_ascii=False))
        else:
            # 普通问答
            if command:
                start_time = time.time()
                answer = rag.invoke(command)
                end_time = time.time()

                print(f"\n回答:{answer}")
                print(f"耗时:{end_time - start_time:.2f}秒\n")

 

posted @ 2025-08-22 14:33  CJTARRR  阅读(15)  评论(0)    收藏  举报