朴素贝叶斯

使用的数据集是mlpack自带的鸢尾花数据集,路径是:~/mlpack-4.6.2/src/mlpack/tests/data

代码:

#include <mlpack.hpp>
#include <iostream>
#include <string>
#include <iomanip>

// 使用 mlpack, Armadillo 和标准命名空间
using namespace mlpack;
using namespace mlpack::naive_bayes; // 明确指定 NaiveBayesClassifier 所在的命名空间
using namespace std;
using namespace arma;

/**
 * @brief 使用 mlpack 实现的朴素贝叶斯分类函数
 *
 * @tparam FeaturesType 特征矩阵类型 (例如, mat, fmat)
 * @tparam LabelsType 标签向量类型 (例如, Row<size_t>)
 * @param trainDataFile 训练数据文件路径
 * @param trainLabelsFile 训练标签文件路径
 * @param testDataFile 测试数据文件路径
 * @param testLabelsFile 真实的测试标签文件路径 (用于计算准确率)
 * @param datasetName 数据集名称 (用于输出)
 * @return int 0 表示成功, 非 0 表示失败
 */
template<typename FeaturesType, typename LabelsType>
int runNaiveBayes(const string& trainDataFile,
                  const string& trainLabelsFile,
                  const string& testDataFile,
                  const string& testLabelsFile,
                  const string& datasetName)
{
    cout << "处理数据集: " << datasetName << endl;

    // 加载数据
    FeaturesType trainData, testData;
    LabelsType trainLabels, testLabels;

    // 使用 data::Load 加载 CSV 文件,出错时打印信息
    if (!data::Load(trainDataFile, trainData, true) ||
        !data::Load(trainLabelsFile, trainLabels, true) ||
        !data::Load(testDataFile, testData, true) ||
        !data::Load(testLabelsFile, testLabels, true))
    {
        // 错误信息已由 data::Load 打印
        return 1;
    }

    // Armadillo 的标签通常是行向量,确保它是正确的形状
    if (trainLabels.n_rows > 1) { trainLabels = trainLabels.t(); }
    if (testLabels.n_rows > 1) { testLabels = testLabels.t(); }

    // 维度检查
    if (trainData.n_cols != trainLabels.n_elem) {
        cerr << "错误: 训练数据和标签的样本数不匹配。" << endl;
        return 1;
    }
    if (testData.n_cols != testLabels.n_elem) {
        cerr << "错误: 测试数据和标签的样本数不匹配。" << endl;
        return 1;
    }

    // 1. 训练模型:
    //    首先,确定数据集中有多少个类别
    const size_t numClasses = arma::max(trainLabels) + 1;

    //    然后,创建 NaiveBayesClassifier 对象并进行训练。
    //    构造函数会处理所有训练逻辑。
    NaiveBayesClassifier<> nbc(trainData, trainLabels, numClasses);

    // 2. 预测: 使用训练好的模型对测试数据进行分类
    LabelsType predictions;
    nbc.Classify(testData, predictions);

    // 3. 计算准确率
    double accuracy = 100.0 * arma::accu(predictions == testLabels) / testLabels.n_elem;
    cout << "朴素贝叶斯分类准确率: " << fixed << setprecision(2) << accuracy << "%" << endl;

    return 0;
}

int main()
{
    // 调用 runNaiveBayes 函数,并传入鸢尾花数据集的文件路径
    return runNaiveBayes<mat, Row<size_t>>(
        "iris_train.csv",
        "iris_train_labels.csv",
        "iris_test.csv",
        "iris_test_labels.csv",
        "鸢尾花数据集"
    );
}
liu@liu:~/code/Naive_Bayes$ g++ -o bayes bayes.cpp -larmadillo
liu@liu:~/code/Naive_Bayes$ ./bayes 
处理数据集: 鸢尾花数据集
朴素贝叶斯分类准确率: 96.83%
posted @ 2025-08-14 10:09  Su1f4t3  阅读(8)  评论(0)    收藏  举报