最近使用sklearn跑一些机器学习的实验对比,发现许多算法随着数据集增大,训练时间呈几何增加,加之交叉验证、参数选择等,非常耗时。
对此,已经有许多优化方案被提出。这里给出一个关于K-NN分类算法的快速实现工具推荐:基于faiss实现版本, 亲测速度提升明显
(1) 安装GPU版本faiss (https://pypi.org/project/faiss-gpu/)
pip install faiss-gpu
(2) 安装支持KNN的faiss wapper的包DESlib
DESlib 是一个集成学习库,类似sklearn,并提供了sklearn的基本一致接口,专注于动态分类器和集成选择的最新技术的实现。
pip install deslib
(3) 使用样例,可以参考接口说明使用:https://deslib.readthedocs.io/en/latest/modules/util/faiss_knn_wrapper.html
from deslib.util.faiss_knn_wrapper import FaissKNNClassifier clf = FaissKNNClassifier(n_neighbors=5, n_jobs=10, algorithm='brute', n_cells=100, n_probes=2) clf.fit(X_train, y_train) #训练 y_test_proba = clf.predict_proba(X_test) #预测概率
本人进过测试, 训练速度比sklearn中的KNN实现快100倍以上。
satellite(sklearn): 1867s vs satellite(faiss ):8s
更多使用和测试参考:
import time
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from deslib.util.faiss_knn_wrapper import FaissKNNClassifier
n_samples = [1000, 10000, 100000, 1000000, 10000000]
rng = 42
faiss_brute = FaissKNNClassifier(n_neighbors=7,
algorithm='brute')
faiss_voronoi = FaissKNNClassifier(n_neighbors=7,
algorithm='voronoi')
faiss_hierarchical = FaissKNNClassifier(n_neighbors=7,
algorithm='hierarchical')
all_knns = [faiss_brute, faiss_voronoi, faiss_hierarchical]
names = ['faiss_brute', 'faiss_voronoi', 'faiss_hierarchical']
list_fitting_time = []
list_search_time = []
for n in n_samples:
print("Number of samples: {}" .format(n))
X, y = make_classification(n_samples=n,
n_features=20,
random_state=rng)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=rng)
temp_fitting_time = []
temp_search_time = []
for name, knn in zip(names, all_knns):
start = time.clock()
knn.fit(X_train, y_train)
fitting_time = time.clock() - start
print("{} fitting time: {}" .format(name, fitting_time))
start = time.clock()
neighbors, dists = knn.kneighbors(X_test)
search_time = time.clock() - start
print("{} neighborhood search time: {}" .format(name, search_time))
temp_fitting_time.append(fitting_time)
temp_search_time.append(search_time)
list_fitting_time.append(temp_fitting_time)
list_search_time.append(temp_search_time)
plt.plot(n_samples, list_search_time)
plt.legend(names)
plt.xlabel("Number of samples")
plt.ylabel("K neighbors search time")
plt.savefig('knn_backbone_benchmark.png')
其他测试可以参考博文:
https://towardsdatascience.com/make-knn-300-times-faster-than-scikit-learns-in-20-lines-5e29d74e76bb

浙公网安备 33010602011771号