随机森林
使用的数据集是mlpack自带的鸢尾花数据集,路径是:~/mlpack-4.6.2/src/mlpack/tests/data
代码:
#include <mlpack.hpp>
#include <iostream>
#include <string>
#include <iomanip>
// 使用 mlpack, Armadillo 和标准命名空间
using namespace mlpack;
using namespace mlpack::tree; // RandomForest is in the tree namespace
using namespace std;
using namespace arma;
/**
* @brief 使用 mlpack 实现的随机森林分类函数
*
* @tparam FeaturesType 特征矩阵类型 (例如, mat, fmat)
* @tparam LabelsType 标签向量类型 (例如, Row<size_t>)
* @param trainDataFile 训练数据文件路径
* @param trainLabelsFile 训练标签文件路径
* @param testDataFile 测试数据文件路径
* @param testLabelsFile 真实的测试标签文件路径 (用于计算准确率)
* @param datasetName 数据集名称 (用于输出)
* @param numTrees 森林中决策树的数量
* @return int 0 表示成功, 非 0 表示失败
*/
template<typename FeaturesType, typename LabelsType>
int runRandomForest(const string& trainDataFile,
const string& trainLabelsFile,
const string& testDataFile,
const string& testLabelsFile,
const string& datasetName,
const size_t numTrees = 10) // 默认创建10棵树
{
cout << "处理数据集: " << datasetName << endl;
// 加载数据
FeaturesType trainData, testData;
LabelsType trainLabels, testLabels;
// 使用 data::Load 加载 CSV 文件
if (!data::Load(trainDataFile, trainData, true) ||
!data::Load(trainLabelsFile, trainLabels, true) ||
!data::Load(testDataFile, testData, true) ||
!data::Load(testLabelsFile, testLabels, true))
{
return 1;
}
// 确保标签是行向量
if (trainLabels.n_rows > 1) { trainLabels = trainLabels.t(); }
if (testLabels.n_rows > 1) { testLabels = testLabels.t(); }
// 维度检查
if (trainData.n_cols != trainLabels.n_elem) {
cerr << "错误: 训练数据和标签的样本数不匹配。" << endl;
return 1;
}
if (testData.n_cols != testLabels.n_elem) {
cerr << "错误: 测试数据和标签的样本数不匹配。" << endl;
return 1;
}
// 1. 训练模型:
// 首先,确定数据集中有多少个类别
const size_t numClasses = arma::max(trainLabels) + 1;
// 然后,创建 RandomForest 对象。
// 构造函数接受训练数据、标签、类别数和树的数量。
RandomForest<> rf(trainData, trainLabels, numClasses, numTrees);
// 2. 预测: 使用训练好的森林对测试数据进行分类
LabelsType predictions;
rf.Classify(testData, predictions);
// 3. 计算准确率
double accuracy = 100.0 * arma::accu(predictions == testLabels) / testLabels.n_elem;
cout << "随机森林 (" << numTrees << " 棵树) 分类准确率: " << fixed << setprecision(2) << accuracy << "%" << endl;
return 0;
}
int main()
{
// 调用 runRandomForest 函数,并传入鸢尾花数据集的文件路径
return runRandomForest<mat, Row<size_t>>(
"iris_train.csv",
"iris_train_labels.csv",
"iris_test.csv",
"iris_test_labels.csv",
"鸢尾花数据集",
20 // 尝试使用20棵树
);
}
liu@liu:~/code/Random_Forest$ g++ -o random random.cpp -larmadillo
liu@liu:~/code/Random_Forest$ ./random
处理数据集: 鸢尾花数据集
随机森林 (20 棵树) 分类准确率: 95.24%
liu@liu:~/code/Random_Forest$ g++ -o random random.cpp -larmadillo
liu@liu:~/code/Random_Forest$ ./random
处理数据集: 鸢尾花数据集
随机森林 (30 棵树) 分类准确率: 96.83%

浙公网安备 33010602011771号