Python langchain RAG演示
示例代码:
import json import os import pickle from typing import Any, Dict, List, Optional, Tuple import numpy as np import torch from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.llms.base import LLM from loguru import logger from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig class PersistentDocumentStore: """支持磁盘持久化的文档存储系统""" def __init__(self, storage_dir: str = "./rag_storage"): """ 初始化持久化文档存储 Args: storage_dir: 存储目录路径 """ self.__storage_dir = storage_dir self.documents_file = os.path.join(storage_dir, "documents.json") self.vectors_file = os.path.join(storage_dir, "vectors.pkl") self.vectorizer_file = os.path.join(storage_dir, "vectorizer.pkl") # 确保存储目录存在 os.makedirs(storage_dir, exist_ok=True) # 初始化 self.documents: List[str] = [] self.vectorizer = TfidfVectorizer(max_features=2000) self.vectors = None # 从磁盘加载已有数据 self._load_from_disk() def _load_from_disk(self): """从磁盘加载数据""" try: # 加载文档 if os.path.exists(self.documents_file): with open(self.documents_file, "r", encoding="utf-8") as f: self.documents = json.load(f) logger.info(f"从磁盘加载了 {len(self.documents)} 条文档") # 加载向量化器 if os.path.exists(self.vectorizer_file): with open(self.vectorizer_file, "rb") as f: self.vectorizer = pickle.load(f) logger.info("从磁盘加载了向量化器") # 加载向量 if os.path.exists(self.vectors_file): with open(self.vectors_file, "rb") as f: self.vectors = pickle.load(f) logger.info("从磁盘加载了向量矩阵") except Exception as e: logger.warning(f"加载磁盘数据失败: {e}") self.documents = [] self.vectors = None def _save_to_disk(self): """保存数据到磁盘""" try: # 保存文档 with open(self.documents_file, "w", encoding="utf-8") as f: json.dump(self.documents, f, ensure_ascii=False, indent=2) # 保存向量化器 with open(self.vectorizer_file, "wb") as f: pickle.dump(self.vectorizer, f) # 保存向量 if self.vectors is not None: with open(self.vectors_file, "wb") as f: pickle.dump(self.vectors, f) logger.info("数据已保存到磁盘") except Exception as e: logger.error(f"保存磁盘数据失败: {e}") def add_documents(self, docs: List[str]): """添加文档并更新存储""" new_docs = [doc for doc in docs if doc not in self.documents] if new_docs: self.documents.extend(new_docs) if self.documents: self.vectors = self.vectorizer.fit_transform(self.documents) self._save_to_disk() logger.info( f"添加了 {len(new_docs)} 条新文档,总计 {len(self.documents)} 条" ) def remove_documents(self, indices: List[int]): """按索引移除文档""" if not indices: return indices = sorted(indices, reverse=True) for idx in indices: if 0 <= idx < len(self.documents): del self.documents[idx] if self.documents: self.vectors = self.vectorizer.fit_transform(self.documents) else: self.vectors = None self._save_to_disk() logger.info(f"移除了 {len(indices)} 条文档,剩余 {len(self.documents)} 条") def search(self, query: str, top_k: int = 3) -> List[Tuple[str, float, int]]: """搜索相关文档,返回文档、相似度和索引""" if not self.documents or self.vectors is None: return [] query_vec = self.vectorizer.transform([query]) similarities = cosine_similarity(query_vec, self.vectors).flatten() # 获取最相关的文档索引 top_indices = np.argsort(similarities)[-top_k:][::-1] return [(self.documents[i], similarities[i], i) for i in top_indices] # type: ignore def get_stats(self) -> Dict[str, Any]: """获取存储统计信息""" return { "total_documents": len(self.documents), "storage_dir": self.__storage_dir, "vector_dim": self.vectors.shape[1] if self.vectors is not None else 0, } class Qwen3PersistentRAG(LLM): """支持磁盘持久化的RAG系统""" tokenizer: Optional[AutoTokenizer] = None model: Optional[AutoModelForCausalLM] = None def __init__(self, model_dir: str, storage_dir: str = "./rag_storage"): super().__init__() self.__model_dir = model_dir self.__storage_dir = storage_dir # 初始化持久化文档存储 self.__document_store = PersistentDocumentStore(storage_dir) logger.info("正在初始化持久化RAG系统...") self._load_model() def _load_model(self): """加载模型""" logger.info("检测设备...") self.__device = self._get_device() logger.info("加载分词器...") self.tokenizer = AutoTokenizer.from_pretrained( self.__model_dir, trust_remote_code=True ) logger.info("加载模型...") self.model = AutoModelForCausalLM.from_pretrained( self.__model_dir, device_map=None, trust_remote_code=True, torch_dtype=( torch.float16 if self.__device != torch.device("cpu") else torch.float32 ), low_cpu_mem_usage=True, ).to( self.__device # type: ignore ) # type: ignore logger.info("加载生成配置...") self.model.generation_config = GenerationConfig.from_pretrained( # type: ignore self.__model_dir, trust_remote_code=True ) def _get_device(self) -> torch.device: """获取最佳设备""" if torch.backends.mps.is_available(): logger.info("使用Apple Silicon GPU (MPS)") return torch.device("mps") elif torch.cuda.is_available(): logger.info(f"使用NVIDIA GPU: {torch.cuda.get_device_name(0)}") return torch.device("cuda") else: logger.info("使用CPU") return torch.device("cpu") def add_documents(self, docs: List[str]): """添加文档到存储""" self.__document_store.add_documents(docs) def remove_documents(self, indices: List[int]): """移除指定索引的文档""" self.__document_store.remove_documents(indices) def list_documents(self) -> List[str]: """列出所有文档""" return self.__document_store.documents def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """RAG增强的调用方法""" # 检索相关文档 relevant_docs = self.__document_store.search(prompt, top_k=3) # 构建增强的prompt context_parts = [] for doc, score, idx in relevant_docs: context_parts.append(f"[文档{idx+1} - 相关度{score:.2f}]\n{doc}") context = "\n\n".join(context_parts) enhanced_prompt = f"""基于以下上下文回答问题: {context} 问题:{prompt} 请基于提供的上下文信息回答问题。如果上下文中没有相关信息,请说明。 回答:""" # 生成回答 messages = [{"role": "user", "content": enhanced_prompt}] text = self.tokenizer.apply_chat_template( # type: ignore messages, tokenize=False, add_generation_prompt=True, enable_thinking=True, ) model_inputs = self.tokenizer([text], return_tensors="pt").to(self.__device) # type: ignore generated_ids = self.model.generate( # type: ignore **model_inputs, max_new_tokens=1024, temperature=0.7, do_sample=True ) response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[ # type: ignore 0 ] # 清理回答 if "/think>" in response: response = response.split("/think>")[-1].strip() return response def get_stats(self) -> Dict[str, Any]: """获取系统统计信息""" stats = self.__document_store.get_stats() stats.update({"model_dir": self.__model_dir, "device": str(self.__device)}) return stats @property def _llm_type(self) -> str: return "Qwen3PersistentRAG" # 使用示例 if __name__ == "__main__": import time # 初始化持久化RAG系统 rag = Qwen3PersistentRAG( model_dir="模型和模型配置存放文件夹路径", storage_dir="./my_rag_storage" ) # 管理文档的交互式界面 print("=== 持久化RAG系统 ===") print("命令:") print(" add <文档内容> - 添加文档") print(" list - 列出所有文档") print(" remove <索引> - 移除文档") print(" stats - 查看统计信息") print(" q - 退出") print() while True: command = input("> ").strip() if command.lower() == "q": break elif command.startswith("add "): doc = command[4:] rag.add_documents([doc]) print(f"已添加文档") elif command == "list": docs = rag.list_documents() for i, doc in enumerate(docs): print(f"{i}: {doc[:50]}...") elif command.startswith("remove "): try: idx = int(command[7:]) rag.remove_documents([idx]) print(f"已移除文档 {idx}") except ValueError: print("请输入有效索引") elif command == "stats": stats = rag.get_stats() print(json.dumps(stats, indent=2, ensure_ascii=False)) else: # 普通问答 if command: start_time = time.time() answer = rag.invoke(command) end_time = time.time() print(f"\n回答:{answer}") print(f"耗时:{end_time - start_time:.2f}秒\n")
 
                     
                    
                 
                    
                 
                
            
         
         浙公网安备 33010602011771号
浙公网安备 33010602011771号