1.概述

嵌入式模型是将文本向量化,是向量搜索的首要条件。我们可以使用在线的向量模型,将数据发送给 模型厂家的 接口实现向量化,比如智谱清言就支持文本向量化。 如果 可以使用本地的词嵌入模型,那么可以减少到 服务的依赖,性能也更有保证,另外可以省费用。

我们以北京智源的 BAAI/bge-large-zh-v1.5为例,讲一下如何下载模型,并将模型转换成java 可以使用的模型。

2.下载

模型一般都放在 国外 的huggingface 上,所以下载需要使用 https://hf-mirror.com/ 做为代理下载。

下载之前 我们先安装 python 环境。
可以安装 python 3.10.11 的版本。

安装
pip install huggingface_hub

from huggingface_hub import snapshot_download
 
snapshot_download(
    repo_id="BAAI/bge-large-zh-v1.5",
    local_dir="./models",
    endpoint="https://hf-mirror.com",
    local_dir_use_symlinks=False  # 直接复制文件,便于打包
)

执行上面的脚本下载模型。

image

下载后这些还不能直接使用,将模型导出为ONNX模型文件。

3. 导出ONNX 模型

需要安装 依赖

pip install transformers onnx onnxruntime torch

# export_bge_to_onnx.py

from transformers import AutoTokenizer, AutoModel
from transformers.onnx import export, OnnxConfig
from pathlib import Path
import torch
import logging

# 设置日志,便于调试
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ========================================
# 1. 配置参数
# ========================================

# 模型名称或本地路径(如果你已经下载到本地)
#model_name_or_path = "./BAAI/bge-large-zh-v1.5"
# 如果是本地路径,例如:
model_name_or_path = "./models"

# ONNX 输出目录
output_dir = "./onnx/bge-large-zh-v1.5-onnx"

# ONNX 模型文件名
onnx_model_path = Path(output_dir) / "model.onnx"

# ONNX opset 版本(推荐 13 或以上,兼容性好)
opset = 14

# 是否使用 GPU 导出(如果模型支持且有 GPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"使用设备: {device}")

# ========================================
# 2. 加载 Tokenizer 和 模型
# ========================================

logger.info(f"加载 Tokenizer: {model_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

logger.info(f"加载模型: {model_name_or_path}")
model = AutoModel.from_pretrained(model_name_or_path)

# 移动模型到指定设备(GPU 或 CPU)
model.to(device)

# 设置为 eval 模式(关闭 dropout 等训练相关操作)
model.eval()

# ========================================
# 3. 定义 ONNX 配置(可选:自定义输入输出)
# ========================================

# BGE 是一个 sentence-transformers 模型,结构是标准的 BERT-like
# 所以我们可以使用默认的 OnnxConfig,但也可以自定义


class BgeOnnxConfig(OnnxConfig):
    def __init__(self, config):
        super().__init__(config)

    @property
    def inputs(self):
        return {
            "input_ids": {0: "batch", 1: "sequence"},
            "attention_mask": {0: "batch", 1: "sequence"},
            "token_type_ids": {0: "batch", 1: "sequence"},  # ✅ 显式添加
        }

    @property
    def outputs(self):
        return {
            "last_hidden_state": {0: "batch", 1: "sequence"},
        }

# 创建配置实例时传入 model.config
onnx_config = BgeOnnxConfig(config=model.config)  



# ========================================
# 4. 创建输出目录
# ========================================

onnx_model_path.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"ONNX 输出目录: {onnx_model_path.parent}")

# ========================================
# 5. 导出模型为 ONNX
# ========================================

logger.info("开始导出模型为 ONNX...")

try:
    export(
        preprocessor=tokenizer,           # 提供 tokenizer 用于生成 dummy input
        model=model,                      # 要导出的 PyTorch 模型
        config=onnx_config,               # ONNX 输入输出配置
        opset=opset,                      # ONNX opset 版本
        output=onnx_model_path,           # 输出文件路径
    )
    logger.info(f"✅ 模型成功导出到: {onnx_model_path}")
