arXiv论文管理RAG系统:从零构建生产级AI研究助手

arXiv论文管理RAG系统

一个完整的生产级检索增强生成(RAG)系统,专门用于管理和查询arXiv学术论文。该系统能够自动获取最新的AI研究论文,处理PDF内容,并提供智能问答功能。

功能特性

  • 自动化论文收集: 每日自动从arXiv API获取计算机科学AI领域的最新论文
  • 智能PDF解析: 使用Docling技术解析科学论文PDF,提取结构化内容
  • 混合搜索系统: 结合BM25关键词搜索和向量语义搜索的混合检索
  • 本地LLM集成: 集成Ollama支持本地大语言模型推理
  • 生产级监控: 集成Langfuse进行完整的RAG管道追踪和分析
  • 高性能缓存: Redis缓存实现150-400倍响应速度提升
  • Web界面: Gradio提供的交互式Web界面,支持实时流式响应
  • 工作流编排: Apache Airflow实现自动化数据处理管道

安装指南

系统要求

  • Python 3.12+
  • Docker & Docker Compose
  • 至少8GB可用内存

快速安装

  1. 克隆项目仓库:
git clone <repository-url>
cd arxiv-paper-curator
  1. 设置环境变量:
cp .env.example .env
  1. 启动所有服务:
docker compose up --build -d
  1. 验证服务状态:
curl http://localhost:8000/api/v1/health

依赖服务

系统包含以下核心服务:

  • FastAPI (端口8000): REST API服务
  • PostgreSQL (端口5432): 论文元数据存储
  • OpenSearch (端口9200): 混合搜索引擎
  • Apache Airflow (端口8080): 工作流编排
  • Ollama (端口11434): 本地LLM服务
  • Redis (端口6379): 缓存服务

使用说明

启动Web界面

uv run python gradio_launcher.py

访问 http://localhost:7861 使用交互式界面。

API使用示例

基础问答

import requests

response = requests.post(
    "http://localhost:8000/api/v1/ask",
    json={
        "query": "什么是机器学习中的注意力机制?",
        "top_k": 3,
        "use_hybrid": True,
        "model": "llama3.2:1b"
    }
)
print(response.json())

流式响应

import requests

response = requests.post(
    "http://localhost:8000/api/v1/stream",
    json={
        "query": "解释Transformer架构",
        "top_k": 2
    },
    stream=True
)

for line in response.iter_lines():
    if line:
        print(line.decode('utf-8'))

混合搜索

response = requests.post(
    "http://localhost:8000/api/v1/hybrid-search/",
    json={
        "query": "神经网络深度学习",
        "size": 10,
        "categories": ["cs.AI", "cs.LG"],
        "use_hybrid": True
    }
)

核心代码

1. 混合搜索服务

class HybridIndexingService:
    """服务用于索引论文的分块和嵌入,实现混合搜索"""
    
    def __init__(self, chunker: TextChunker, embeddings_client: JinaEmbeddingsClient, opensearch_client: OpenSearchClient):
        self.chunker = chunker
        self.embeddings_client = embeddings_client
        self.opensearch_client = opensearch_client
        logger.info("混合索引服务初始化完成")

    async def index_paper(self, paper_data: Dict) -> Dict[str, int]:
        """索引单篇论文,包括分块和嵌入生成"""
        arxiv_id = paper_data.get("arxiv_id")
        paper_id = str(paper_data.get("id", ""))

        if not arxiv_id:
            logger.error("论文缺少arxiv_id")
            return {"chunks_created": 0, "chunks_indexed": 0, "embeddings_generated": 0, "errors": 1}

        try:
            # 步骤1: 使用混合分段方法对论文进行分块
            chunks = self.chunker.chunk_paper(
                title=paper_data.get("title", ""),
                abstract=paper_data.get("abstract", ""),
                full_text=paper_data.get("raw_text", paper_data.get("full_text", "")),
                arxiv_id=arxiv_id,
                paper_id=paper_id,
                sections=paper_data.get("sections")
            )
            
            if not chunks:
                logger.warning(f"论文 {arxiv_id} 无有效分块")
                return {"chunks_created": 0, "chunks_indexed": 0, "embeddings_generated": 0}

            # 步骤2: 为分块生成嵌入
            chunk_texts = [chunk.text for chunk in chunks]
            embeddings = await self.embeddings_client.embed_passages(chunk_texts)
            
            # 步骤3: 准备OpenSearch文档
            documents = []
            for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
                doc = {
                    "chunk_id": f"{arxiv_id}_chunk_{i}",
                    "arxiv_id": arxiv_id,
                    "paper_id": paper_id,
                    "chunk_index": i,
                    "chunk_text": chunk.text,
                    "chunk_word_count": len(chunk.text.split()),
                    "embedding": embedding,
                    "title": paper_data.get("title", ""),
                    "authors": paper_data.get("authors", []),
                    "abstract": paper_data.get("abstract", ""),
                    "categories": paper_data.get("categories", []),
                    "published_date": paper_data.get("published_date"),
                    "section_title": chunk.metadata.section_title,
                    "embedding_model": "jina-embeddings-v3"
                }
                documents.append(doc)

            # 步骤4: 批量索引到OpenSearch
            success_count = await self.opensearch_client.bulk_index_documents(documents)
            
            logger.info(f"成功索引论文 {arxiv_id}: {success_count}/{len(documents)} 分块")
            return {
                "chunks_created": len(chunks),
                "chunks_indexed": success_count,
                "embeddings_generated": len(embeddings)
            }

        except Exception as e:
            logger.error(f"索引论文 {arxiv_id} 时出错: {e}")
            return {"chunks_created": 0, "chunks_indexed": 0, "embeddings_generated": 0, "errors": 1}

