决策树
使用的数据集是mlpack自带的鸢尾花数据集,路径是:~/mlpack-4.6.2/src/mlpack/tests/data
代码:
#include <mlpack.hpp>
#include <iostream>
#include <string>
using namespace mlpack;
using namespace mlpack::tree;
using namespace arma;
using namespace std;
/**
* @brief 简化版决策树分类函数
*
* @tparam FeaturesType 特征矩阵类型 (e.g., mat, fmat)
* @tparam LabelsType 标签向量类型 (e.g., Row<size_t>)
* @param trainDataFile 训练数据文件路径
* @param trainLabelsFile 训练标签文件路径
* @param testDataFile 测试数据文件路径
* @param testLabelsFile 测试标签文件路径
* @param datasetName 数据集名称(用于输出)
* @return int 0表示成功,非0表示失败
*/
template<typename FeaturesType, typename LabelsType>
int runDecisionTree(const string& trainDataFile,
const string& trainLabelsFile,
const string& testDataFile,
const string& testLabelsFile,
const string& datasetName)
{
cout << "处理数据集: " << datasetName << endl;
// 加载数据
FeaturesType trainData, testData;
LabelsType trainLabels, testLabels;
if (!data::Load(trainDataFile, trainData) ||
!data::Load(trainLabelsFile, trainLabels) ||
!data::Load(testDataFile, testData) ||
!data::Load(testLabelsFile, testLabels)) {
cerr << "错误: 无法加载数据文件" << endl;
return 1;
}
// 检查数据维度
if (trainData.n_cols != trainLabels.n_elem || testData.n_cols != testLabels.n_elem) {
cerr << "错误: 数据维度不匹配" << endl;
return 1;
}
// 获取类别数量并训练决策树
const size_t numClasses = arma::max(trainLabels) + 1;
DecisionTree<> tree(trainData, trainLabels, numClasses);
// 预测并计算准确率
LabelsType predictions;
tree.Classify(testData, predictions);
double accuracy = 100.0 * arma::accu(predictions == testLabels) / testLabels.n_elem;
cout << "准确率: " << fixed << setprecision(2) << accuracy << "%" << endl;
return 0;
}
int main()
{
return runDecisionTree<mat, Row<size_t>>(
"iris_train.csv",
"iris_train_labels.csv",
"iris_test.csv",
"iris_test_labels.csv",
"鸢尾花数据集"
);
}
liu@liu:~/code/decision_tree$ g++ -o decision_tree decision_tree.cpp -larmadillo
liu@liu:~/code/decision_tree$ ./decision_tree
==============================
数据集: 鸢尾花数据集
==============================
正在加载训练数据... 完成!
- 训练样本数: 87
- 特征数: 4
正在加载训练标签... 完成!
- 训练标签数: 87
正在加载测试数据... 完成!
- 测试样本数: 63
- 特征数: 4
正在加载测试标签... 完成!
- 测试标签数: 63
检测到类别数量: 3
正在训练决策树... 完成!
正在对测试集进行预测... 完成!
==============================
分类结果报告
==============================
数据集名称: 鸢尾花数据集
总样本数: 63
正确预测数: 60
错误预测数: 3
准确率: 95.24%
==============================

浙公网安备 33010602011771号