annoy_word2vec向量索引

一、向量检索工具annoy(还有KD-tree、Faiss)

from annoy import AnnoyIndex

注意:AnnoyIndex中add_item加入的是item的idx

怎么构建annoy索引

1.

from annoy_embedding import annoy_ir

aiu = annoy_ir(tc_wv_model)#tc_wv_model就是腾讯的Tencent_AILab_ChineseEmbedding.bin

tc_index,reverse_word_index = aiu.query2index(querys)#querys就是推荐词列表

aiu.saveindex(path + 'pre_embedding_recall/hotword/',tc_index,reverse_word_index)

class annoy_ir(object):
    def __init__(self,wordembedding):
        self._wordembedding = wordembedding
    #query转化为embedding
    def query2embedding(self,query):
        QuerySplit = splitword_jieba(APPDataPro(str(query)))#是一个列表,里面为切分得到的词
        if any([tf != '' and tf in self._wordembedding.index_to_key for tf in QuerySplit]):#有一个列表元素为真any([])就为真,也就是不为空且在所有向量的词袋里
            qlen = len([tf for tf in QuerySplit if tf in self._wordembedding.index_to_key ])#能找到embedding的词语个数
            return True,sum([self._wordembedding[tf] for tf in QuerySplit if tf in self._wordembedding.index_to_key ])/qlen
        else:
            return False,0
    #md5生成
    def get_md5_02(self,file_path):
        try:
            f = open(file_path,'rb')  
            md5_obj = hashlib.md5()
            while True:
                d = f.read(8096)
                if not d:
                    break
                md5_obj.update(d)
            hash_code = md5_obj.hexdigest()
            f.close()
            md5 = str(hash_code).lower()#md5感觉是一个字符串呀。为什么要读入一个文件生成md5
            return md5
        except:
            return False
    #md5保存
    def savemd5(self,file_path,md5):
        f = open(file_path,'w',encoding='utf-8')
        f.write(str(md5))
        f.close()
    #md5加载
    def loadmd5(self,file_path):
        with open(file_path) as f:
            return f.readline()[:-1]
        return False
    #索引构建。索引构建依据的也是querys,构建的是这批query的索引
    def query2index(self,querys):
        word_index = OrderedDict() #from collections import OrderedDict
        tc_index = AnnoyIndex(200) #from annoy import AnnoyIndex
        i = 0 
        for qi in querys: 
            f,e_qi = self.query2embedding(qi) 
            if f:
                word_index[qi] = str(i)#word_index{qi:"0"}
                tc_index.add_item(i, e_qi) #e_qi:embedding of query i ,tc_index里面加入的item是(0,e_q0)(1,e_q1)
                i += 1 
        reverse_word_index = dict([(value, key) for (key, value) in word_index.items()]) 
        tc_index.build(10)
        return tc_index,reverse_word_index#tc_index里面的item("0",e_q0),e_q0表示query0的embedding,reverse_word_index里面是{“0”,q0,"1",q1},q0以及q1都是原始字符串
    #保存索引
    def saveindex(self,savepath,tc_index,reverse_word_index):
        with open(savepath + 'reverse_word_index.json','w') as fp:
            json.dump(reverse_word_index, fp)
        self.savemd5(savepath + 'reverse_word_index.md',self.get_md5_02(savepath + 'reverse_word_index.json'))
        tc_index.save(savepath + 'annoy_tc.index')
        self.savemd5(savepath + 'annoy_tc.md',self.get_md5_02(savepath + 'annoy_tc.index'))
        
    #索引加载
    def loadindex(self,loadpath,loadfarpath):
        oldmd5_reverse_word_index = ''
        oldmd5_annoy_tc = ''
        if os.path.exists(loadpath + 'reverse_word_index.md') and os.path.exists(loadpath + 'annoy_tc.md'):
            oldmd5_reverse_word_index = self.loadmd5(loadpath + 'reverse_word_index.md')
            oldmd5_annoy_tc = self.loadmd5(loadpath + 'annoy_tc.md')
        newmd5_reverse_word_index = self.loadmd5(loadfarpath + 'reverse_word_index.md')
        newmd5_annoy_tc = self.loadmd5(loadfarpath + 'annoy_tc.md')
        if not (oldmd5_reverse_word_index==newmd5_reverse_word_index and oldmd5_annoy_tc==newmd5_annoy_tc and newmd5_annoy_tc != 'False'):
            tc_index = AnnoyIndex(200) 
            tc_index.load(loadfarpath + 'annoy_tc.index') 
            with open(loadfarpath + 'reverse_word_index.json','r') as fp: 
                reverse_word_index = json.load(fp)
            self.saveindex(loadpath,tc_index,reverse_word_index)
        else:
            try:
                tc_index = AnnoyIndex(200)
                tc_index.load(loadpath + 'annoy_tc.index')
                with open(loadpath + 'reverse_word_index.json','r') as fp:
                    reverse_word_index = json.load(fp)
            except:
                return "",""
        return tc_index,reverse_word_index
    #向量召回。返回的是一个列表,列表里面元素为一些embedding?
    def annoy_indexrecall(self,query,recallnum,tc_index,reverse_word_index):
        f,qe = self.query2embedding(query)#要看能不能得到query的embedding,如果可以的话再去索引里面召回最近邻
        if f:
            return [reverse_word_index[str(item)] for item in tc_index.get_nns_by_vector(qe, recallnum)]#tc_index是AnnoyIndex类的实例,AnnoyIndex位于annoy模块
        return []

  

2.建立annoy索引

from annoy_embedding import annoy_ir

tc_index = AnnoyIndex(5)#括号里面的参数代表向量的长度

tc_index.add_item(i, e_qi) #i是索引为整数类型,e_qi是他的向量表示,可以是一个list

加完了之后 tc_index.build(10):另外n_trees这个参数很关键,官方文档是这样说的: # n_trees is provided during build time and affects the build time and the index size. # A larger value will give more accurate results, but larger indexes. # 这里首次使用没啥经验,按文档里的是10设置

得到tc_index,reverse_word_index#tc_index里面的item("0",e_q0),e_q0表示query0的embedding,reverse_word_index里面是{“0”,q0,"1",q1},q0以及q1都是原始字符串

保存,将reverse_word_index保存为json文件,将json文件保存为md5.

tc_index.save(savepath + 'annoy_tc.index'),然后将annoy_tc.index这个文件变成annoy_tc.md

md5_obj = hashlib.md5()

md5_obj.update(d)

hash_code = md5_obj.hexdigest()

 md5 = str(hash_code).lower()

f = open(file_path,'w',encoding='utf-8')
f.write(str(md5))

加载md5文件得到 tc_index(Annoy_index类)、reverse_word_index(dict)

posted @ 2022-04-28 20:31  HappierJoanne  阅读(226)  评论(0)    收藏  举报