ES RAG向量搜索示例,使用BAAI BGE创建embedding
准备:
docker pull docker.elastic.co/elasticsearch/elasticsearch:7.6.2 7.6.2: Pulling from elasticsearch/elasticsearch c808caf183b6: Pull complete d6caf8e15a64: Pull complete b0ba5f324e82: Pull complete d7e8c1e99b9a: Pull complete 85c4d6c81438: Pull complete 3119218fac98: Pull complete 914accf214bb: Pull complete Digest: sha256:59342c577e2b7082b819654d119f42514ddf47f0699c8b54dc1f0150250ce7aa Status: Downloaded newer image for docker.elastic.co/elasticsearch/elasticsearch:7.6.2 docker.elastic.co/elasticsearch/elasticsearch:7.6.2 What's Next? View a summary of image vulnerabilities and recommendations → docker scout quickview docker.elastic.co/elasticsearch/elasticsearch:7.6.2 PS D:\source\pythonProject> pip install elasticsearch Requirement already satisfied: elasticsearch in d:\python\python312\lib\site-packages (7.6.0) Requirement already satisfied: urllib3>=1.21.1 in d:\python\python312\lib\site-packages (from elasticsearch) (1.26.18)
进入容器修改配置 docker exec -it esid bash cd config/ vi elasticsearch.yml 增加 http.cors.enabled: true http.cors.allow-origin: "*" discovery.zen.minimum_master_nodes: 1 重启服务 docker restart esid
查看页面
ip:9200

