DBSCAN算法

dbscan是一个简单的聚类算法,之前做了word2vec后需要找相似的词,于是自己用python实现了一个;

这个算法很简单也很快,sklearn也有实现方法,但是我需要自定义距离的计算方式,sklearn貌似不提供,所以自己写了一个类,见下面代码;有需要随意拿去用;

import numpy as np 
import gensim 
w2v = gensim.models.word2vec.Word2Vec.load('/data1/yanjianfeng/sb')

class Dbscan:

    def __init__(self, eps = 0.8, min_count = 3, if_w2v = False):
        self.eps = eps
        self.min_count = min_count
        self.if_w2v = if_w2v

    def simi(self, sent_a, sent_b):
        sent_a = sent_a.split(' ')
        sent_b = sent_b.split(' ')
        simis = 2.0*sum([(i in sent_b) for i in sent_a])/(len(sent_a) + len(sent_b))
        return simis 

    def simi_w2v(self, v_1, v_2):
        simis = sum(v_1*v_2)/(np.sqrt(sum(v_1**2)*sum(v_2**2)))
        return simis 

    def sent_w2v(self, sent):
        sent = sent.split(' ')
        vs = np.array([0.0]*100)
        vs_len = 0 
        for wd in sent:
            try:
                v = w2v[wd]
                vs += v 
                vs_len += 1 
            except:
                continue
        return vs/(vs_len*1.0)

    def fit(self, train_data):
        self.train_data = train_data
        self.label = [-1]*len(self.train_data)
        self.explored = [False]*len(self.train_data)
        self.clust = 0 
        if self.if_w2v: 
            self.train_data = [self.sent_w2v(i) for i in train_data]
            print 'word2vec process finished'

        for i in range(len(self.train_data)):
            print i
            if self.explored[i] == False:
                self.iter(i)

    def iter(self, i):
        c, all_c = [i], [i] 
        self.explored[i] == True
        while len(c) != 0:
            key = self.train_data[c[0]]
            for j in range(len(self.train_data)):
                print j 
                if self.explored[j] == False:
                    if self.if_w2v:
                        similarity = self.simi_w2v(key, self.train_data[j])
                    else:
                        similarity = self.simi(key, self.train_data[j])
                    
                    if similarity >= self.eps:
                        self.explored[j] = True
                        c.append(j)
                        all_c.append(j)
            c = c[1:]
        if len(all_c) >= self.min_count:
            for t in all_c:
                self.label[t] = self.clust
            self.clust += 1 

 

posted @ 2017-03-17 14:21  LarryGates  阅读(833)  评论(0编辑  收藏  举报