except Exception as e:
    logger.error(f"❌ 导出失败: {e}")
    raise



# ========================================
# 6. (可选)验证 ONNX 模型是否可用
# ========================================
try:
    import onnxruntime as ort
    # 检查 ONNX 模型是否结构正确
    ort_session = ort.InferenceSession(str(onnx_model_path))
    # 获取输入名称
    input_names = [inp.name for inp in ort_session.get_inputs()]
    logger.info(f"ONNX 模型输入: {input_names}")

    # 使用 tokenizer 编码一个示例句子
    inputs = tokenizer(
        "这是一个测试句子。",
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="np"  # ONNX Runtime 使用 numpy
    )

    # 准备输入数据(转换为 numpy)
    onnx_inputs = {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
    }

    # ✅ 关键:如果模型需要 token_type_ids,补上全 0 的数组
    if "token_type_ids" in input_names:
        # BERT/BGE 使用 token_type_ids 区分句子对,单句任务中全为 0
        import numpy as np
        onnx_inputs["token_type_ids"] = np.zeros_like(inputs["input_ids"])  # 形状一致,全 0
        logger.info(f"已添加 token_type_ids (全 0),形状: {onnx_inputs['token_type_ids'].shape}")

    # 运行推理
    outputs = ort_session.run(None, onnx_inputs)
    # 输出结果形状
    logger.info(f"ONNX 推理成功!输出形状: {outputs[0].shape}")
except Exception as e:
    logger.warning(f"⚠️ ONNX 验证失败: {e}")
    raise  # 可选:让错误暴露出来,便于调试

logger.info("🎉 ONNX 导出和验证完成!")

脚本 中有两个变量可以做调整

  • 模型目录
    model_name_or_path = "./models"
  • 输出目录
    output_dir = "./onnx/bge-large-zh-v1.5-onnx"

导出后的效果
image

在项目中使用自定义嵌入模型

使用 starter 实现

  • 引入依赖包
<dependency>
	<groupId>org.springframework.ai</groupId>
	<artifactId>spring-ai-starter-model-transformers</artifactId>
</dependency>
  • 增加配置
spring:
	ai:
		model:
		  embedding: transformers
		embedding:
		  transformer:
			onnx:
			  model-uri: classpath:/model/onnx/model.onnx
			tokenizer:
			  uri: classpath:/model/onnx/tokenizer.json

这样就可以使用自定义嵌入模型了

使用手动方式实现

  • 导入依赖包
<dependency>
  <groupId>org.springframework.ai</groupId>
  <artifactId>spring-ai-transformers</artifactId>
</dependency>
@Configuration
@EnableConfigurationProperties
public class EmbeddingModelConfig {

    @Bean
    public EmbeddingModel embeddingModel() throws Exception {
        TransformersEmbeddingModel embeddingModel = new TransformersEmbeddingModel();

        // 设置 ONNX 模型路径 (从 classpath 加载)
        embeddingModel.setModelResource(new ClassPathResource("model/onnx/model.onnx"));
        // 设置分词器配置路径
        embeddingModel.setTokenizerResource(new ClassPathResource("model/onnx/tokenizer.json"));

        // 可选:设置其他分词选项,例如启用填充
        embeddingModel.setTokenizerOptions(java.util.Map.of("padding", "true"));

//        embeddingModel.setGpuDeviceId(0);

        // 初始化模型
        embeddingModel.afterPropertiesSet();

        return embeddingModel;
    }
}
  • 需要注意的点

这个和 cuda 有关,所以在部署的时候,需要部署对应 的 cuda tookit 版本,我之前部署了 12.9 的版本,发现报错,改成 12.4 的版本就可以了。

posted on 2025-08-28 12:53  自由港  阅读(41)  评论(0)    收藏  举报