朴素贝叶斯
使用的数据集是mlpack自带的鸢尾花数据集,路径是:~/mlpack-4.6.2/src/mlpack/tests/data
代码:
#include <mlpack.hpp>
#include <iostream>
#include <string>
#include <iomanip>
// 使用 mlpack, Armadillo 和标准命名空间
using namespace mlpack;
using namespace mlpack::naive_bayes; // 明确指定 NaiveBayesClassifier 所在的命名空间
using namespace std;
using namespace arma;
/**
* @brief 使用 mlpack 实现的朴素贝叶斯分类函数
*
* @tparam FeaturesType 特征矩阵类型 (例如, mat, fmat)
* @tparam LabelsType 标签向量类型 (例如, Row<size_t>)
* @param trainDataFile 训练数据文件路径
* @param trainLabelsFile 训练标签文件路径
* @param testDataFile 测试数据文件路径
* @param testLabelsFile 真实的测试标签文件路径 (用于计算准确率)
* @param datasetName 数据集名称 (用于输出)
* @return int 0 表示成功, 非 0 表示失败
*/
template<typename FeaturesType, typename LabelsType>
int runNaiveBayes(const string& trainDataFile,
const string& trainLabelsFile,
const string& testDataFile,
const string& testLabelsFile,
const string& datasetName)
{
cout << "处理数据集: " << datasetName << endl;
// 加载数据
FeaturesType trainData, testData;
LabelsType trainLabels, testLabels;
// 使用 data::Load 加载 CSV 文件,出错时打印信息
if (!data::Load(trainDataFile, trainData, true) ||
!data::Load(trainLabelsFile, trainLabels, true) ||
!data::Load(testDataFile, testData, true) ||
!data::Load(testLabelsFile, testLabels, true))
{
// 错误信息已由 data::Load 打印
return 1;
}
// Armadillo 的标签通常是行向量,确保它是正确的形状
if (trainLabels.n_rows > 1) { trainLabels = trainLabels.t(); }
if (testLabels.n_rows > 1) { testLabels = testLabels.t(); }
// 维度检查
if (trainData.n_cols != trainLabels.n_elem) {
cerr << "错误: 训练数据和标签的样本数不匹配。" << endl;
return 1;
}
if (testData.n_cols != testLabels.n_elem) {
cerr << "错误: 测试数据和标签的样本数不匹配。" << endl;
return 1;
}
// 1. 训练模型:
// 首先,确定数据集中有多少个类别
const size_t numClasses = arma::max(trainLabels) + 1;
// 然后,创建 NaiveBayesClassifier 对象并进行训练。
// 构造函数会处理所有训练逻辑。
NaiveBayesClassifier<> nbc(trainData, trainLabels, numClasses);
// 2. 预测: 使用训练好的模型对测试数据进行分类
LabelsType predictions;
nbc.Classify(testData, predictions);
// 3. 计算准确率
double accuracy = 100.0 * arma::accu(predictions == testLabels) / testLabels.n_elem;
cout << "朴素贝叶斯分类准确率: " << fixed << setprecision(2) << accuracy << "%" << endl;
return 0;
}
int main()
{
// 调用 runNaiveBayes 函数,并传入鸢尾花数据集的文件路径
return runNaiveBayes<mat, Row<size_t>>(
"iris_train.csv",
"iris_train_labels.csv",
"iris_test.csv",
"iris_test_labels.csv",
"鸢尾花数据集"
);
}
liu@liu:~/code/Naive_Bayes$ g++ -o bayes bayes.cpp -larmadillo
liu@liu:~/code/Naive_Bayes$ ./bayes
处理数据集: 鸢尾花数据集
朴素贝叶斯分类准确率: 96.83%

浙公网安备 33010602011771号