1.概述

有些情况下,我们需要实现以图搜图的功能,就是根据图片找相似的图片。
实现原理:

  • 使用特征提取模型将图片进行入库
  • 将需要查询的图片进行向量化
  • 在数据库使用相似度进行检索
    这里使用 python 更适合做向量检索,当然也可以使用java ,但是java 支持没有那么好。

如果项目是java 项目开发:

  1. 可以将 python 做成一个服务接口,java 调用接口。
  2. 这个功能直接使用pyghon ,将 这个做成一个镜像服务进行使用。

2.实现过程

2.1. 创建向量数据库表

我使用的向量数据库是postgres ,向量数据库的向量最大不能超过2000 ,因此在创建向量数据库的时候维度指定为 1536.
CREATE TABLE "public"."images" (
  "id" uuid NOT NULL,
  "content" text COLLATE "pg_catalog"."default",
  "metadata" jsonb,
  "embedding" vector(1536)
)
;

2.2 下载特征提取模型

下载模型

from huggingface_hub import snapshot_download
 
snapshot_download(
    repo_id="openai/clip-vit-base-patch32",
    local_dir="./models",
    endpoint="https://hf-mirror.com",
    local_dir_use_symlinks=False  # 直接复制文件,便于打包
)

下载模型后我们可以版这些模型转成 onnx 模型

# export_clip_to_onnx.py

import torch
import os
from transformers import CLIPVisionModel, CLIPProcessor
from pathlib import Path

# -------------------------------
# 配置参数
# -------------------------------
MODEL_PATH = "./models"  # ← 修改为你的本地模型路径
ONNX_OUTPUT_PATH = "clip_vision.onnx"
OPSET_VERSION = 14
BATCH_SIZE = 1
DYNAMIC_BATCH = True

HEIGHT, WIDTH = 224, 224

# -------------------------------
# 1. 加载本地模型和预处理器(启用 use_fast)
# -------------------------------
print("加载本地模型...")
vision_model = CLIPVisionModel.from_pretrained(MODEL_PATH)
processor = CLIPProcessor.from_pretrained(MODEL_PATH, use_fast=True)  # ✅ 显式启用 fast tokenizer

vision_model.eval()

# -------------------------------
# 2. 创建示例输入
# -------------------------------
dummy_input = torch.randn(BATCH_SIZE, 3, HEIGHT, WIDTH)

# -------------------------------
# 3. ONNX 导出配置
# -------------------------------
input_names = ["pixel_values"]
output_names = ["last_hidden_state", "pooler_output"]

dynamic_axes = {
    "pixel_values": {0: "batch"},
    "last_hidden_state": {0: "batch"},
    "pooler_output": {0: "batch"}
} if DYNAMIC_BATCH else None

# -------------------------------
# 4. 导出为 ONNX(删除非法参数)
# -------------------------------
print("正在导出 ONNX 模型...")
with torch.no_grad():
    torch.onnx.export(
        model=vision_model,
        args=dummy_input,
        f=ONNX_OUTPUT_PATH,
        opset_version=OPSET_VERSION,
        do_constant_folding=True,
        input_names=input_names,
        output_names=output_names,
        dynamic_axes=dynamic_axes,
        verbose=False,
        # 删除了非法参数:output_attentions, output_hidden_states
    )

print(f" ONNX 模型已成功导出:{ONNX_OUTPUT_PATH}")

# -------------------------------
# 5. 验证 ONNX 模型
# -------------------------------
import onnx
from onnx import shape_inference

print("验证 ONNX 模型结构...")
onnx_model = onnx.load(ONNX_OUTPUT_PATH)
onnx_model = shape_inference.infer_shapes(onnx_model)
onnx.checker.check_model(onnx_model)
print("ONNX 模型验证通过!")

# 显示输入输出
print(" 模型输入/输出信息:")
for inp in onnx_model.graph.input:
    print(f"输入: {inp.name} -> {inp.type.tensor_type.shape}")
for out in onnx_model.graph.output:
    print(f"输出: {out.name} -> {out.type.tensor_type.shape}")

2.3 编写相关代码

需要使用 python 3.11 的版本

需要安装依赖

Flask==2.3.2
psycopg2-binary==2.9.7
numpy==1.24.3
torch==2.0.1
torchvision==0.15.2
onnxruntime==1.15.1
Pillow==10.0.0
scikit-learn==1.3.0
import os
import numpy as np
from flask import Flask, request, jsonify
import psycopg2
import json
import uuid
from PIL import Image
import torch
from torchvision import transforms
import onnxruntime as ort
import traceback

app = Flask(__name__)

# 数据库配置
DB_CONFIG = {
    'host': '192.168.2.14',
    'port': 5432,
    'database': 'postgres',
    'user': 'postgres',
    'password': 'postgres'
}

# ONNX模型路径
MODEL_PATH = "D:\\temp\\clip\\clip_vision.onnx"

IMAGE_BASE_PATH = "D:/temp/images/"

