随机森林

使用的数据集是mlpack自带的鸢尾花数据集,路径是:~/mlpack-4.6.2/src/mlpack/tests/data

代码:

#include <mlpack.hpp>
#include <iostream>
#include <string>
#include <iomanip>

// 使用 mlpack, Armadillo 和标准命名空间
using namespace mlpack;
using namespace mlpack::tree; // RandomForest is in the tree namespace
using namespace std;
using namespace arma;

/**
 * @brief 使用 mlpack 实现的随机森林分类函数
 *
 * @tparam FeaturesType 特征矩阵类型 (例如, mat, fmat)
 * @tparam LabelsType 标签向量类型 (例如, Row<size_t>)
 * @param trainDataFile 训练数据文件路径
 * @param trainLabelsFile 训练标签文件路径
 * @param testDataFile 测试数据文件路径
 * @param testLabelsFile 真实的测试标签文件路径 (用于计算准确率)
 * @param datasetName 数据集名称 (用于输出)
 * @param numTrees 森林中决策树的数量
 * @return int 0 表示成功, 非 0 表示失败
 */
template<typename FeaturesType, typename LabelsType>
int runRandomForest(const string& trainDataFile,
                    const string& trainLabelsFile,
                    const string& testDataFile,
                    const string& testLabelsFile,
                    const string& datasetName,
                    const size_t numTrees = 10) // 默认创建10棵树
{
    cout << "处理数据集: " << datasetName << endl;

    // 加载数据
    FeaturesType trainData, testData;
    LabelsType trainLabels, testLabels;

    // 使用 data::Load 加载 CSV 文件
    if (!data::Load(trainDataFile, trainData, true) ||
        !data::Load(trainLabelsFile, trainLabels, true) ||
        !data::Load(testDataFile, testData, true) ||
        !data::Load(testLabelsFile, testLabels, true))
    {
        return 1;
    }

    // 确保标签是行向量
    if (trainLabels.n_rows > 1) { trainLabels = trainLabels.t(); }
    if (testLabels.n_rows > 1) { testLabels = testLabels.t(); }

    // 维度检查
    if (trainData.n_cols != trainLabels.n_elem) {
        cerr << "错误: 训练数据和标签的样本数不匹配。" << endl;
        return 1;
    }
    if (testData.n_cols != testLabels.n_elem) {
        cerr << "错误: 测试数据和标签的样本数不匹配。" << endl;
        return 1;
    }

    // 1. 训练模型:
    //    首先,确定数据集中有多少个类别
    const size_t numClasses = arma::max(trainLabels) + 1;

    //    然后,创建 RandomForest 对象。
    //    构造函数接受训练数据、标签、类别数和树的数量。
    RandomForest<> rf(trainData, trainLabels, numClasses, numTrees);

    // 2. 预测: 使用训练好的森林对测试数据进行分类
    LabelsType predictions;
    rf.Classify(testData, predictions);

    // 3. 计算准确率
    double accuracy = 100.0 * arma::accu(predictions == testLabels) / testLabels.n_elem;
    cout << "随机森林 (" << numTrees << " 棵树) 分类准确率: " << fixed << setprecision(2) << accuracy << "%" << endl;

    return 0;
}

int main()
{
    // 调用 runRandomForest 函数,并传入鸢尾花数据集的文件路径
    return runRandomForest<mat, Row<size_t>>(
        "iris_train.csv",
        "iris_train_labels.csv",
        "iris_test.csv",
        "iris_test_labels.csv",
        "鸢尾花数据集",
        20 // 尝试使用20棵树
    );
}
liu@liu:~/code/Random_Forest$ g++ -o random random.cpp -larmadillo
liu@liu:~/code/Random_Forest$ ./random 
处理数据集: 鸢尾花数据集
随机森林 (20 棵树) 分类准确率: 95.24%
liu@liu:~/code/Random_Forest$ g++ -o random random.cpp -larmadillo
liu@liu:~/code/Random_Forest$ ./random 
处理数据集: 鸢尾花数据集
随机森林 (30 棵树) 分类准确率: 96.83%
posted @ 2025-08-14 10:08  Su1f4t3  阅读(11)  评论(0)    收藏  举报