大模型--三种三种检索方式-Dense retrieval / Lexical Retrieval / Multi-Vector Retrieval- 44
1. 参考
M3-Embedding
https://github.com/FlagOpen/FlagEmbedding
https://arxiv.org/pdf/2402.03216
https://huggingface.co/BAAI/bge-m3
2. Dense retrieval

import torch
import torch.nn as nn
class DenseRetrieval(nn.Module):
def __init__(self, embedding_dim):
super(DenseRetrieval, self).__init__()
self.query_encoder = nn.Sequential(
nn.Linear(embedding_dim, 128),
nn.ReLU(),
nn.Linear(128, 64)
)
self.doc_encoder = nn.Sequential(
nn.Linear(embedding_dim, 128),
nn.ReLU(),
nn.Linear(128, 64)
)
def forward(self, query_embeddings, doc_embeddings):
query_vectors = self.query_encoder(query_embeddings)
doc_vectors = self.doc_encoder(doc_embeddings)
# 计算余弦相似度或其他相似度
scores = torch.cosine_similarity(query_vectors.unsqueeze(1), doc_vectors.unsqueeze(0), dim=2)
return scores
2. Lexical Retrieval

import torch
import torch.nn as nn
class DenseRetrieval(nn.Module):
def __init__(self, embedding_dim):
super(DenseRetrieval, self).__init__()
self.query_encoder = nn.Sequential(
nn.Linear(embedding_dim, 128),
nn.ReLU(),
nn.Linear(128, 64)
)
self.doc_encoder = nn.Sequential(
nn.Linear(embedding_dim, 128),
nn.ReLU(),
nn.Linear(128, 64)
)
def forward(self, query_embeddings, doc_embeddings):
query_vectors = self.query_encoder(query_embeddings)
doc_vectors = self.doc_encoder(doc_embeddings)
# 计算余弦相似度或其他相似度
scores = torch.cosine_similarity(query_vectors.unsqueeze(1), doc_vectors.unsqueeze(0), dim=2)
return scores
4. Multi-Vector Retrieval

class MultiVectorRetrieval(nn.Module):
def __init__(self, embedding_dim, num_vectors):
super(MultiVectorRetrieval, self).__init__()
self.num_vectors = num_vectors
self.projection = nn.Linear(embedding_dim, embedding_dim * num_vectors)
def forward(self, query_embeddings, doc_embeddings):
projected_query = self.projection(query_embeddings).view(-1, self.num_vectors, embedding_dim)
projected_doc = self.projection(doc_embeddings).view(-1, self.num_vectors, embedding_dim)
# 对每个向量计算相似度并取最大值
similarities = torch.bmm(projected_query, projected_doc.transpose(1, 2))
max_similarities, _ = torch.max(similarities, dim=-1)
avg_similarity = torch.mean(max_similarities, dim=1)
return avg_similarity

浙公网安备 33010602011771号