导航

XGBoost判断蘑菇是否有毒示例

Posted on 2021-07-27 19:50  蝈蝈俊  阅读(918)  评论(0编辑  收藏  举报

数据文件说明

本示例的数据集文件可以在https://github.com/dmlc/xgboost/tree/master/demo/data这里获得。
该数据集描述的是不同蘑菇的相关特征,如大小、颜色等,并且每一种蘑菇都会被标记为可食用的(标记为0)或有毒的(标记为1)。

LibSVM 格式说明

这个数据是LibSVM格式的

LibSVM 使用的训练数据和检验数据文件格式如下:

[label] [index1]:[value1] [index2]:[value2] …
[label] [index1]:[value1] [index2]:[value2] …

label 目标值,就是说class(属于哪一类),就是你要分类的种类,通常是一些整数。

index 是有顺序的索引,通常是连续的整数。就是指特征编号,必须按照升序排列

value 就是特征值,用来train的数据,通常是一堆实数组成。

格式特征:

  • 每行包含一个实例,并以“ \ n”字符结尾。
  • 对于分类,
  • 是一个从1开始的整数, 是一个实数。唯一的例外是预先计算的内核, 从0开始;

参考:libsvm的数据格式及制作

我们这个例子中的数据文件

1 3:1 10:1 11:1 21:1 30:1 34:1 36:1 40:1 41:1 53:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 105:1 117:1 124:1
0 3:1 10:1 20:1 21:1 23:1 34:1 36:1 39:1 41:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 116:1 120:1
0 1:1 10:1 19:1 21:1 24:1 34:1 36:1 39:1 42:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 106:1 116:1 122:1
1 3:1 9:1 19:1 21:1 30:1 34:1 36:1 40:1 42:1 53:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 105:1 117:1 124:1
0 3:1 10:1 14:1 22:1 29:1 34:1 37:1 39:1 41:1 54:1 58:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 98:1 106:1 114:1 120:1
0 3:1 9:1 20:1 21:1 23:1 34:1 36:1 39:1 42:1 53:1 56:1 65:1 69:1 77:1 86:1 88:1 92:1 95:1 102:1 105:1 116:1 120:1

每一行的label值,标记该蘑菇可食用的(标记为0)或有毒的(标记为1)。

数据源说明

这个例子的数据源来自:http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/
数据中包括蘑菇对形状、颜色等特征,以及是否有毒的标签。

agaricus-lepiota.data

原始数据存放在agaricus-lepiota.data里,内容如下所示。它有23列,其中第一列是标签列,p表示有毒,e表示没有毒。后面的22列是22个特征对应的特征值。

agaricus-lepiota.names

agaricus-lepiota.names 文件里存放特征映射关系,比如蘑菇头形状(cap-shap)为钟型(bell)的用b表示,圆锥型(conical)的用c表示;蘑菇头颜色(cap-color)为棕色(brown)的用n表示,浅黄色(buff)的用b表示,等等。总共22个特征映射,对应agaricus-lepiota.data里的第1~22列(第0列为标签)。

数据准备

这里我们已经把这个数据变化成了LibSVM格式。

另外我们还把数据随机分成训练集(agaricus.txt.train)和测试集(agaricus.txt.test)两部分,80%的数据分配给训练集,20%分配给测试集。

参考:xgboost小试

训练模型

我们的任务是对蘑菇特征数据进行学习,训练相关模型,然后利用训练好的模型预测未知的蘑菇样本是否有毒。

import xgboost as xgb

# 数据读取
xgb_train = xgb.DMatrix('./agaricus.txt.train')
xgb_test = xgb.DMatrix('./agaricus.txt.test')

# 定义模型训练参数
params = {
    "objective":"binary:logistic",
    "booster":"gbtree",
    "max_depth":3
}

# 训练轮数
num_round = 5

# 训练过程中实时输出评估结果
watchlist = [(xgb_train,'train'),(xgb_test,'test')]

# 模型训练
model = xgb.train(params,xgb_train,num_round,watchlist)

输出结果

% python ./xgb20.py 
[19:25:53] WARNING: /opt/concourse/worker/volumes/live/7a2b9f41-3287-451b-6691-43e9a6c0910f/volume/xgboost-split_1619728204606/work/src/learner.cc:1061: Starting in XGBoost 1.3.0, the default evaluation metric used with the objective 'binary:logistic' was changed from 'error' to 'logloss'. Explicitly set eval_metric if you'd like to restore the old behavior.
[0]     train-logloss:0.45224   test-logloss:0.45317
[1]     train-logloss:0.32281   test-logloss:0.32412
[2]     train-logloss:0.23637   test-logloss:0.23739
[3]     train-logloss:0.16933   test-logloss:0.16935
[4]     train-logloss:0.12386   test-logloss:0.12352

XGBoost训练过程中实时输出了训练集和测试集的错误率评估结果。随着训练的进行,训练集和测试集的错误率均在不断下降,说明模型对于特征数据的学习是十分有效的。

参数说明

  • "objective":"binary:logistic" objective 该参数用来指定目标函数,XGBoost可以根据该参数判断进行何种学习任务,binary:logistic和binary:logitraw都表示学习任务类型为二分类。binary:logistic输出为概率,binary:logitraw输出为逻辑转换前的输出分数。
  • booster为gbtree表示采用XGBoost中的树模型。
  • 参数max_depth表示决策树分裂的最大深度。

预测


# 对测试集进行预测
preds = model.predict(xgb_test)

参考: