KNN分类器原理与实现
实现一个完整的KNN分类器,其功能与MATLAB自带的KNN函数相当。
KNN分类器原理与实现
1. KNN算法核心原理
KNN(K-最近邻)算法基于"物以类聚"的思想:
- 训练阶段:仅存储训练数据和标签
- 预测阶段:对于新样本,找到训练集中最近的K个邻居
- 决策规则:通过多数投票(分类)或平均(回归)确定输出
2. 完整KNN分类器实现
classdef myKNNClassifier < handle
% 自定义KNN分类器,功能与MATLAB的fitcknn类似
properties
X_train % 训练特征
Y_train % 训练标签
k % 近邻数量
distance_metric % 距离度量方法
weights % 权重类型
classes % 唯一类别标签
num_classes % 类别数量
standardized % 是否标准化数据
mu % 均值(用于标准化)
sigma % 标准差(用于标准化)
end
methods
function obj = myKNNClassifier(k, distance_metric, weights)
% 构造函数
% 输入参数:
% k - 近邻数量 (默认: 5)
% distance_metric - 距离度量 (默认: 'euclidean')
% weights - 权重类型 (默认: 'equal')
if nargin < 1 || isempty(k)
k = 5;
end
if nargin < 2 || isempty(distance_metric)
distance_metric = 'euclidean';
end
if nargin < 3 || isempty(weights)
weights = 'equal';
end
obj.k = k;
obj.distance_metric = distance_metric;
obj.weights = weights;
obj.standardized = false;
end
function obj = fit(obj, X, Y)
% 训练KNN分类器
% 输入:
% X - 训练特征 (n_samples × n_features)
% Y - 训练标签 (n_samples × 1)
obj.X_train = X;
obj.Y_train = Y;
obj.classes = unique(Y);
obj.num_classes = length(obj.classes);
% 数据标准化
obj.mu = mean(X, 1);
obj.sigma = std(X, 0, 1);
obj.sigma(obj.sigma == 0) = 1; % 避免除零
obj.X_train = (X - obj.mu) ./ obj.sigma;
obj.standardized = true;
end
function Y_pred = predict(obj, X_test)
% 预测新样本的标签
% 输入:
% X_test - 测试特征 (m_samples × n_features)
% 输出:
% Y_pred - 预测标签 (m_samples × 1)
if ~obj.standardized
error('请先调用fit方法训练模型');
end
% 标准化测试数据
X_test_std = (X_test - obj.mu) ./ obj.sigma;
n_test = size(X_test_std, 1);
Y_pred = zeros(n_test, 1);
for i = 1:n_test
% 计算距离
distances = obj.calculate_distances(X_test_std(i, :));
% 找到最近的k个邻居
[~, idx] = mink(distances, obj.k);
nearest_labels = obj.Y_train(idx);
% 投票决定类别
if strcmp(obj.weights, 'equal')
% 等权重投票
Y_pred(i) = mode(nearest_labels);
else
% 距离加权投票
Y_pred(i) = obj.weighted_vote(distances(idx), nearest_labels);
end
end
end
function [Y_pred, scores] = predict_with_scores(obj, X_test)
% 预测并返回各类别的得分
% 输出:
% Y_pred - 预测标签
% scores - 得分矩阵 (m_samples × n_classes)
if ~obj.standardized
error('请先调用fit方法训练模型');
end
X_test_std = (X_test - obj.mu) ./ obj.sigma;
n_test = size(X_test_std, 1);
scores = zeros(n_test, obj.num_classes);
for i = 1:n_test
distances = obj.calculate_distances(X_test_std(i, :));
[~, idx] = mink(distances, obj.k);
nearest_labels = obj.Y_train(idx);
nearest_distances = distances(idx);
% 计算每个类别的得分
for j = 1:obj.num_classes
class_mask = (nearest_labels == obj.classes(j));
if strcmp(obj.weights, 'equal')
scores(i, j) = sum(class_mask) / obj.k;
else
% 距离加权得分
class_weights = 1 ./ (nearest_distances(class_mask) + eps);
scores(i, j) = sum(class_weights) / sum(1./(nearest_distances+eps));
end
end
end
[~, max_idx] = max(scores, [], 2);
Y_pred = obj.classes(max_idx);
end
function accuracy = score(obj, X_test, Y_true)
% 计算模型准确率
Y_pred = obj.predict(X_test);
accuracy = sum(Y_pred == Y_true) / length(Y_true);
end
function cv_accuracy = cross_val_score(obj, X, Y, cv_folds)
% 交叉验证评估
% 输入:
% cv_folds - 交叉验证折数 (默认: 5)
if nargin < 4
cv_folds = 5;
end
n_samples = size(X, 1);
indices = crossvalind('Kfold', n_samples, cv_folds);
cv_accuracy = zeros(cv_folds, 1);
for i = 1:cv_folds
test_mask = (indices == i);
train_mask = ~test_mask;
% 临时训练
temp_knn = myKNNClassifier(obj.k, obj.distance_metric, obj.weights);
temp_knn.fit(X(train_mask, :), Y(train_mask));
% 测试
cv_accuracy(i) = temp_knn.score(X(test_mask, :), Y(test_mask));
end
end
end
methods (Access = private)
function distances = calculate_distances(obj, x)
% 计算单个样本到所有训练样本的距离
switch obj.distance_metric
case 'euclidean'
distances = sqrt(sum((obj.X_train - x).^2, 2));
case 'manhattan'
distances = sum(abs(obj.X_train - x), 2);
case 'chebyshev'
distances = max(abs(obj.X_train - x), [], 2);
case 'minkowski'
% 使用默认p=3
distances = sum(abs(obj.X_train - x).^3, 2).^(1/3);
otherwise
error('不支持的距离度量方法: %s', obj.distance_metric);
end
end
function pred_label = weighted_vote(obj, distances, labels)
% 距离加权投票
weights = 1 ./ (distances + eps); % 避免除零
unique_labels = unique(labels);
weighted_counts = zeros(length(unique_labels), 1);
for i = 1:length(unique_labels)
mask = (labels == unique_labels(i));
weighted_counts(i) = sum(weights(mask));
end
[~, max_idx] = max(weighted_counts);
pred_label = unique_labels(max_idx);
end
end
end
3. 与MATLAB自带KNN功能对比
function compare_with_matlab_knn()
% 比较自定义KNN与MATLAB自带KNN的性能
% 生成示例数据
rng(42); % 设置随机种子以便复现
[X, Y] = generate_sample_data();
% 分割数据
cv = cvpartition(Y, 'HoldOut', 0.3);
X_train = X(cv.training, :);
Y_train = Y(cv.training, :);
X_test = X(cv.test, :);
Y_test = Y(cv.test, :);
% MATLAB自带KNN
tic;
matlab_knn = fitcknn(X_train, Y_train, 'NumNeighbors', 5);
Y_pred_matlab = predict(matlab_knn, X_test);
matlab_time = toc;
matlab_accuracy = sum(Y_pred_matlab == Y_test) / length(Y_test);
% 自定义KNN
tic;
my_knn = myKNNClassifier(5, 'euclidean', 'equal');
my_knn.fit(X_train, Y_train);
Y_pred_my = my_knn.predict(X_test);
my_time = toc;
my_accuracy = sum(Y_pred_my == Y_test) / length(Y_test);
% 显示结果
fprintf('=== KNN分类器性能比较 ===\n');
fprintf('MATLAB KNN - 准确率: %.4f, 时间: %.4f秒\n', matlab_accuracy, matlab_time);
fprintf('自定义KNN - 准确率: %.4f, 时间: %.4f秒\n', my_accuracy, my_time);
fprintf('准确率差异: %.6f\n', abs(matlab_accuracy - my_accuracy));
% 混淆矩阵比较
figure('Position', [100, 100, 1200, 500]);
subplot(1,2,1);
cm_matlab = confusionmat(Y_test, Y_pred_matlab);
confusionchart(cm_matlab);
title('MATLAB KNN 混淆矩阵');
subplot(1,2,2);
cm_my = confusionmat(Y_test, Y_pred_my);
confusionchart(cm_my);
title('自定义KNN 混淆矩阵');
end
function [X, Y] = generate_sample_data(n_samples, n_features, n_classes)
% 生成分类示例数据
if nargin < 1
n_samples = 1000;
end
if nargin < 2
n_features = 4;
end
if nargin < 3
n_classes = 3;
end
rng(42);
X = zeros(n_samples, n_features);
Y = zeros(n_samples, 1);
samples_per_class = floor(n_samples / n_classes);
for i = 1:n_classes
start_idx = (i-1) * samples_per_class + 1;
end_idx = i * samples_per_class;
% 为每个类别生成不同的均值
mean_vec = 2 * i * ones(1, n_features);
cov_mat = diag(0.5 + 0.5 * rand(1, n_features));
X(start_idx:end_idx, :) = mvnrnd(mean_vec, cov_mat, samples_per_class);
Y(start_idx:end_idx) = i;
end
% 随机打乱数据
idx = randperm(n_samples);
X = X(idx, :);
Y = Y(idx, :);
end
4. 高级功能扩展
function demo_advanced_knn()
% 展示KNN高级功能
% 加载数据
[X, Y] = generate_sample_data(500, 3, 3);
% 测试不同K值
k_values = 1:2:15;
accuracies = zeros(length(k_values), 1);
figure('Position', [100, 100, 1000, 800]);
for i = 1:length(k_values)
knn = myKNNClassifier(k_values(i));
cv_acc = knn.cross_val_score(X, Y, 5);
accuracies(i) = mean(cv_acc);
end
subplot(2,2,1);
plot(k_values, accuracies, 'bo-', 'LineWidth', 2, 'MarkerSize', 8);
xlabel('K值');
ylabel('交叉验证准确率');
title('K值选择');
grid on;
% 测试不同距离度量
distance_metrics = {'euclidean', 'manhattan', 'chebyshev'};
dist_accuracies = zeros(length(distance_metrics), 1);
for i = 1:length(distance_metrics)
knn = myKNNClassifier(5, distance_metrics{i});
cv_acc = knn.cross_val_score(X, Y, 5);
dist_accuracies(i) = mean(cv_acc);
end
subplot(2,2,2);
bar(dist_accuracies);
set(gca, 'XTickLabel', distance_metrics);
xlabel('距离度量');
ylabel('准确率');
title('不同距离度量比较');
% 可视化决策边界
subplot(2,2,[3,4]);
visualize_decision_boundary(X, Y);
title('KNN决策边界');
end
function visualize_decision_boundary(X, Y)
% 可视化决策边界(适用于2D数据)
if size(X, 2) ~= 2
fprintf('只支持2D数据可视化\n');
return;
end
% 创建网格
h = 0.02;
x1_min = min(X(:,1)) - 1; x1_max = max(X(:,1)) + 1;
x2_min = min(X(:,2)) - 1; x2_max = max(X(:,2)) + 1;
[xx1, xx2] = meshgrid(x1_min:h:x1_max, x2_min:h:x2_max);
% 训练KNN并预测网格点
knn = myKNNClassifier(5);
knn.fit(X, Y);
X_grid = [xx1(:), xx2(:)];
Z = knn.predict(X_grid);
Z = reshape(Z, size(xx1));
% 绘制决策区域
contourf(xx1, xx2, Z, 'AlphaData', 0.3);
hold on;
% 绘制数据点
classes = unique(Y);
colors = {'red', 'blue', 'green', 'yellow', 'cyan'};
for i = 1:length(classes)
scatter(X(Y==classes(i), 1), X(Y==classes(i), 2), 50, colors{i}, 'filled');
end
xlabel('特征1');
ylabel('特征2');
colorbar;
hold off;
end
5. 使用示例
% 基本使用示例
function basic_usage_example()
% 生成数据
[X, Y] = generate_sample_data(200, 2, 2);
% 创建和训练KNN分类器
knn = myKNNClassifier(3); % 使用3个最近邻
knn.fit(X, Y);
% 预测新样本
X_new = [1.5, 2.0; 3.0, 1.0];
predictions = knn.predict(X_new);
fprintf('预测结果: %d, %d\n', predictions);
% 获取预测得分
[~, scores] = knn.predict_with_scores(X_new);
fprintf('类别得分:\n');
disp(scores);
% 交叉验证
cv_score = knn.cross_val_score(X, Y, 5);
fprintf('5折交叉验证准确率: %.4f\n', mean(cv_score));
end
% 运行比较
compare_with_matlab_knn();
% 运行高级演示
demo_advanced_knn();
% 运行基本示例
basic_usage_example();
参考代码 knn分类器原理实现,可等同于matlab自带KNN函数 www.youwenfan.com/contentcnm/82555.html
功能特性总结
| 特性 | MATLAB KNN | 自定义KNN |
|---|---|---|
| 基本KNN分类 | ✅ | ✅ |
| 多种距离度量 | ✅ | ✅ |
| 加权投票 | ✅ | ✅ |
| 交叉验证 | ✅ | ✅ |
| 决策边界可视化 | ✅ | ✅ |
| 预测概率/得分 | ✅ | ✅ |
| 数据标准化 | ✅ | ✅ |
这个自定义KNN分类器实现了MATLAB自带KNN函数的核心功能,并提供了良好的扩展性。

浙公网安备 33010602011771号