KNN

使用的数据集是mlpack自带的鸢尾花数据集,路径是:~/mlpack-4.6.2/src/mlpack/tests/data

代码:

#include <mlpack.hpp>
#include <iostream>
#include <string>
#include <iomanip>
#include <map>

// 使用 mlpack, Armadillo 和标准命名空间
using namespace mlpack;
using namespace mlpack::neighbor; // 明确指定 NeighborSearch 所在的命名空间
using namespace std;
using namespace arma;

/**
 * @brief 使用 mlpack 实现的 KNN 分类函数 (兼容 mlpack 4.x API)
 *
 * @tparam FeaturesType 特征矩阵类型 (例如, mat, fmat)
 * @tparam LabelsType 标签向量类型 (例如, Row<size_t>)
 * @param trainDataFile 训练数据文件路径
 * @param trainLabelsFile 训练标签文件路径
 * @param testDataFile 测试数据文件路径
 * @param testLabelsFile 真实的测试标签文件路径 (用于计算准确率)
 * @param datasetName 数据集名称 (用于输出)
 * @param k K-近邻算法中的K值
 * @return int 0 表示成功, 非 0 表示失败
 */
template<typename FeaturesType, typename LabelsType>
int runKnn(const string& trainDataFile,
           const string& trainLabelsFile,
           const string& testDataFile,
           const string& testLabelsFile,
           const string& datasetName,
           const size_t k = 5)
{
    cout << "处理数据集: " << datasetName << endl;

    // 加载数据
    FeaturesType trainData, testData;
    LabelsType trainLabels, testLabels;

    // 使用 data::Load 加载 CSV 文件,出错时打印信息
    if (!data::Load(trainDataFile, trainData, true) ||
        !data::Load(trainLabelsFile, trainLabels, true) ||
        !data::Load(testDataFile, testData, true) ||
        !data::Load(testLabelsFile, testLabels, true))
    {
        // 错误信息已由 data::Load 打印
        return 1;
    }
    
    // Armadillo 的标签通常是行向量,确保它是正确的形状
    // 如果 trainLabels 是列向量,转置它
    if (trainLabels.n_rows > 1) { trainLabels = trainLabels.t(); }
    if (testLabels.n_rows > 1) { testLabels = testLabels.t(); }


    // 维度检查
    if (trainData.n_cols != trainLabels.n_elem) {
        cerr << "错误: 训练数据和标签的样本数不匹配。" << endl;
        return 1;
    }
    if (testData.n_cols != testLabels.n_elem) {
        cerr << "错误: 测试数据和标签的样本数不匹配。" << endl;
        return 1;
    }

    // 1. 训练模型: 对于 KNN,这仅意味着构建一个可以快速搜索的数据结构
    //    构造函数只接受训练特征数据。
    KNN knn(trainData);

    // 2. 搜索邻居: 为测试集中的每个点找到 k 个最近的邻居
    arma::Mat<size_t> resultingNeighbors;
    arma::mat resultingDistances; // 距离矩阵,虽然这里不用,但API需要它
    
    // 调用 Search 方法
    knn.Search(testData, k, resultingNeighbors, resultingDistances);

    // 3. 投票分类: 手动根据邻居的标签来预测类别
    LabelsType predictions(testData.n_cols);
    for (size_t i = 0; i < testData.n_cols; ++i)
    {
        // 使用 map 来统计每个邻居标签出现的次数
        std::map<size_t, size_t> counts;
        for (size_t j = 0; j < k; ++j)
        {
            const size_t neighborIndex = resultingNeighbors(j, i);
            const size_t neighborLabel = trainLabels(neighborIndex);
            counts[neighborLabel]++;
        }

        // 找到出现次数最多的标签作为预测结果
        size_t bestLabel = 0;
        size_t maxCount = 0;
        for (auto const& [label, count] : counts)
        {
            if (count > maxCount)
            {
                maxCount = count;
                bestLabel = label;
            }
        }
        predictions(i) = bestLabel;
    }

    // 计算准确率
    double accuracy = 100.0 * arma::accu(predictions == testLabels) / testLabels.n_elem;
    cout << "当 K = " << k << " 时,KNN 分类准确率: " << fixed << setprecision(2) << accuracy << "%" << endl;

    return 0;
}

int main()
{
    return runKnn<mat, Row<size_t>>(
        "iris_train.csv",
        "iris_train_labels.csv",
        "iris_test.csv",
        "iris_test_labels.csv",
        "鸢尾花数据集",
        5 // 在这里调整 K 值
    );
}
liu@liu:~/code/KNN1$ g++ -o knn knn.cpp -larmadillo
liu@liu:~/code/KNN1$ ./knn 
处理数据集: 鸢尾花数据集
当 K = 5 时,KNN 分类准确率: 98.41%
liu@liu:~/code/KNN1$ g++ -o knn knn.cpp -larmadillo
liu@liu:~/code/KNN1$ ./knn 
处理数据集: 鸢尾花数据集
当 K = 8 时,KNN 分类准确率: 95.24%
posted @ 2025-08-14 10:09  Su1f4t3  阅读(7)  评论(0)    收藏  举报