annoy_word2vec向量索引
一、向量检索工具annoy(还有KD-tree、Faiss)
from annoy import AnnoyIndex
注意:AnnoyIndex中add_item加入的是item的idx
怎么构建annoy索引
1.
from annoy_embedding import annoy_ir
aiu = annoy_ir(tc_wv_model)#tc_wv_model就是腾讯的Tencent_AILab_ChineseEmbedding.bin
tc_index,reverse_word_index = aiu.query2index(querys)#querys就是推荐词列表
aiu.saveindex(path + 'pre_embedding_recall/hotword/',tc_index,reverse_word_index)
class annoy_ir(object): def __init__(self,wordembedding): self._wordembedding = wordembedding #query转化为embedding def query2embedding(self,query): QuerySplit = splitword_jieba(APPDataPro(str(query)))#是一个列表,里面为切分得到的词 if any([tf != '' and tf in self._wordembedding.index_to_key for tf in QuerySplit]):#有一个列表元素为真any([])就为真,也就是不为空且在所有向量的词袋里 qlen = len([tf for tf in QuerySplit if tf in self._wordembedding.index_to_key ])#能找到embedding的词语个数 return True,sum([self._wordembedding[tf] for tf in QuerySplit if tf in self._wordembedding.index_to_key ])/qlen else: return False,0 #md5生成 def get_md5_02(self,file_path): try: f = open(file_path,'rb') md5_obj = hashlib.md5() while True: d = f.read(8096) if not d: break md5_obj.update(d) hash_code = md5_obj.hexdigest() f.close() md5 = str(hash_code).lower()#md5感觉是一个字符串呀。为什么要读入一个文件生成md5 return md5 except: return False #md5保存 def savemd5(self,file_path,md5): f = open(file_path,'w',encoding='utf-8') f.write(str(md5)) f.close() #md5加载 def loadmd5(self,file_path): with open(file_path) as f: return f.readline()[:-1] return False #索引构建。索引构建依据的也是querys,构建的是这批query的索引 def query2index(self,querys): word_index = OrderedDict() #from collections import OrderedDict tc_index = AnnoyIndex(200) #from annoy import AnnoyIndex i = 0 for qi in querys: f,e_qi = self.query2embedding(qi) if f: word_index[qi] = str(i)#word_index{qi:"0"} tc_index.add_item(i, e_qi) #e_qi:embedding of query i ,tc_index里面加入的item是(0,e_q0)(1,e_q1) i += 1 reverse_word_index = dict([(value, key) for (key, value) in word_index.items()]) tc_index.build(10) return tc_index,reverse_word_index#tc_index里面的item("0",e_q0),e_q0表示query0的embedding,reverse_word_index里面是{“0”,q0,"1",q1},q0以及q1都是原始字符串 #保存索引 def saveindex(self,savepath,tc_index,reverse_word_index): with open(savepath + 'reverse_word_index.json','w') as fp: json.dump(reverse_word_index, fp) self.savemd5(savepath + 'reverse_word_index.md',self.get_md5_02(savepath + 'reverse_word_index.json')) tc_index.save(savepath + 'annoy_tc.index') self.savemd5(savepath + 'annoy_tc.md',self.get_md5_02(savepath + 'annoy_tc.index')) #索引加载 def loadindex(self,loadpath,loadfarpath): oldmd5_reverse_word_index = '' oldmd5_annoy_tc = '' if os.path.exists(loadpath + 'reverse_word_index.md') and os.path.exists(loadpath + 'annoy_tc.md'): oldmd5_reverse_word_index = self.loadmd5(loadpath + 'reverse_word_index.md') oldmd5_annoy_tc = self.loadmd5(loadpath + 'annoy_tc.md') newmd5_reverse_word_index = self.loadmd5(loadfarpath + 'reverse_word_index.md') newmd5_annoy_tc = self.loadmd5(loadfarpath + 'annoy_tc.md') if not (oldmd5_reverse_word_index==newmd5_reverse_word_index and oldmd5_annoy_tc==newmd5_annoy_tc and newmd5_annoy_tc != 'False'): tc_index = AnnoyIndex(200) tc_index.load(loadfarpath + 'annoy_tc.index') with open(loadfarpath + 'reverse_word_index.json','r') as fp: reverse_word_index = json.load(fp) self.saveindex(loadpath,tc_index,reverse_word_index) else: try: tc_index = AnnoyIndex(200) tc_index.load(loadpath + 'annoy_tc.index') with open(loadpath + 'reverse_word_index.json','r') as fp: reverse_word_index = json.load(fp) except: return "","" return tc_index,reverse_word_index #向量召回。返回的是一个列表,列表里面元素为一些embedding? def annoy_indexrecall(self,query,recallnum,tc_index,reverse_word_index): f,qe = self.query2embedding(query)#要看能不能得到query的embedding,如果可以的话再去索引里面召回最近邻 if f: return [reverse_word_index[str(item)] for item in tc_index.get_nns_by_vector(qe, recallnum)]#tc_index是AnnoyIndex类的实例,AnnoyIndex位于annoy模块 return []
2.建立annoy索引
from annoy_embedding import annoy_ir
tc_index = AnnoyIndex(5)#括号里面的参数代表向量的长度
tc_index.add_item(i, e_qi) #i是索引为整数类型,e_qi是他的向量表示,可以是一个list
加完了之后 tc_index.build(10):另外n_trees这个参数很关键,官方文档是这样说的: # n_trees is provided during build time and affects the build time and the index size. # A larger value will give more accurate results, but larger indexes. # 这里首次使用没啥经验,按文档里的是10设置
得到tc_index,reverse_word_index#tc_index里面的item("0",e_q0),e_q0表示query0的embedding,reverse_word_index里面是{“0”,q0,"1",q1},q0以及q1都是原始字符串
保存,将reverse_word_index保存为json文件,将json文件保存为md5.
tc_index.save(savepath + 'annoy_tc.index'),然后将annoy_tc.index这个文件变成annoy_tc.md
md5_obj = hashlib.md5()
md5_obj.update(d)
hash_code = md5_obj.hexdigest()
md5 = str(hash_code).lower()
f = open(file_path,'w',encoding='utf-8')
f.write(str(md5))
加载md5文件得到 tc_index(Annoy_index类)、reverse_word_index(dict)