决策树

使用的数据集是mlpack自带的鸢尾花数据集,路径是:~/mlpack-4.6.2/src/mlpack/tests/data

代码:

#include <mlpack.hpp>
#include <iostream>
#include <string>

using namespace mlpack;
using namespace mlpack::tree;
using namespace arma;
using namespace std;

/**
 * @brief 简化版决策树分类函数
 * 
 * @tparam FeaturesType 特征矩阵类型 (e.g., mat, fmat)
 * @tparam LabelsType 标签向量类型 (e.g., Row<size_t>)
 * @param trainDataFile 训练数据文件路径
 * @param trainLabelsFile 训练标签文件路径
 * @param testDataFile 测试数据文件路径
 * @param testLabelsFile 测试标签文件路径
 * @param datasetName 数据集名称(用于输出)
 * @return int 0表示成功,非0表示失败
 */
template<typename FeaturesType, typename LabelsType>
int runDecisionTree(const string& trainDataFile,
                    const string& trainLabelsFile,
                    const string& testDataFile,
                    const string& testLabelsFile,
                    const string& datasetName)
{
    cout << "处理数据集: " << datasetName << endl;
    
    // 加载数据
    FeaturesType trainData, testData;
    LabelsType trainLabels, testLabels;
    
    if (!data::Load(trainDataFile, trainData) || 
        !data::Load(trainLabelsFile, trainLabels) ||
        !data::Load(testDataFile, testData) ||
        !data::Load(testLabelsFile, testLabels)) {
        cerr << "错误: 无法加载数据文件" << endl;
        return 1;
    }
    
    // 检查数据维度
    if (trainData.n_cols != trainLabels.n_elem || testData.n_cols != testLabels.n_elem) {
        cerr << "错误: 数据维度不匹配" << endl;
        return 1;
    }
    
    // 获取类别数量并训练决策树
    const size_t numClasses = arma::max(trainLabels) + 1;
    DecisionTree<> tree(trainData, trainLabels, numClasses);
    
    // 预测并计算准确率
    LabelsType predictions;
    tree.Classify(testData, predictions);
    double accuracy = 100.0 * arma::accu(predictions == testLabels) / testLabels.n_elem;
    
    cout << "准确率: " << fixed << setprecision(2) << accuracy << "%" << endl;
    return 0;
}

int main()
{
    return runDecisionTree<mat, Row<size_t>>(
        "iris_train.csv",
        "iris_train_labels.csv",
        "iris_test.csv",
        "iris_test_labels.csv",
        "鸢尾花数据集"
    );
}
liu@liu:~/code/decision_tree$ g++ -o decision_tree decision_tree.cpp -larmadillo
liu@liu:~/code/decision_tree$ ./decision_tree 
==============================
数据集: 鸢尾花数据集
==============================
正在加载训练数据... 完成!
  - 训练样本数: 87
  - 特征数: 4
正在加载训练标签... 完成!
  - 训练标签数: 87
正在加载测试数据... 完成!
  - 测试样本数: 63
  - 特征数: 4
正在加载测试标签... 完成!
  - 测试标签数: 63
检测到类别数量: 3

正在训练决策树... 完成!
正在对测试集进行预测... 完成!

==============================
      分类结果报告
==============================
数据集名称: 鸢尾花数据集
总样本数: 63
正确预测数: 60
错误预测数: 3
准确率: 95.24%
==============================
posted @ 2025-08-14 10:13  Su1f4t3  阅读(10)  评论(0)    收藏  举报