大模型--三种三种检索方式-Dense retrieval / Lexical Retrieval / Multi-Vector Retrieval- 44

1. 参考

M3-Embedding
https://github.com/FlagOpen/FlagEmbedding
https://arxiv.org/pdf/2402.03216
https://huggingface.co/BAAI/bge-m3

2. Dense retrieval

import torch
import torch.nn as nn

class DenseRetrieval(nn.Module):
    def __init__(self, embedding_dim):
        super(DenseRetrieval, self).__init__()
        self.query_encoder = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )
        self.doc_encoder = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )

    def forward(self, query_embeddings, doc_embeddings):
        query_vectors = self.query_encoder(query_embeddings)
        doc_vectors = self.doc_encoder(doc_embeddings)
        # 计算余弦相似度或其他相似度
        scores = torch.cosine_similarity(query_vectors.unsqueeze(1), doc_vectors.unsqueeze(0), dim=2)
        return scores

2. Lexical Retrieval

import torch
import torch.nn as nn

class DenseRetrieval(nn.Module):
    def __init__(self, embedding_dim):
        super(DenseRetrieval, self).__init__()
        self.query_encoder = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )
        self.doc_encoder = nn.Sequential(
            nn.Linear(embedding_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 64)
        )

    def forward(self, query_embeddings, doc_embeddings):
        query_vectors = self.query_encoder(query_embeddings)
        doc_vectors = self.doc_encoder(doc_embeddings)
        # 计算余弦相似度或其他相似度
        scores = torch.cosine_similarity(query_vectors.unsqueeze(1), doc_vectors.unsqueeze(0), dim=2)
        return scores

4. Multi-Vector Retrieval

class MultiVectorRetrieval(nn.Module):
    def __init__(self, embedding_dim, num_vectors):
        super(MultiVectorRetrieval, self).__init__()
        self.num_vectors = num_vectors
        self.projection = nn.Linear(embedding_dim, embedding_dim * num_vectors)
        
    def forward(self, query_embeddings, doc_embeddings):
        projected_query = self.projection(query_embeddings).view(-1, self.num_vectors, embedding_dim)
        projected_doc = self.projection(doc_embeddings).view(-1, self.num_vectors, embedding_dim)
        
        # 对每个向量计算相似度并取最大值
        similarities = torch.bmm(projected_query, projected_doc.transpose(1, 2))
        max_similarities, _ = torch.max(similarities, dim=-1)
        avg_similarity = torch.mean(max_similarities, dim=1)
        
        return avg_similarity
posted @ 2025-02-27 19:11  jack-chen666  阅读(109)  评论(0)    收藏  举报