KNN
使用的数据集是mlpack自带的鸢尾花数据集,路径是:~/mlpack-4.6.2/src/mlpack/tests/data
代码:
#include <mlpack.hpp>
#include <iostream>
#include <string>
#include <iomanip>
#include <map>
// 使用 mlpack, Armadillo 和标准命名空间
using namespace mlpack;
using namespace mlpack::neighbor; // 明确指定 NeighborSearch 所在的命名空间
using namespace std;
using namespace arma;
/**
* @brief 使用 mlpack 实现的 KNN 分类函数 (兼容 mlpack 4.x API)
*
* @tparam FeaturesType 特征矩阵类型 (例如, mat, fmat)
* @tparam LabelsType 标签向量类型 (例如, Row<size_t>)
* @param trainDataFile 训练数据文件路径
* @param trainLabelsFile 训练标签文件路径
* @param testDataFile 测试数据文件路径
* @param testLabelsFile 真实的测试标签文件路径 (用于计算准确率)
* @param datasetName 数据集名称 (用于输出)
* @param k K-近邻算法中的K值
* @return int 0 表示成功, 非 0 表示失败
*/
template<typename FeaturesType, typename LabelsType>
int runKnn(const string& trainDataFile,
const string& trainLabelsFile,
const string& testDataFile,
const string& testLabelsFile,
const string& datasetName,
const size_t k = 5)
{
cout << "处理数据集: " << datasetName << endl;
// 加载数据
FeaturesType trainData, testData;
LabelsType trainLabels, testLabels;
// 使用 data::Load 加载 CSV 文件,出错时打印信息
if (!data::Load(trainDataFile, trainData, true) ||
!data::Load(trainLabelsFile, trainLabels, true) ||
!data::Load(testDataFile, testData, true) ||
!data::Load(testLabelsFile, testLabels, true))
{
// 错误信息已由 data::Load 打印
return 1;
}
// Armadillo 的标签通常是行向量,确保它是正确的形状
// 如果 trainLabels 是列向量,转置它
if (trainLabels.n_rows > 1) { trainLabels = trainLabels.t(); }
if (testLabels.n_rows > 1) { testLabels = testLabels.t(); }
// 维度检查
if (trainData.n_cols != trainLabels.n_elem) {
cerr << "错误: 训练数据和标签的样本数不匹配。" << endl;
return 1;
}
if (testData.n_cols != testLabels.n_elem) {
cerr << "错误: 测试数据和标签的样本数不匹配。" << endl;
return 1;
}
// 1. 训练模型: 对于 KNN,这仅意味着构建一个可以快速搜索的数据结构
// 构造函数只接受训练特征数据。
KNN knn(trainData);
// 2. 搜索邻居: 为测试集中的每个点找到 k 个最近的邻居
arma::Mat<size_t> resultingNeighbors;
arma::mat resultingDistances; // 距离矩阵,虽然这里不用,但API需要它
// 调用 Search 方法
knn.Search(testData, k, resultingNeighbors, resultingDistances);
// 3. 投票分类: 手动根据邻居的标签来预测类别
LabelsType predictions(testData.n_cols);
for (size_t i = 0; i < testData.n_cols; ++i)
{
// 使用 map 来统计每个邻居标签出现的次数
std::map<size_t, size_t> counts;
for (size_t j = 0; j < k; ++j)
{
const size_t neighborIndex = resultingNeighbors(j, i);
const size_t neighborLabel = trainLabels(neighborIndex);
counts[neighborLabel]++;
}
// 找到出现次数最多的标签作为预测结果
size_t bestLabel = 0;
size_t maxCount = 0;
for (auto const& [label, count] : counts)
{
if (count > maxCount)
{
maxCount = count;
bestLabel = label;
}
}
predictions(i) = bestLabel;
}
// 计算准确率
double accuracy = 100.0 * arma::accu(predictions == testLabels) / testLabels.n_elem;
cout << "当 K = " << k << " 时,KNN 分类准确率: " << fixed << setprecision(2) << accuracy << "%" << endl;
return 0;
}
int main()
{
return runKnn<mat, Row<size_t>>(
"iris_train.csv",
"iris_train_labels.csv",
"iris_test.csv",
"iris_test_labels.csv",
"鸢尾花数据集",
5 // 在这里调整 K 值
);
}
liu@liu:~/code/KNN1$ g++ -o knn knn.cpp -larmadillo
liu@liu:~/code/KNN1$ ./knn
处理数据集: 鸢尾花数据集
当 K = 5 时,KNN 分类准确率: 98.41%
liu@liu:~/code/KNN1$ g++ -o knn knn.cpp -larmadillo
liu@liu:~/code/KNN1$ ./knn
处理数据集: 鸢尾花数据集
当 K = 8 时,KNN 分类准确率: 95.24%

浙公网安备 33010602011771号