2. RAG问答端点

@ask_router.post("/ask", response_model=AskResponse)
async def ask_question(
    request: AskRequest,
    opensearch_client: OpenSearchDep,
    embeddings_service: EmbeddingsDep,
    ollama_client: OllamaDep,
    langfuse_tracer: LangfuseDep,
    cache_client: CacheDep,
) -> AskResponse:
    """RAG问答端点,支持缓存和追踪"""
    
    # 检查缓存
    if cache_client:
        cached_response = await cache_client.find_cached_response(request)
        if cached_response:
            logger.info("缓存命中,返回缓存响应")
            return cached_response

    # 创建追踪器
    rag_tracer = RAGTracer(langfuse_tracer)
    
    with rag_tracer.trace_request(user_id="api_user", query=request.query) as trace:
        start_time = time.time()
        
        try:
            # 检索相关分块
            chunks, arxiv_ids, sources = await _prepare_chunks_and_sources(
                request, opensearch_client, embeddings_service, rag_tracer, trace
            )
            
            if not chunks:
                return AskResponse(
                    query=request.query,
                    answer="未找到相关论文内容来回答此问题。",
                    sources=[],
                    chunks_used=0,
                    search_mode="none"
                )

            # 构建提示词
            with rag_tracer.trace_prompt_construction(trace, chunks) as prompt_span:
                prompt_builder = RAGPromptBuilder()
                prompt = prompt_builder.create_rag_prompt(request.query, chunks)
                rag_tracer.end_prompt(prompt_span, prompt)

            # 生成回答
            with rag_tracer.trace_generation(trace, request.model, prompt) as gen_span:
                response_text = await ollama_client.generate_response(
                    prompt=prompt,
                    model=request.model,
                    system_prompt=prompt_builder.system_prompt
                )
                rag_tracer.end_generation(gen_span, response_text, request.model)

            # 构建响应
            response = AskResponse(
                query=request.query,
                answer=response_text,
                sources=sources,
                chunks_used=len(chunks),
                search_mode="hybrid" if request.use_hybrid else "bm25"
            )

            # 缓存响应
            if cache_client:
                await cache_client.store_response(request, response)

            total_duration = time.time() - start_time
            rag_tracer.end_request(trace, response_text, total_duration)
            
            logger.info(f"RAG问答完成: {len(chunks)}分块, {total_duration:.2f}秒")
            return response

        except Exception as e:
            logger.error(f"RAG问答错误: {e}")
            raise HTTPException(status_code=500, detail=f"生成回答时出错: {str(e)}")

3. 文本分块服务

class TextChunker:
    """文本分块服务,将文本分割为重叠的段落"""
    
    def __init__(self, chunk_size: int = 600, overlap_size: int = 100, min_chunk_size: int = 100):
        self.chunk_size = chunk_size
        self.overlap_size = overlap_size
        self.min_chunk_size = min_chunk_size

        if overlap_size >= chunk_size:
            raise ValueError("重叠大小必须小于分块大小")

        logger.info(f"文本分块器初始化: 分块大小={chunk_size}, 重叠大小={overlap_size}, 最小分块大小={min_chunk_size}")

    def chunk_paper(
        self,
        title: str,
        abstract: str,
        full_text: str,
        arxiv_id: str,
        paper_id: str,
        sections: Optional[Union[Dict[str, str], str, list]] = None,
    ) -> List[TextChunk]:
        """使用混合分段方法对论文进行分块"""
        
        chunks = []
        
        # 处理标题和摘要
        if title and abstract:
            title_abstract_text = f"标题: {title}\n\n摘要: {abstract}"
            title_chunks = self._chunk_text(title_abstract_text, arxiv_id, paper_id, "title_abstract")
            chunks.extend(title_chunks)

        # 处理正文内容
        if full_text:
            if sections and isinstance(sections, list):
                # 使用分段结构进行智能分块
                for section in sections:
                    if isinstance(section, dict) and 'title' in section and 'content' in section:
                        section_title = section['title']
                        section_content = section['content']
                        
                        if section_content:
                            section_chunks = self._chunk_text(
                                section_content, arxiv_id, paper_id, f"section_{section_title}"
                            )
                            chunks.extend(section_chunks)
            else:
                # 回退到基于段落的分块
                body_chunks = self._chunk_text(full_text, arxiv_id, paper_id, "body")
                chunks.extend(body_chunks)

        logger.info(f"为论文 {arxiv_id} 创建了 {len(chunks)} 个分块")
        return chunks

    def _chunk_text(self, text: str, arxiv_id: str, paper_id: str, section_title: str) -> List[TextChunk]:
        """将文本分割为重叠的分块"""
        if not text or len(text.strip()) == 0:
            return []

        words = self._split_into_words(text)
        if len(words) < self.min_chunk_size:
            return []

        chunks = []
        start_idx = 0
        
        while start_idx < len(words):
            end_idx = min(start_idx + self.chunk_size, len(words))
            
            # 确保不在单词中间截断
            while end_idx < len(words) and not text[end_idx-1].isspace():
                end_idx += 1

            chunk_words = words[start_idx:end_idx]
            chunk_text = self._reconstruct_text(chunk_words)
            
            # 创建分块元数据
            metadata = ChunkMetadata(
                chunk_index=len(chunks),
                start_char=start_idx,
                end_char=end_idx,
                word_count=len(chunk_words),
                overlap_with_previous=self.overlap_size if start_idx > 0 else 0,
                overlap_with_next=self.overlap_size if end_idx < len(words) else 0,
                section_title=section_title
            )
            
            chunk = TextChunk(
                text=chunk_text,
                metadata=metadata,
                arxiv_id=arxiv_id,
                paper_id=paper_id
            )
            chunks.append(chunk)
            
            # 移动起始位置,考虑重叠
            start_idx += (self.chunk_size - self.overlap_size)
            
            # 如果剩余文本太少,则停止
            if len(words) - start_idx < self.min_chunk_size:
                break

        return chunks

4. 缓存客户端

class CacheClient:
    """基于Redis的精确匹配缓存,用于RAG查询"""
    
    def __init__(self, redis_client: redis.Redis, settings: RedisSettings):
        self.redis = redis_client
        self.settings = settings
        self.ttl = timedelta(hours=settings.ttl_hours)

    def _generate_cache_key(self, request: AskRequest) -> str:
        """基于请求参数生成精确缓存键"""
        key_data = {
            "query": request.query,
            "model": request.model,
            "top_k": request.top_k,
            "use_hybrid": request.use_hybrid,
            "categories": sorted(request.categories) if request.categories else [],
        }
        key_string = json.dumps(key_data, sort_keys=True)
        key_hash = hashlib.sha256(key_string.encode()).hexdigest()[:16]
        return f"exact_cache:{key_hash}"

    async def find_cached_response(self, request: AskRequest) -> Optional[AskResponse]:
        """查找精确查询匹配的缓存响应"""
        try:
            cache_key = self._generate_cache_key(request)
            cached_response = self.redis.get(cache_key)

            if cached_response:
                try:
                    response_data = json.loads(cached_response)
                    logger.info("精确查询匹配缓存命中")
                    return AskResponse(**response_data)
                except json.JSONDecodeError as e:
                    logger.warning(f"反序列化缓存响应失败: {e}")
                    return None

            return None

        except Exception as e:
            logger.error(f"检查缓存时出错: {e}")
            return None

    async def store_response(self, request: AskRequest, response: AskResponse) -> bool:
        """存储精确查询匹配的响应"""
        try:
            cache_key = self._generate_cache_key(request)
            success = self.redis.set(cache_key, response.model_dump_json(), ex=self.ttl)

            if success:
                logger.info(f"响应已存储到精确缓存,键为 {cache_key[:16]}...")
                return True
            else:
                logger.warning("存储响应到缓存失败")
                return False

        except Exception as e:
            logger.error(f"存储到缓存时出错: {e}")
            return False

更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)
对网络安全、黑客技术感兴趣的朋友可以关注我的安全公众号(网络安全技术点滴分享)

公众号二维码

公众号二维码

posted @ 2025-11-09 21:37  qife  阅读(7)  评论(0)    收藏  举报