编写代码:
from elasticsearch import Elasticsearch
# 连接Elasticsearch
es = Elasticsearch()
# 定义索引的设置和映射
index_name = "vector_search_example"
index_settings = {
"settings": {
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"properties": {
"title": {"type": "text"},
"embedding": {
"type": "dense_vector", # 使用dense_vector类型
"dims": 5, # 向量维度,根据实际情况调整
}
}
}
}
# 创建索引
if not es.indices.exists(index=index_name):
es.indices.create(index=index_name, body=index_settings)
# 存储向量数据示例
doc1 = {
"title": "Hello World Document",
"embedding": [0.1, 0.2, 0.3, 0.4, 0.5] # 示例向量数据
}
response = es.index(index=index_name, id=1, body=doc1)
print(f"Indexed document: {response['result']}")
# 添加更多文档
doc2 = {
"title": "Another Document Example",
"embedding": [0.2, 0.35, 0.45, 0.55, 0.6] # 另一个示例向量
}
response = es.index(index=index_name, id=2, body=doc2)
print(f"Indexed document: {response['result']}")
doc3 = {
"title": "Yet Another Hello",
"embedding": [0.7, 0.6, 0.5, 0.4, 0.3] # 第三个示例向量,与前两个有较大差异
}
response = es.index(index=index_name, id=3, body=doc3)
print(f"Indexed document: {response['result']}")
# 搜索相似向量
query_vector = [0.2, 0.3, 0.4, 0.5, 0.6] # 查询向量
script_query = {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
"params": {"query_vector": query_vector}
}
}
}
response = es.search(index=index_name, body={"query": script_query}, size=2)
# 打印搜索结果
for hit in response["hits"]["hits"]:
print(f"Document ID: {hit['_id']}, Score: {hit['_score']}, Title: {hit['_source']['title']}")
返回结果:
Indexed document: updated Indexed document: updated Indexed document: updated Document ID: 2, Score: 1.9982954, Title: Another Document Example Document ID: 1, Score: 1.9949367, Title: Hello World Document
我们再复杂一点,使用BGE模型进行编码,便于搜索:
from elasticsearch import Elasticsearch
from FlagEmbedding import FlagModel
from collections import defaultdict
from time import time
# 连接Elasticsearch
es = Elasticsearch()
# 定义索引的设置和映射
index_name = "vector_search_sec_tool"
index_settings = {
"settings": {
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"properties": {
"description": {"type": "text"},
"embedding": {
"type": "dense_vector",
"dims": 768
}
}
}
}
# 创建索引
if not es.indices.exists(index=index_name):
es.indices.create(index=index_name, body=index_settings)
def search_sectool_knowledge_base(descriptions):
# 构建索引
corpus = []
index = defaultdict(dict)
for item in descriptions:
for method, description in item['methods'].items():
index[description] = {"method": method, "path": item["path"]}
corpus.append(description)
embedder = FlagModel('bge-base-zh-v1.5/',
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",)
corpus_embeddings = embedder.encode(corpus)
# 存储向量数据到Elasticsearch
for i, description in enumerate(corpus):
doc = {
"description": description,
"embedding": corpus_embeddings[i].tolist() # 将numpy数组转换为列表
}
response = es.index(index=index_name, id=i+1, body=doc)
print(f"Indexed document: {response['result']}")
# Query sentences:
queries = [
'搜索告警列表',
'查询漏洞']
now = time()
times = 1
for i in range(times):
for query in queries:
query_embedding = embedder.encode(query).tolist()
script_query = {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
"params": {"query_vector": query_embedding}
}
}
}
response = es.search(index=index_name, body={"query": script_query}, size=3)
print("\n\n======================\n\n")
print("Query:", query)
print("\nTop 3 most similar sentences in corpus:")
for hit in response["hits"]["hits"]:
description = hit["_source"]["description"]
score = hit["_score"]
print(f"{description} (Score: {score:.4f}) ==> {index[description]}")
print(f"{times} {time() - now} seconds elapsed")
if __name__ == '__main__':
descriptions = [{'path': '/v1/{project_id}/subscriptions/version', 'methods': {'GET': '获取视图订购信息'}},
{'path': '/v1/{project_id}/workspaces/{workspace_id}/sa/reports',
'methods': {'GET': '分析报管理获取报告列表'}},
{'path': '/v1/{project_id}/workspaces/{workspace_id}/siem/alert-rules',
'methods': {'GET': 'corss-workspace智能建模聚合列表接口'}},
{'path': '/v1/{project_id}/workspaces/{workspace_id}/siem/alert-rules/metrics',
'methods': {'GET': 'cross-workspace智能建模可用模型指标接口'}},
{'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/alerts/search',
'methods': {'POST': '搜索告警列表'}},
{'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/incidents/search',
'methods': {'POST': '搜索事件列表'}},
{'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/indicators/search',
'methods': {'POST': '威胁情报列表查询'}},
{'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/vulnerability/search',
'methods': {'POST': '查询漏洞列表'}}]
search_sectool_knowledge_base(descriptions)
运行结果:
Indexed document: updated
Indexed document: updated
Indexed document: updated
Indexed document: updated
Indexed document: updated
Indexed document: updated
Indexed document: updated
Indexed document: updated
======================
Query: 搜索告警列表
Top 3 most similar sentences in corpus:
搜索告警列表 (Score: 2.0000) ==> {'method': 'POST', 'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/alerts/search'}
搜索事件列表 (Score: 1.9030) ==> {'method': 'POST', 'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/incidents/search'}
威胁情报列表查询 (Score: 1.8769) ==> {'method': 'POST', 'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/indicators/search'}
======================
Query: 查询漏洞
Top 3 most similar sentences in corpus:
查询漏洞列表 (Score: 1.9688) ==> {'method': 'POST', 'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/vulnerability/search'}
威胁情报列表查询 (Score: 1.8580) ==> {'method': 'POST', 'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/indicators/search'}
搜索告警列表 (Score: 1.8370) ==> {'method': 'POST', 'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/alerts/search'}
1 0.060091257095336914 seconds elapsed
还可以继续优化下,将ES数据存储完整:
from elasticsearch import Elasticsearch
from FlagEmbedding import FlagModel
from collections import defaultdict
from time import time
# 连接Elasticsearch
es = Elasticsearch()
# 定义索引的设置和映射
index_name = "vector_search_example222"
index_settings = {
"settings": {
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"properties": {
"path": {"type": "text"},
"methods": {"type": "object"},
"description": {"type": "text"},
"embedding": {
"type": "dense_vector",
"dims": 768 # 假设FlagModel生成768维的向量
}
}
}
}
# 创建索引
if not es.indices.exists(index=index_name):
es.indices.create(index=index_name, body=index_settings)
def search_sec_knowledge_base(descriptions):
# 构建索引
corpus = []
index = defaultdict(dict)
for item in descriptions:
for method, description in item['methods'].items():
index[description] = {"method": method, "path": item["path"]}
corpus.append(description)
model = FlagModel('bge-base-zh-v1.5/',
query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:",)
embedder = model
corpus_embeddings = embedder.encode(corpus)
# 存储向量数据到Elasticsearch
for i, description in enumerate(corpus):
doc = {
"path": index[description]["path"],
"methods": {index[description]["method"]: description},
"description": description,
"embedding": [float(x) for x in corpus_embeddings[i]] # 确保是浮点数列表
}
response = es.index(index=index_name, id=i+1, body=doc)
print(f"Indexed document: {response['result']}")
# Query sentences:
queries = [
'搜索告警列表',
'查询漏洞',
'Someone in a gorilla costume is playing a set of drums.',
'A cheetah chases prey on across a field.']
now = time()
times = 1
for i in range(times):
for query in queries:
query_embedding = [float(x) for x in embedder.encode(query)]
script_query = {
"script_score": {
"query": {"match_all": {}},
"script": {
"source": "cosineSimilarity(params.query_vector, 'embedding') + 1.0",
"params": {"query_vector": query_embedding}
}
}
}
response = es.search(index=index_name, body={"query": script_query}, size=5)
print("\n\n======================\n\n")
print("Query:", query)
print("\nTop 5 most similar sentences in corpus:")
for hit in response["hits"]["hits"]:
source = hit["_source"]
description = source["description"]
score = hit["_score"]
print(f"{description} (Score: {score:.4f}) ==> Path: {source['path']}, Methods: {source['methods']}")
print(f"{times} {time() - now} seconds elapsed")
if __name__ == '__main__':
descriptions = [{'path': '/v1/{project_id}/subscriptions/version', 'methods': {'GET': '获取视图订购信息'}},
{'path': '/v1/{project_id}/workspaces/{workspace_id}/sa/reports',
'methods': {'GET': '分析报管理获取报告列表'}},
{'path': '/v1/{project_id}/workspaces/{workspace_id}/siem/alert-rules',
'methods': {'GET': 'corss-workspace智能建模聚合列表接口'}},
{'path': '/v1/{project_id}/workspaces/{workspace_id}/siem/alert-rules/metrics',
'methods': {'GET': 'cross-workspace智能建模可用模型指标接口'}},
{'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/alerts/search',
'methods': {'POST': '搜索告警列表'}},
{'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/incidents/search',
'methods': {'POST': '搜索事件列表'}},
{'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/indicators/search',
'methods': {'POST': '威胁情报列表查询'}},
{'path': '/v1/{project_id}/workspaces/{workspace_id}/soc/vulnerability/search',
'methods': {'POST': '查询漏洞列表'}}]
search_sec_knowledge_base(descriptions)

浙公网安备 33010602011771号