arXiv论文管理RAG系统:从零构建生产级AI研究助手
arXiv论文管理RAG系统
一个完整的生产级检索增强生成(RAG)系统,专门用于管理和查询arXiv学术论文。该系统能够自动获取最新的AI研究论文,处理PDF内容,并提供智能问答功能。
功能特性
- 自动化论文收集: 每日自动从arXiv API获取计算机科学AI领域的最新论文
- 智能PDF解析: 使用Docling技术解析科学论文PDF,提取结构化内容
- 混合搜索系统: 结合BM25关键词搜索和向量语义搜索的混合检索
- 本地LLM集成: 集成Ollama支持本地大语言模型推理
- 生产级监控: 集成Langfuse进行完整的RAG管道追踪和分析
- 高性能缓存: Redis缓存实现150-400倍响应速度提升
- Web界面: Gradio提供的交互式Web界面,支持实时流式响应
- 工作流编排: Apache Airflow实现自动化数据处理管道
安装指南
系统要求
- Python 3.12+
- Docker & Docker Compose
- 至少8GB可用内存
快速安装
- 克隆项目仓库:
git clone <repository-url>
cd arxiv-paper-curator
- 设置环境变量:
cp .env.example .env
- 启动所有服务:
docker compose up --build -d
- 验证服务状态:
curl http://localhost:8000/api/v1/health
依赖服务
系统包含以下核心服务:
- FastAPI (端口8000): REST API服务
- PostgreSQL (端口5432): 论文元数据存储
- OpenSearch (端口9200): 混合搜索引擎
- Apache Airflow (端口8080): 工作流编排
- Ollama (端口11434): 本地LLM服务
- Redis (端口6379): 缓存服务
使用说明
启动Web界面
uv run python gradio_launcher.py
访问 http://localhost:7861 使用交互式界面。
API使用示例
基础问答
import requests
response = requests.post(
"http://localhost:8000/api/v1/ask",
json={
"query": "什么是机器学习中的注意力机制?",
"top_k": 3,
"use_hybrid": True,
"model": "llama3.2:1b"
}
)
print(response.json())
流式响应
import requests
response = requests.post(
"http://localhost:8000/api/v1/stream",
json={
"query": "解释Transformer架构",
"top_k": 2
},
stream=True
)
for line in response.iter_lines():
if line:
print(line.decode('utf-8'))
混合搜索
response = requests.post(
"http://localhost:8000/api/v1/hybrid-search/",
json={
"query": "神经网络深度学习",
"size": 10,
"categories": ["cs.AI", "cs.LG"],
"use_hybrid": True
}
)
核心代码
1. 混合搜索服务
class HybridIndexingService:
"""服务用于索引论文的分块和嵌入,实现混合搜索"""
def __init__(self, chunker: TextChunker, embeddings_client: JinaEmbeddingsClient, opensearch_client: OpenSearchClient):
self.chunker = chunker
self.embeddings_client = embeddings_client
self.opensearch_client = opensearch_client
logger.info("混合索引服务初始化完成")
async def index_paper(self, paper_data: Dict) -> Dict[str, int]:
"""索引单篇论文,包括分块和嵌入生成"""
arxiv_id = paper_data.get("arxiv_id")
paper_id = str(paper_data.get("id", ""))
if not arxiv_id:
logger.error("论文缺少arxiv_id")
return {"chunks_created": 0, "chunks_indexed": 0, "embeddings_generated": 0, "errors": 1}
try:
# 步骤1: 使用混合分段方法对论文进行分块
chunks = self.chunker.chunk_paper(
title=paper_data.get("title", ""),
abstract=paper_data.get("abstract", ""),
full_text=paper_data.get("raw_text", paper_data.get("full_text", "")),
arxiv_id=arxiv_id,
paper_id=paper_id,
sections=paper_data.get("sections")
)
if not chunks:
logger.warning(f"论文 {arxiv_id} 无有效分块")
return {"chunks_created": 0, "chunks_indexed": 0, "embeddings_generated": 0}
# 步骤2: 为分块生成嵌入
chunk_texts = [chunk.text for chunk in chunks]
embeddings = await self.embeddings_client.embed_passages(chunk_texts)
# 步骤3: 准备OpenSearch文档
documents = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
doc = {
"chunk_id": f"{arxiv_id}_chunk_{i}",
"arxiv_id": arxiv_id,
"paper_id": paper_id,
"chunk_index": i,
"chunk_text": chunk.text,
"chunk_word_count": len(chunk.text.split()),
"embedding": embedding,
"title": paper_data.get("title", ""),
"authors": paper_data.get("authors", []),
"abstract": paper_data.get("abstract", ""),
"categories": paper_data.get("categories", []),
"published_date": paper_data.get("published_date"),
"section_title": chunk.metadata.section_title,
"embedding_model": "jina-embeddings-v3"
}
documents.append(doc)
# 步骤4: 批量索引到OpenSearch
success_count = await self.opensearch_client.bulk_index_documents(documents)
logger.info(f"成功索引论文 {arxiv_id}: {success_count}/{len(documents)} 分块")
return {
"chunks_created": len(chunks),
"chunks_indexed": success_count,
"embeddings_generated": len(embeddings)
}
except Exception as e:
logger.error(f"索引论文 {arxiv_id} 时出错: {e}")
return {"chunks_created": 0, "chunks_indexed": 0, "embeddings_generated": 0, "errors": 1}
2. RAG问答端点
@ask_router.post("/ask", response_model=AskResponse)
async def ask_question(
request: AskRequest,
opensearch_client: OpenSearchDep,
embeddings_service: EmbeddingsDep,
ollama_client: OllamaDep,
langfuse_tracer: LangfuseDep,
cache_client: CacheDep,
) -> AskResponse:
"""RAG问答端点,支持缓存和追踪"""
# 检查缓存
if cache_client:
cached_response = await cache_client.find_cached_response(request)
if cached_response:
logger.info("缓存命中,返回缓存响应")
return cached_response
# 创建追踪器
rag_tracer = RAGTracer(langfuse_tracer)
with rag_tracer.trace_request(user_id="api_user", query=request.query) as trace:
start_time = time.time()
try:
# 检索相关分块
chunks, arxiv_ids, sources = await _prepare_chunks_and_sources(
request, opensearch_client, embeddings_service, rag_tracer, trace
)
if not chunks:
return AskResponse(
query=request.query,
answer="未找到相关论文内容来回答此问题。",
sources=[],
chunks_used=0,
search_mode="none"
)
# 构建提示词
with rag_tracer.trace_prompt_construction(trace, chunks) as prompt_span:
prompt_builder = RAGPromptBuilder()
prompt = prompt_builder.create_rag_prompt(request.query, chunks)
rag_tracer.end_prompt(prompt_span, prompt)
# 生成回答
with rag_tracer.trace_generation(trace, request.model, prompt) as gen_span:
response_text = await ollama_client.generate_response(
prompt=prompt,
model=request.model,
system_prompt=prompt_builder.system_prompt
)
rag_tracer.end_generation(gen_span, response_text, request.model)
# 构建响应
response = AskResponse(
query=request.query,
answer=response_text,
sources=sources,
chunks_used=len(chunks),
search_mode="hybrid" if request.use_hybrid else "bm25"
)
# 缓存响应
if cache_client:
await cache_client.store_response(request, response)
total_duration = time.time() - start_time
rag_tracer.end_request(trace, response_text, total_duration)
logger.info(f"RAG问答完成: {len(chunks)}分块, {total_duration:.2f}秒")
return response
except Exception as e:
logger.error(f"RAG问答错误: {e}")
raise HTTPException(status_code=500, detail=f"生成回答时出错: {str(e)}")
3. 文本分块服务
class TextChunker:
"""文本分块服务,将文本分割为重叠的段落"""
def __init__(self, chunk_size: int = 600, overlap_size: int = 100, min_chunk_size: int = 100):
self.chunk_size = chunk_size
self.overlap_size = overlap_size
self.min_chunk_size = min_chunk_size
if overlap_size >= chunk_size:
raise ValueError("重叠大小必须小于分块大小")
logger.info(f"文本分块器初始化: 分块大小={chunk_size}, 重叠大小={overlap_size}, 最小分块大小={min_chunk_size}")
def chunk_paper(
self,
title: str,
abstract: str,
full_text: str,
arxiv_id: str,
paper_id: str,
sections: Optional[Union[Dict[str, str], str, list]] = None,
) -> List[TextChunk]:
"""使用混合分段方法对论文进行分块"""
chunks = []
# 处理标题和摘要
if title and abstract:
title_abstract_text = f"标题: {title}\n\n摘要: {abstract}"
title_chunks = self._chunk_text(title_abstract_text, arxiv_id, paper_id, "title_abstract")
chunks.extend(title_chunks)
# 处理正文内容
if full_text:
if sections and isinstance(sections, list):
# 使用分段结构进行智能分块
for section in sections:
if isinstance(section, dict) and 'title' in section and 'content' in section:
section_title = section['title']
section_content = section['content']
if section_content:
section_chunks = self._chunk_text(
section_content, arxiv_id, paper_id, f"section_{section_title}"
)
chunks.extend(section_chunks)
else:
# 回退到基于段落的分块
body_chunks = self._chunk_text(full_text, arxiv_id, paper_id, "body")
chunks.extend(body_chunks)
logger.info(f"为论文 {arxiv_id} 创建了 {len(chunks)} 个分块")
return chunks
def _chunk_text(self, text: str, arxiv_id: str, paper_id: str, section_title: str) -> List[TextChunk]:
"""将文本分割为重叠的分块"""
if not text or len(text.strip()) == 0:
return []
words = self._split_into_words(text)
if len(words) < self.min_chunk_size:
return []
chunks = []
start_idx = 0
while start_idx < len(words):
end_idx = min(start_idx + self.chunk_size, len(words))
# 确保不在单词中间截断
while end_idx < len(words) and not text[end_idx-1].isspace():
end_idx += 1
chunk_words = words[start_idx:end_idx]
chunk_text = self._reconstruct_text(chunk_words)
# 创建分块元数据
metadata = ChunkMetadata(
chunk_index=len(chunks),
start_char=start_idx,
end_char=end_idx,
word_count=len(chunk_words),
overlap_with_previous=self.overlap_size if start_idx > 0 else 0,
overlap_with_next=self.overlap_size if end_idx < len(words) else 0,
section_title=section_title
)
chunk = TextChunk(
text=chunk_text,
metadata=metadata,
arxiv_id=arxiv_id,
paper_id=paper_id
)
chunks.append(chunk)
# 移动起始位置,考虑重叠
start_idx += (self.chunk_size - self.overlap_size)
# 如果剩余文本太少,则停止
if len(words) - start_idx < self.min_chunk_size:
break
return chunks
4. 缓存客户端
class CacheClient:
"""基于Redis的精确匹配缓存,用于RAG查询"""
def __init__(self, redis_client: redis.Redis, settings: RedisSettings):
self.redis = redis_client
self.settings = settings
self.ttl = timedelta(hours=settings.ttl_hours)
def _generate_cache_key(self, request: AskRequest) -> str:
"""基于请求参数生成精确缓存键"""
key_data = {
"query": request.query,
"model": request.model,
"top_k": request.top_k,
"use_hybrid": request.use_hybrid,
"categories": sorted(request.categories) if request.categories else [],
}
key_string = json.dumps(key_data, sort_keys=True)
key_hash = hashlib.sha256(key_string.encode()).hexdigest()[:16]
return f"exact_cache:{key_hash}"
async def find_cached_response(self, request: AskRequest) -> Optional[AskResponse]:
"""查找精确查询匹配的缓存响应"""
try:
cache_key = self._generate_cache_key(request)
cached_response = self.redis.get(cache_key)
if cached_response:
try:
response_data = json.loads(cached_response)
logger.info("精确查询匹配缓存命中")
return AskResponse(**response_data)
except json.JSONDecodeError as e:
logger.warning(f"反序列化缓存响应失败: {e}")
return None
return None
except Exception as e:
logger.error(f"检查缓存时出错: {e}")
return None
async def store_response(self, request: AskRequest, response: AskResponse) -> bool:
"""存储精确查询匹配的响应"""
try:
cache_key = self._generate_cache_key(request)
success = self.redis.set(cache_key, response.model_dump_json(), ex=self.ttl)
if success:
logger.info(f"响应已存储到精确缓存,键为 {cache_key[:16]}...")
return True
else:
logger.warning("存储响应到缓存失败")
return False
except Exception as e:
logger.error(f"存储到缓存时出错: {e}")
return False
更多精彩内容 请关注我的个人公众号 公众号(办公AI智能小助手)
对网络安全、黑客技术感兴趣的朋友可以关注我的安全公众号(网络安全技术点滴分享)
公众号二维码

公众号二维码


浙公网安备 33010602011771号