XGBoost--4--代码编写基本流程--分类

这一节主要介绍一下XGBoost算法在CPU/GPU版本下的代码编写基本流程,主要分为以下几个部分:

  • 构造训练集/验证集
  • 算法参数设置
  • XGBoost模型训练/验证
  • 模型预测

主要面对的任务场景是多分类任务,下一节再说回归任务;

另外,除上述几个部分外,会涉及到sklearn用于加载数据集以及最后的模型预测的评价指标计算;

导入使用到的库:

import time
import xgboost as xgb
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, classification
from libs.xgboost_plot import plot_training_merror

1. 构造数据集/验证集

使用sklearn导入数据集,并进一步拆分成训练集、验证集;

# 使用sklearn加载数据集,并进一步拆分
digits = datasets.load_digits()
data, labels = digits.data, digits.target
x_train, x_test, y_train, y_test = train_test_split(data, labels, test_size=0.2, random_state=7)
print("x_train: {}, x_test: {}, classes: {}".format(x_train.shape, x_test.shape, len(set(labels))))

构造XGBoost算法需要的输入格式:

dtrain = xgb.DMatrix(x_train, y_train)  # 训练集
dtest = xgb.DMatrix(x_test, y_test)  # 验证集
evals = [(dtrain, 'train'), (dtest, 'val')]  # 训练过程中进行验证

2. 算法参数设置

算法模型参数设置,详情见:XGBoost Parameters

params = {
    'tree_method': "gpu_hist",
    'booster': 'gbtree',
    'objective': 'multi:softmax',
    'num_class': 10,
    'max_depth': 6,
    'eval_metric': 'merror',
    'eta': 0.01,
    'verbosity': 0,
    'gpu_id': 0
}

简单介绍以下:

  • tree_methodgpu_hist表示使用GPU运算,影响的可以使用hist利用CPU计算;
  • objective,目标函数
  • num_class类别数量,配合multi:softmax使用;

3. XGBoost模型训练/验证

模型训练、保存模型、绘制merror图像;

s_time = time.time()
train_res = {}
model = xgb.train(params, dtrain, num_boost_round=100,
                  evals=evals,
                  evals_result=train_res)
print("模型训练耗时: {}".format(time.time() - s_time))

# 模型保存
save_path = "./saved_model/model.model"
model.save_model(save_path)

# train/val的merror绘图
merror_img_path = "./test/error.png"
plot_training_merror(train_res, merror_img_path)

4. 模型预测

模型预测,打印预测结果

pred_data = model.predict(dtest)
res = classification_report(y_test, pred_data)
print(res)

输出如下:

              precision    recall  f1-score   support

           0       1.00      0.95      0.98        43
           1       0.86      1.00      0.92        42
           2       0.98      1.00      0.99        40
           3       0.89      0.97      0.93        34
           4       0.92      0.89      0.90        37
           5       0.93      0.96      0.95        28
           6       0.96      0.93      0.95        28
           7       0.86      0.94      0.90        33
           8       0.95      0.81      0.88        43
           9       0.96      0.81      0.88        32

    accuracy                           0.93       360
   macro avg       0.93      0.93      0.93       360
weighted avg       0.93      0.93      0.93       360

5. 结语

XGBoost框架最基本的使用,也就是这个流程了:

  • 构建数据集
  • 参数设置
  • 模型训练、保存、预测;

当然,在实际应用中,每一个步骤中,都存在很多的细节值得深究,也必须深究;

不然很难做到知其然知其所以然,对于实际的问题,也很难获得一个很好的结果;

posted @ 2020-10-27 15:55  chenzhen0530  阅读(447)  评论(0编辑  收藏  举报