Spring AI 学习之路 探索Spring AI中的嵌入模型(Embedding Model)

在现代人工智能和机器学习应用中,嵌入模型(Embedding Model)扮演着至关重要的角色。嵌入模型能够将高维度的数据(如文本、图像等)转换为低维度的向量表示,从而使得这些数据能够在机器学习模型中被有效地处理和利用。Spring AI作为一个强大的AI框架,提供了对嵌入模型的全面支持。本文将深入探讨Spring AI中的嵌入模型,并通过代码示例展示如何使用它们。

什么是嵌入模型?

嵌入模型是一种将离散数据(如单词、句子、图像等)转换为连续向量空间中的向量的技术。这些向量不仅能够捕捉到数据的语义信息,还能够在向量空间中进行数学运算,从而揭示数据之间的关系。例如,通过嵌入模型,我们可以将单词“king”转换为一个向量,然后通过向量运算得到“queen”的向量表示。

Spring AI中的嵌入模型

Spring AI提供了多种嵌入模型的实现,包括预训练模型和自定义模型。这些模型可以用于各种任务,如文本分类、情感分析、图像识别等。Spring AI的嵌入模型接口设计得非常灵活,使得开发者可以轻松地集成和使用这些模型。

API介绍

EmbeddingModel

在Spring AI中,EmbeddingModel 是一个核心接口,用于表示嵌入模型。它定义了如何将输入数据(如文本)转换为嵌入向量。以下是 EmbeddingModel 接口的主要方法:

public interface EmbeddingModel extends Model<EmbeddingRequest, EmbeddingResponse> {

	@Override
	EmbeddingResponse call(EmbeddingRequest request);

	/**
	 * Embeds the given text into a vector.
	 * @param text the text to embed.
	 * @return the embedded vector.
	 */
	default float[] embed(String text) {
		Assert.notNull(text, "Text must not be null");
		List<float[]> response = this.embed(List.of(text));
		return response.iterator().next();
	}

	/**
	 * Embeds the given document's content into a vector.
	 * @param document the document to embed.
	 * @return the embedded vector.
	 */
	float[] embed(Document document);

	/**
	 * Embeds a batch of texts into vectors.
	 * @param texts list of texts to embed.
	 * @return list of embedded vectors.
	 */
	default List<float[]> embed(List<String> texts) {
		Assert.notNull(texts, "Texts must not be null");
		return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build()))
			.getResults()
			.stream()
			.map(Embedding::getOutput)
			.toList();
	}

	/**
	 * Embeds a batch of {@link Document}s into vectors based on a
	 * {@link BatchingStrategy}.
	 * @param documents list of {@link Document}s.
	 * @param options {@link EmbeddingOptions}.
	 * @param batchingStrategy {@link BatchingStrategy}.
	 * @return a list of float[] that represents the vectors for the incoming
	 * {@link Document}s. The returned list is expected to be in the same order of the
	 * {@link Document} list.
	 */
	default List<float[]> embed(List<Document> documents, EmbeddingOptions options, BatchingStrategy batchingStrategy) {
		Assert.notNull(documents, "Documents must not be null");
		List<float[]> embeddings = new ArrayList<>(documents.size());
		List<List<Document>> batch = batchingStrategy.batch(documents);
		for (List<Document> subBatch : batch) {
			List<String> texts = subBatch.stream().map(Document::getText).toList();
			EmbeddingRequest request = new EmbeddingRequest(texts, options);
			EmbeddingResponse response = this.call(request);
			for (int i = 0; i < subBatch.size(); i++) {
				embeddings.add(response.getResults().get(i).getOutput());
			}
		}
		Assert.isTrue(embeddings.size() == documents.size(),
				"Embeddings must have the same number as that of the documents");
		return embeddings;
	}

	/**
	 * Embeds a batch of texts into vectors and returns the {@link EmbeddingResponse}.
	 * @param texts list of texts to embed.
	 * @return the embedding response.
	 */
	default EmbeddingResponse embedForResponse(List<String> texts) {
		Assert.notNull(texts, "Texts must not be null");
		return this.call(new EmbeddingRequest(texts, EmbeddingOptionsBuilder.builder().build()));
	}

	/**
	 * Get the number of dimensions of the embedded vectors. Note that by default, this
	 * method will call the remote Embedding endpoint to get the dimensions of the
	 * embedded vectors. If the dimensions are known ahead of time, it is recommended to
	 * override this method.
	 * @return the number of dimensions of the embedded vectors.
	 */
	default int dimensions() {
		return embed("Test String").length;
	}

}

2. EmbeddingRequest 和 EmbeddingResponse

EmbeddingRequestEmbeddingResponse 是用于处理嵌入请求和响应的类。
EmbeddingRequest: 包含一个或多个输入文本,用于生成嵌入向量。
EmbeddingResponse: 包含生成的嵌入向量。

public class EmbeddingRequest implements ModelRequest<List<String>> {

	private final List<String> inputs;

	private final EmbeddingOptions options;
  ...
}
public class EmbeddingResponse implements ModelResponse<Embedding> {

	/**
	 * Embedding data.
	 */
	private final List<Embedding> embeddings;

	/**
	 * Embedding metadata.
	 */
	private final EmbeddingResponseMetadata metadata;
	...}

3. Embedding 类

Embedding 类表示一个嵌入向量,通常是一个浮点数数组。

public class Embedding implements ModelResult<float[]> {

	private float[] embedding;

	private Integer index;

	private EmbeddingResultMetadata metadata;
	...
}

使用阿里的 EmbeddingModel 进行演示

我们在之前的例子中已经引入了阿里的服务,我们可以通过实现 EmbeddingModel 接口来集成这个服务。以下是一个示例:

public class EmbeddingController {
    // 阿里嵌入模型
    private final EmbeddingModel embeddingModel;

    /**
     * 将句子向量化
     * @param prompt 用户提问
     * @return 向量值
     */
    @GetMapping("embed")
    public List<float[]> embed(@RequestParam String prompt) {
        // 文本嵌入
        float[] embed = embeddingModel.embed(prompt);
        return Arrays.asList(embed);
    }
}
posted @ 2025-03-05 09:16  brother_four  阅读(728)  评论(0)    收藏  举报