class ImageVectorService:
    def __init__(self, db_config, model_path):
        self.db_config = db_config
        self.model_path = model_path
        self.connection = None
        self.session = None
        self.transform = None
        self._init_model()
        self._init_db()

    def _init_model(self):
        """初始化ONNX模型"""
        try:
            self.session = ort.InferenceSession(self.model_path)
            # 定义图像预处理
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.48145466, 0.4578275, 0.40821073],
                    std=[0.26862954, 0.26130258, 0.27577711]
                )
            ])
            print("ONNX模型加载成功")
        except Exception as e:
            print(f"模型加载失败: {e}")
            raise

    def _init_db(self):
        """初始化数据库连接"""
        try:
            self.connection = psycopg2.connect(**self.db_config)
            print("数据库连接成功")
        except Exception as e:
            print(f"数据库连接失败: {e}")
            raise

    def preprocess_image(self, image_path):
        """
        预处理图片以适应CLIP模型输入

        Args:
            image_path: 图片路径

        Returns:
            numpy.ndarray: 处理后的图片张量
        """
        try:
            image = Image.open(image_path).convert('RGB')
            image_tensor = self.transform(image)
            # 添加批次维度并转换为numpy数组
            image_tensor = image_tensor.unsqueeze(0).numpy()
            return image_tensor
        except Exception as e:
            print(f"图片预处理失败: {e}")
            raise

    def extract_features(self, image_path):
        """
        使用ONNX模型提取图片特征向量并降维

        Args:
            image_path: 图片路径

        Returns:
            list: 1536维特征向量
        """
        try:
            # 预处理图片
            image_tensor = self.preprocess_image(image_path)

            # 运行模型推理
            input_name = self.session.get_inputs()[0].name
            output = self.session.run(None, {input_name: image_tensor})

            # 获取特征向量并展平
            features = output[0].flatten()
            print(f"模型原始特征维度: {len(features)}")

            # 降维处理 - 使用平均池化将特征降到1536维
            target_dim = 1536
            if len(features) > target_dim:
                if len(features) % target_dim == 0:
                    # 如果是整数倍,使用平均池化
                    factor = len(features) // target_dim
                    features = features.reshape(target_dim, factor).mean(axis=1)
                else:
                    # 使用线性插值或其他降维方法
                    indices = np.linspace(0, len(features) - 1, target_dim).astype(int)
                    features = features[indices]
            elif len(features) < target_dim:
                # 如果维度不足,进行填充
                padded_features = np.zeros(target_dim)
                padded_features[:len(features)] = features
                features = padded_features

            print(f"处理后特征维度: {len(features)}")

            # 归一化特征向量
            norm = np.linalg.norm(features)
            if norm > 0:
                features = features / norm

            return features.tolist()
        except Exception as e:
            print(f"特征提取失败: {e}")
            raise


    def store_image_vector(self, image_path, image_id=None, metadata=None):
        """
        将图片特征向量存储到向量数据库

        Args:
            image_path: 图片路径
            image_id: 图片ID(可选)
            metadata: 元数据(可选)

        Returns:
            str: 图片ID
        """
        try:
            # 提取特征向量
            features = self.extract_features(image_path)

            # 动态检查维度(如果你确定模型输出应该是1536维,则保留此检查)
            expected_dim = 1536
            if len(features) != expected_dim:
                print(f"警告: 特征维度为{len(features)},与期望的{expected_dim}不符")
                # 可以选择是否继续处理或抛出异常

            # 生成ID
            if not image_id:
                image_id = str(uuid.uuid4())

            # 准备元数据
            if not metadata:
                metadata = {}
            metadata['filename'] = os.path.basename(image_path)
            metadata['image_id'] = image_id

            # 插入数据库
            cursor = self.connection.cursor()
            insert_query = """
                INSERT INTO images (id, content, metadata, embedding) 
                VALUES (%s, %s, %s, %s)
                ON CONFLICT (id) DO UPDATE SET
                content = EXCLUDED.content,
                metadata = EXCLUDED.metadata,
                embedding = EXCLUDED.embedding
            """
            cursor.execute(insert_query, (
                image_id,
                "Image feature vector",
                json.dumps(metadata),
                features  # 直接存储向量数组
            ))
            self.connection.commit()
            cursor.close()

            print(f"图片向量存储成功: {image_id}")
            return image_id
        except Exception as e:
            print(f"图片向量存储失败: {e}")
            raise


    def search_similar_images(self, image_path, top_k=5):
        """
        根据图片查询相似图片

        Args:
            image_path: 查询图片路径
            top_k: 返回最相似的K个结果

        Returns:
            list: 相似图片列表
        """
        try:
            # 提取查询图片的特征向量
            query_features = self.extract_features(image_path)



            # 查询相似图片
            cursor = self.connection.cursor()
            # 使用余弦相似度查询最相似的图片
            search_query = """
                SELECT id, content, metadata,
                       1 - (embedding <=> %s::vector) as similarity
                FROM images
                ORDER BY embedding <=> %s::vector
                LIMIT %s
            """




            cursor.execute(search_query, (
                query_features,
                query_features,
                top_k
            ))

            results = []
            for row in cursor.fetchall():
                id, content, metadata, similarity = row
                if metadata is None:
                    metadata_dict = {}
                elif isinstance(metadata, dict):
                    metadata_dict = metadata
                elif isinstance(metadata, (str, bytes)):
                    try:
                        metadata_dict = json.loads(metadata)
                    except (json.JSONDecodeError, TypeError):
                        metadata_dict = {}
                else:
                    metadata_dict = {}
                results.append({
                    'id': id,
                    'content': content,
                    'metadata': metadata_dict,
                    'similarity': float(similarity)
                })

            cursor.close()
            return results
        except Exception as e:
            error_msg = f"图片相似性查询失败: {str(e)}\n"
            error_msg += f"错误类型: {type(e).__name__}\n"
            error_msg += f"查询参数: image_path={image_path}, top_k={top_k}\n"
            error_msg += f"堆栈跟踪:\n{traceback.format_exc()}"
            print(error_msg)
            raise

    def batch_process_directory(self, directory_path):
        """
        批量处理目录中的所有图片

        Args:
            directory_path: 图片目录路径

        Returns:
            dict: 处理结果统计
        """
        supported_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.tiff')
        processed_count = 0
        failed_count = 0
        failed_files = []

        try:
            for filename in os.listdir(directory_path):
                if filename.lower().endswith(supported_extensions):
                    image_path = os.path.join(directory_path, filename)
                    try:
                        self.store_image_vector(image_path)
                        processed_count += 1
                        print(f"已处理: {filename}")
                    except Exception as e:
                        failed_count += 1
                        failed_files.append({
                            'filename': filename,
                            'error': str(e)
                        })
                        print(f"处理失败 {filename}: {e}")

            return {
                'processed_count': processed_count,
                'failed_count': failed_count,
                'failed_files': failed_files
            }
        except Exception as e:
            print(f"批量处理失败: {e}")
            raise

