annoy_word2vec向量索引
一、向量检索工具annoy(还有KD-tree、Faiss)
from annoy import AnnoyIndex
注意:AnnoyIndex中add_item加入的是item的idx
怎么构建annoy索引
1.
from annoy_embedding import annoy_ir
aiu = annoy_ir(tc_wv_model)#tc_wv_model就是腾讯的Tencent_AILab_ChineseEmbedding.bin
tc_index,reverse_word_index = aiu.query2index(querys)#querys就是推荐词列表
aiu.saveindex(path + 'pre_embedding_recall/hotword/',tc_index,reverse_word_index)
class annoy_ir(object):
def __init__(self,wordembedding):
self._wordembedding = wordembedding
#query转化为embedding
def query2embedding(self,query):
QuerySplit = splitword_jieba(APPDataPro(str(query)))#是一个列表,里面为切分得到的词
if any([tf != '' and tf in self._wordembedding.index_to_key for tf in QuerySplit]):#有一个列表元素为真any([])就为真,也就是不为空且在所有向量的词袋里
qlen = len([tf for tf in QuerySplit if tf in self._wordembedding.index_to_key ])#能找到embedding的词语个数
return True,sum([self._wordembedding[tf] for tf in QuerySplit if tf in self._wordembedding.index_to_key ])/qlen
else:
return False,0
#md5生成
def get_md5_02(self,file_path):
try:
f = open(file_path,'rb')
md5_obj = hashlib.md5()
while True:
d = f.read(8096)
if not d:
break
md5_obj.update(d)
hash_code = md5_obj.hexdigest()
f.close()
md5 = str(hash_code).lower()#md5感觉是一个字符串呀。为什么要读入一个文件生成md5
return md5
except:
return False
#md5保存
def savemd5(self,file_path,md5):
f = open(file_path,'w',encoding='utf-8')
f.write(str(md5))
f.close()
#md5加载
def loadmd5(self,file_path):
with open(file_path) as f:
return f.readline()[:-1]
return False
#索引构建。索引构建依据的也是querys,构建的是这批query的索引
def query2index(self,querys):
word_index = OrderedDict() #from collections import OrderedDict
tc_index = AnnoyIndex(200) #from annoy import AnnoyIndex
i = 0
for qi in querys:
f,e_qi = self.query2embedding(qi)
if f:
word_index[qi] = str(i)#word_index{qi:"0"}
tc_index.add_item(i, e_qi) #e_qi:embedding of query i ,tc_index里面加入的item是(0,e_q0)(1,e_q1)
i += 1
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])
tc_index.build(10)
return tc_index,reverse_word_index#tc_index里面的item("0",e_q0),e_q0表示query0的embedding,reverse_word_index里面是{“0”,q0,"1",q1},q0以及q1都是原始字符串
#保存索引
def saveindex(self,savepath,tc_index,reverse_word_index):
with open(savepath + 'reverse_word_index.json','w') as fp:
json.dump(reverse_word_index, fp)
self.savemd5(savepath + 'reverse_word_index.md',self.get_md5_02(savepath + 'reverse_word_index.json'))
tc_index.save(savepath + 'annoy_tc.index')
self.savemd5(savepath + 'annoy_tc.md',self.get_md5_02(savepath + 'annoy_tc.index'))
#索引加载
def loadindex(self,loadpath,loadfarpath):
oldmd5_reverse_word_index = ''
oldmd5_annoy_tc = ''
if os.path.exists(loadpath + 'reverse_word_index.md') and os.path.exists(loadpath + 'annoy_tc.md'):
oldmd5_reverse_word_index = self.loadmd5(loadpath + 'reverse_word_index.md')
oldmd5_annoy_tc = self.loadmd5(loadpath + 'annoy_tc.md')
newmd5_reverse_word_index = self.loadmd5(loadfarpath + 'reverse_word_index.md')
newmd5_annoy_tc = self.loadmd5(loadfarpath + 'annoy_tc.md')
if not (oldmd5_reverse_word_index==newmd5_reverse_word_index and oldmd5_annoy_tc==newmd5_annoy_tc and newmd5_annoy_tc != 'False'):
tc_index = AnnoyIndex(200)
tc_index.load(loadfarpath + 'annoy_tc.index')
with open(loadfarpath + 'reverse_word_index.json','r') as fp:
reverse_word_index = json.load(fp)
self.saveindex(loadpath,tc_index,reverse_word_index)
else:
try:
tc_index = AnnoyIndex(200)
tc_index.load(loadpath + 'annoy_tc.index')
with open(loadpath + 'reverse_word_index.json','r') as fp:
reverse_word_index = json.load(fp)
except:
return "",""
return tc_index,reverse_word_index
#向量召回。返回的是一个列表,列表里面元素为一些embedding?
def annoy_indexrecall(self,query,recallnum,tc_index,reverse_word_index):
f,qe = self.query2embedding(query)#要看能不能得到query的embedding,如果可以的话再去索引里面召回最近邻
if f:
return [reverse_word_index[str(item)] for item in tc_index.get_nns_by_vector(qe, recallnum)]#tc_index是AnnoyIndex类的实例,AnnoyIndex位于annoy模块
return []
2.建立annoy索引
from annoy_embedding import annoy_ir
tc_index = AnnoyIndex(5)#括号里面的参数代表向量的长度
tc_index.add_item(i, e_qi) #i是索引为整数类型,e_qi是他的向量表示,可以是一个list
加完了之后 tc_index.build(10):另外n_trees这个参数很关键,官方文档是这样说的: # n_trees is provided during build time and affects the build time and the index size. # A larger value will give more accurate results, but larger indexes. # 这里首次使用没啥经验,按文档里的是10设置
得到tc_index,reverse_word_index#tc_index里面的item("0",e_q0),e_q0表示query0的embedding,reverse_word_index里面是{“0”,q0,"1",q1},q0以及q1都是原始字符串
保存,将reverse_word_index保存为json文件,将json文件保存为md5.
tc_index.save(savepath + 'annoy_tc.index'),然后将annoy_tc.index这个文件变成annoy_tc.md
md5_obj = hashlib.md5()
md5_obj.update(d)
hash_code = md5_obj.hexdigest()
md5 = str(hash_code).lower()
f = open(file_path,'w',encoding='utf-8')
f.write(str(md5))
加载md5文件得到 tc_index(Annoy_index类)、reverse_word_index(dict)
浙公网安备 33010602011771号