提升机器算法
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
"""
提升机器算法
LightGBM是个快速的,分布式的,高性能的基于决策树算法的梯度提升框架。
可用于排序,分类,回归以及很多其他的机器学习任务中。
"""
import multiprocessing
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
import lightgbm as lgb
from sklearn.model_selection import cross_val_score, GridSearchCV, train_test_split
from wssc.app.config import DIM_DICT, DATASET
def create_bag_of_centroids(wordlist, dim_dict=DIM_DICT, num_centroids=len(DIM_DICT)):
"""
向量袋
:param wordlist:
:param dim_dict:
:param num_centroids:
:return:
"""
bag_of_centroids = np.zeros(num_centroids, dtype="float32")
for word in wordlist:
if dim_dict.__contains__(word):
index = dim_dict.__getitem__(word)
bag_of_centroids[index] += 1
return bag_of_centroids
def main():
NUM_CORES = multiprocessing.cpu_count()
raw_data = pd.read_csv(DATASET % "result_1")
raw_data["DIM"] = Parallel(n_jobs=NUM_CORES, verbose=10)(
delayed(create_bag_of_centroids)(raw_data.iloc[index]["DIM"].split(";")) for index in range(len(raw_data)))
X = np.array(raw_data["DIM"].values.tolist())
Y = np.array(raw_data["C_ID"])
seed = 7
test_size = 0.22
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=test_size, random_state=seed)
# lightgbm
lgb_train = lgb.Dataset(X_train, y_train) # 将数据保存到LightGBM二进制文件将使加载更快
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train) # 创建验证数据
params = {
'boosting_type': 'gbdt',
'objective': 'multiclass',
'num_class': 50,
'metric': 'multi_error',
'num_leaves': 300,
'min_data_in_leaf': 100,
'learning_rate': 0.01,
'feature_fraction': 0.8,
'bagging_fraction': 0.8,
'bagging_freq': 5,
'lambda_l1': 0.4,
'lambda_l2': 0.5,
'min_gain_to_split': 0.2,
'verbose': 5,
'is_unbalance': True
}
print('Start training...')
# train
gbm = lgb.train(params,
lgb_train,
num_boost_round=10000,
valid_sets=lgb_eval,
early_stopping_rounds=500)
print('Start predicting...')
# 预测数据集
y_pred = gbm.predict(X_test, num_iteration=gbm.best_iteration) # 如果在训练期间启用了早期停止,可以通过best_iteration方式从最佳迭代中获得预测
pred = np.argmax(y_pred, axis=1)
# 评估模型
print((pred == y_test).mean())
if __name__ == '__main__':
main()
posted on 2019-12-02 14:21 nnnnnnnnnnnnnnnn 阅读(238) 评论(0) 收藏 举报
                    
                
                
            
        
浙公网安备 33010602011771号