from chromadb.config import Settings
from chromadb.utils import embedding_functions
import os
import chromadb
# 设置 Chroma 配置
persist_directory = "database"
if not os.path.exists(persist_directory):
os.makedirs(persist_directory)
client = chromadb.PersistentClient(Settings(
persist_directory=persist_directory,
))
from chromadb.api.types import Documents, Embeddings
from chromadb.utils.embedding_functions import EmbeddingFunction
class CustomEmbeddingFunction(EmbeddingFunction):
def __call__(self, texts: Documents) :
# 这里实现你的嵌入逻辑
import numpy as np
return [np.random.rand(384).tolist() for _ in texts] # 示例:随机向量
def name(self):
return "custom_embedding_function" # 提供一个名称
# 创建一个集合
collection = client.get_or_create_collection(name="my-collection", metadata={"hnsw:space": "cosine"}, embedding_function=CustomEmbeddingFunction())
# 添加数据
ids = ["id1", "id2"]
documents = ["This is the first document.", "This is the second document."]
embeddings = collection.add(ids=ids, documents=documents, embeddings=None) # 如果未提供 embeddings,则会自动生成
# 查询数据
query_texts = ["This is a query about the first document."]
results = collection.query(query_texts=query_texts, n_results=2) # 获取最接近的2个结果
print(results)