# 初始化服务
image_service = ImageVectorService(DB_CONFIG, MODEL_PATH)

@app.route('/health', methods=['GET'])
def health_check():
    """健康检查接口"""
    return jsonify({"status": "healthy", "message": "服务运行正常"})

@app.route('/batch_process', methods=['POST'])
def batch_process():
    """
    批量处理目录中的图片并存储向量

    请求参数:
    {
        "directory_path": "图片目录路径"
    }

    返回:
    {
        "processed_count": 处理成功数量,
        "failed_count": 处理失败数量,
        "failed_files": [失败文件列表]
    }
    """
    try:
        data = request.get_json()
        directory_path = IMAGE_BASE_PATH

        if not directory_path:
            return jsonify({"error": "缺少directory_path参数"}), 400

        if not os.path.exists(directory_path):
            return jsonify({"error": "目录不存在"}), 400

        result = image_service.batch_process_directory(directory_path)
        return jsonify(result)
    except Exception as e:
        return jsonify({"error": str(e)}), 500

@app.route('/search_similar', methods=['POST'])
def search_similar():
    """
    根据图片查询相似图片

    请求参数:
    {
        "image_path": "查询图片路径",
        "top_k": 返回结果数量(可选,默认5)
    }

    返回:
    [
        {
            "id": "图片ID",
            "content": "内容描述",
            "metadata": {元数据},
            "similarity": 相似度分数
        }
    ]
    """
    try:
        data = request.get_json()
        image_path = data.get('image_path')
        top_k = data.get('top_k', 5)

        image_path = os.path.join(IMAGE_BASE_PATH, image_path) ;

        if not image_path:
            return jsonify({"error": "缺少image_path参数"}), 400

        if not os.path.exists(image_path):
            return jsonify({"error": "图片文件不存在"}), 400

        results = image_service.search_similar_images(image_path, top_k)
        return jsonify(results)
    except Exception as e:
        return jsonify({"error": str(e)}), 500

@app.route('/store_single', methods=['POST'])
def store_single():
    """
    存储单张图片向量

    请求参数:
    {
        "image_path": "图片路径",
        "image_id": "图片ID(可选)",
        "metadata": {元数据(可选)}
    }

    返回:
    {
        "image_id": "图片ID"
    }
    """
    try:
        data = request.get_json()
        image_path = data.get('image_path')
        image_id = data.get('image_id')
        metadata = data.get('metadata', {})

        if not image_path:
            return jsonify({"error": "缺少image_path参数"}), 400

        if not os.path.exists(image_path):
            return jsonify({"error": "图片文件不存在"}), 400

        result_id = image_service.store_image_vector(image_path, image_id, metadata)
        return jsonify({"image_id": result_id})
    except Exception as e:
        return jsonify({"error": str(e)}), 500

if __name__ == '__main__':
    app.run(host='0.0.0.0', port=5002, debug=True, threaded=True)

2.4 使用命令运行

waitress-serve --host=0.0.0.0 --port=5001 app:app

posted on 2025-09-03 10:00  自由港  阅读(82)  评论(0)    收藏  举报