import json
import sys
import time
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility, Index
# 连接到 Milvus
def connect_milvus(host='xxxxxx', port='31800'):
    print("Connecting to Milvus...")
    connections.connect(host=host, port=port)
# 创建或获取集合
def get_or_create_collection(collection_name, dim=256):
    if utility.has_collection(collection_name):
        print(f"Collection '{collection_name}' already exists.")
        return Collection(name=collection_name)
    else:
        print(f"Creating collection '{collection_name}'.")
        fields = [
            FieldSchema(name="item_code", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
            FieldSchema(name="blip_embedding", dtype=DataType.FLOAT_VECTOR, dim=dim)
        ]
        schema = CollectionSchema(fields, "item_blip向量")
        return Collection(name=collection_name, schema=schema)
# 创建索引
def create_index_if_not_exists(collection, field_name="blip_embedding"):
    # 检查是否已经存在索引
    if not collection.has_index():
        index_params = {
            "index_type": "IVF_FLAT",  # 选择合适的索引类型
            "metric_type": "IP",       # 选择合适的距离度量方式
            "params": {"nlist": 1414}   # nlist 是一个影响索引性能的参数,需根据数据量调整
        }
        print(f"Creating index on '{field_name}'...")
        Index(collection, field_name, index_params)
        print("Index has created.")
    else:
        print("Index already exists. No need to create a new one.")
# 重建索引
def recreate_index(collection, field_name="blip_embedding"):
    # 尝试释放集合,如果集合未加载则会捕获异常
    try:
        print("Releasing the collection before dropping the index...")
        collection.release()
    except Exception as e:
        print("Collection is not loaded, proceeding to drop index.")
    # 删除现有索引
    if collection.has_index():
        print("Dropping existing index...")
        collection.drop_index()
        print("Index dropped.")
    # 创建新的索引
    index_params = {
        "index_type": "IVF_FLAT",  # 选择合适的索引类型
        "metric_type": "IP",       # 选择合适的距离度量方式
        "params": {"nlist": 1414}   # nlist 是一个影响索引性能的参数,需根据数据量调整
    }
    print(f"Creating new index on '{field_name}'...")
    Index(collection, field_name, index_params)
    print("New index created.")
# 检查索引
def check_index_status(collection):
    index_info = collection.index()
    print("Index info:", index_info)
# 批量插入数据
def batch_insert(collection, item_codes, embeddings):
    entities = [item_codes, embeddings]
    try:
        insert_result = collection.insert(entities)
        print('Insert result: ', insert_result)
    except Exception as e:
        print('Error during insert:', e)
# 检查商品是否存在
def item_code_exists(collection, item_code):
    expr = f'item_code == "{item_code}"'
    try:
        results = collection.query(expr=expr, output_fields=["item_code"])
        return len(results) > 0
    except Exception as e:
        print(f"Error checking item code existence: {e}")
        return False
# 删除商品
def delete_item(collection, item_code):
    expr = f'item_code == "{item_code}"'
    try:
        collection.delete(expr)
        print(f"Deleted item code: {item_code}")
    except Exception as e:
        print(f"Error deleting item code: {e}")
# 主函数示例
def write2milvus(collection_blip):
    # 从文件中读取数据并批量插入
    with open('youshi_ic_embedding_all.txt', 'r', encoding='utf-8') as r_file:
        item_codes = []
        embeddings = []
        batch_size = 1024
        for index, line in enumerate(r_file):
            if index % 1000 == 0:
                print('商品写入的个数:', index+1)
            item_code, emb = line.strip().split('\t')
            emb = json.loads('[' + emb + ']')
            item_codes.append(item_code)
            embeddings.append(emb)
            if len(embeddings) == batch_size:
                batch_insert(collection_blip, item_codes, embeddings)
                item_codes = []
                embeddings = []
        if item_codes:
            batch_insert(collection_blip, item_codes, embeddings)
# 删除集合
def drop_collection(collection_name):
    try:
        collection = Collection(name=collection_name)
        collection.drop()
        print(f"Collection '{collection_name}' has been dropped.")
    except Exception as e:
        print(f"Error dropping collection '{collection_name}': {e}")
# 进行相似性搜索
def search_similar(collection, item_vec, limit=50):
    search_params = {
        "metric_type": "IP",
        "params": {"nprobe": 128},
    }
    collection.load()
    try:
        result = collection.search([item_vec], "blip_embedding", search_params, limit=limit, output_fields=["item_code"])
        item_codes = [hit.entity.get('item_code') for hits in result for hit in hits]
        return item_codes
    except Exception as e:
        print(f"Error during search: {e}")
        return []
def get_search_similar_all(collection, fin):
    with open(fin, 'r', encoding='utf-8') as r_file, open('youshi_ic_similar_rs.txt', 'w', encoding='utf-8') as out:
        for index, line in enumerate(r_file):
            if index % 1000 == 0:
                print('完成商品相似品计算的个数:', index+1)
            item_code, emb = line.strip().split('\t')
            emb = json.loads('[' + emb + ']')
            result = search_similar(collection, emb, limit=10)
            if result:
                out.write("{}\t{}\n".format(item_code, '#'.join(result)))
if __name__ == "__main__":
    # 连接到 Milvus
    connect_milvus()
    # 创建或获取集合
    collection_name = 'youshi_item_blip_vec'
    collection = get_or_create_collection(collection_name)
    # write2milvus(collection)
    recreate_index(collection)
    create_index_if_not_exists(collection)
    fin = './test.txt'
    get_search_similar_all(collection, fin)