KNN分类器原理与实现

实现一个完整的KNN分类器,其功能与MATLAB自带的KNN函数相当。

KNN分类器原理与实现

1. KNN算法核心原理

KNN(K-最近邻)算法基于"物以类聚"的思想:

  • 训练阶段:仅存储训练数据和标签
  • 预测阶段:对于新样本,找到训练集中最近的K个邻居
  • 决策规则:通过多数投票(分类)或平均(回归)确定输出

2. 完整KNN分类器实现

classdef myKNNClassifier < handle
    % 自定义KNN分类器,功能与MATLAB的fitcknn类似
    
    properties
        X_train           % 训练特征
        Y_train           % 训练标签
        k                 % 近邻数量
        distance_metric   % 距离度量方法
        weights           % 权重类型
        classes           % 唯一类别标签
        num_classes       % 类别数量
        standardized      % 是否标准化数据
        mu                % 均值(用于标准化)
        sigma             % 标准差(用于标准化)
    end
    
    methods
        function obj = myKNNClassifier(k, distance_metric, weights)
            % 构造函数
            % 输入参数:
            %   k - 近邻数量 (默认: 5)
            %   distance_metric - 距离度量 (默认: 'euclidean')
            %   weights - 权重类型 (默认: 'equal')
            
            if nargin < 1 || isempty(k)
                k = 5;
            end
            if nargin < 2 || isempty(distance_metric)
                distance_metric = 'euclidean';
            end
            if nargin < 3 || isempty(weights)
                weights = 'equal';
            end
            
            obj.k = k;
            obj.distance_metric = distance_metric;
            obj.weights = weights;
            obj.standardized = false;
        end
        
        function obj = fit(obj, X, Y)
            % 训练KNN分类器
            % 输入:
            %   X - 训练特征 (n_samples × n_features)
            %   Y - 训练标签 (n_samples × 1)
            
            obj.X_train = X;
            obj.Y_train = Y;
            obj.classes = unique(Y);
            obj.num_classes = length(obj.classes);
            
            % 数据标准化
            obj.mu = mean(X, 1);
            obj.sigma = std(X, 0, 1);
            obj.sigma(obj.sigma == 0) = 1; % 避免除零
            
            obj.X_train = (X - obj.mu) ./ obj.sigma;
            obj.standardized = true;
        end
        
        function Y_pred = predict(obj, X_test)
            % 预测新样本的标签
            % 输入:
            %   X_test - 测试特征 (m_samples × n_features)
            % 输出:
            %   Y_pred - 预测标签 (m_samples × 1)
            
            if ~obj.standardized
                error('请先调用fit方法训练模型');
            end
            
            % 标准化测试数据
            X_test_std = (X_test - obj.mu) ./ obj.sigma;
            
            n_test = size(X_test_std, 1);
            Y_pred = zeros(n_test, 1);
            
            for i = 1:n_test
                % 计算距离
                distances = obj.calculate_distances(X_test_std(i, :));
                
                % 找到最近的k个邻居
                [~, idx] = mink(distances, obj.k);
                nearest_labels = obj.Y_train(idx);
                
                % 投票决定类别
                if strcmp(obj.weights, 'equal')
                    % 等权重投票
                    Y_pred(i) = mode(nearest_labels);
                else
                    % 距离加权投票
                    Y_pred(i) = obj.weighted_vote(distances(idx), nearest_labels);
                end
            end
        end
        
        function [Y_pred, scores] = predict_with_scores(obj, X_test)
            % 预测并返回各类别的得分
            % 输出:
            %   Y_pred - 预测标签
            %   scores - 得分矩阵 (m_samples × n_classes)
            
            if ~obj.standardized
                error('请先调用fit方法训练模型');
            end
            
            X_test_std = (X_test - obj.mu) ./ obj.sigma;
            n_test = size(X_test_std, 1);
            scores = zeros(n_test, obj.num_classes);
            
            for i = 1:n_test
                distances = obj.calculate_distances(X_test_std(i, :));
                [~, idx] = mink(distances, obj.k);
                nearest_labels = obj.Y_train(idx);
                nearest_distances = distances(idx);
                
                % 计算每个类别的得分
                for j = 1:obj.num_classes
                    class_mask = (nearest_labels == obj.classes(j));
                    if strcmp(obj.weights, 'equal')
                        scores(i, j) = sum(class_mask) / obj.k;
                    else
                        % 距离加权得分
                        class_weights = 1 ./ (nearest_distances(class_mask) + eps);
                        scores(i, j) = sum(class_weights) / sum(1./(nearest_distances+eps));
                    end
                end
            end
            
            [~, max_idx] = max(scores, [], 2);
            Y_pred = obj.classes(max_idx);
        end
        
        function accuracy = score(obj, X_test, Y_true)
            % 计算模型准确率
            Y_pred = obj.predict(X_test);
            accuracy = sum(Y_pred == Y_true) / length(Y_true);
        end
        
        function cv_accuracy = cross_val_score(obj, X, Y, cv_folds)
            % 交叉验证评估
            % 输入:
            %   cv_folds - 交叉验证折数 (默认: 5)
            
            if nargin < 4
                cv_folds = 5;
            end
            
            n_samples = size(X, 1);
            indices = crossvalind('Kfold', n_samples, cv_folds);
            cv_accuracy = zeros(cv_folds, 1);
            
            for i = 1:cv_folds
                test_mask = (indices == i);
                train_mask = ~test_mask;
                
                % 临时训练
                temp_knn = myKNNClassifier(obj.k, obj.distance_metric, obj.weights);
                temp_knn.fit(X(train_mask, :), Y(train_mask));
                
                % 测试
                cv_accuracy(i) = temp_knn.score(X(test_mask, :), Y(test_mask));
            end
        end
    end
    
    methods (Access = private)
        function distances = calculate_distances(obj, x)
            % 计算单个样本到所有训练样本的距离
            switch obj.distance_metric
                case 'euclidean'
                    distances = sqrt(sum((obj.X_train - x).^2, 2));
                case 'manhattan'
                    distances = sum(abs(obj.X_train - x), 2);
                case 'chebyshev'
                    distances = max(abs(obj.X_train - x), [], 2);
                case 'minkowski'
                    % 使用默认p=3
                    distances = sum(abs(obj.X_train - x).^3, 2).^(1/3);
                otherwise
                    error('不支持的距离度量方法: %s', obj.distance_metric);
            end
        end
        
        function pred_label = weighted_vote(obj, distances, labels)
            % 距离加权投票
            weights = 1 ./ (distances + eps); % 避免除零
            unique_labels = unique(labels);
            weighted_counts = zeros(length(unique_labels), 1);
            
            for i = 1:length(unique_labels)
                mask = (labels == unique_labels(i));
                weighted_counts(i) = sum(weights(mask));
            end
            
            [~, max_idx] = max(weighted_counts);
            pred_label = unique_labels(max_idx);
        end
    end
end

3. 与MATLAB自带KNN功能对比

function compare_with_matlab_knn()
    % 比较自定义KNN与MATLAB自带KNN的性能
    
    % 生成示例数据
    rng(42); % 设置随机种子以便复现
    [X, Y] = generate_sample_data();
    
    % 分割数据
    cv = cvpartition(Y, 'HoldOut', 0.3);
    X_train = X(cv.training, :);
    Y_train = Y(cv.training, :);
    X_test = X(cv.test, :);
    Y_test = Y(cv.test, :);
    
    % MATLAB自带KNN
    tic;
    matlab_knn = fitcknn(X_train, Y_train, 'NumNeighbors', 5);
    Y_pred_matlab = predict(matlab_knn, X_test);
    matlab_time = toc;
    matlab_accuracy = sum(Y_pred_matlab == Y_test) / length(Y_test);
    
    % 自定义KNN
    tic;
    my_knn = myKNNClassifier(5, 'euclidean', 'equal');
    my_knn.fit(X_train, Y_train);
    Y_pred_my = my_knn.predict(X_test);
    my_time = toc;
    my_accuracy = sum(Y_pred_my == Y_test) / length(Y_test);
    
    % 显示结果
    fprintf('=== KNN分类器性能比较 ===\n');
    fprintf('MATLAB KNN - 准确率: %.4f, 时间: %.4f秒\n', matlab_accuracy, matlab_time);
    fprintf('自定义KNN - 准确率: %.4f, 时间: %.4f秒\n', my_accuracy, my_time);
    fprintf('准确率差异: %.6f\n', abs(matlab_accuracy - my_accuracy));
    
    % 混淆矩阵比较
    figure('Position', [100, 100, 1200, 500]);
    
    subplot(1,2,1);
    cm_matlab = confusionmat(Y_test, Y_pred_matlab);
    confusionchart(cm_matlab);
    title('MATLAB KNN 混淆矩阵');
    
    subplot(1,2,2);
    cm_my = confusionmat(Y_test, Y_pred_my);
    confusionchart(cm_my);
    title('自定义KNN 混淆矩阵');
end

function [X, Y] = generate_sample_data(n_samples, n_features, n_classes)
    % 生成分类示例数据
    if nargin < 1
        n_samples = 1000;
    end
    if nargin < 2
        n_features = 4;
    end
    if nargin < 3
        n_classes = 3;
    end
    
    rng(42);
    X = zeros(n_samples, n_features);
    Y = zeros(n_samples, 1);
    
    samples_per_class = floor(n_samples / n_classes);
    
    for i = 1:n_classes
        start_idx = (i-1) * samples_per_class + 1;
        end_idx = i * samples_per_class;
        
        % 为每个类别生成不同的均值
        mean_vec = 2 * i * ones(1, n_features);
        cov_mat = diag(0.5 + 0.5 * rand(1, n_features));
        
        X(start_idx:end_idx, :) = mvnrnd(mean_vec, cov_mat, samples_per_class);
        Y(start_idx:end_idx) = i;
    end
    
    % 随机打乱数据
    idx = randperm(n_samples);
    X = X(idx, :);
    Y = Y(idx, :);
end

4. 高级功能扩展

function demo_advanced_knn()
    % 展示KNN高级功能
    
    % 加载数据
    [X, Y] = generate_sample_data(500, 3, 3);
    
    % 测试不同K值
    k_values = 1:2:15;
    accuracies = zeros(length(k_values), 1);
    
    figure('Position', [100, 100, 1000, 800]);
    
    for i = 1:length(k_values)
        knn = myKNNClassifier(k_values(i));
        cv_acc = knn.cross_val_score(X, Y, 5);
        accuracies(i) = mean(cv_acc);
    end
    
    subplot(2,2,1);
    plot(k_values, accuracies, 'bo-', 'LineWidth', 2, 'MarkerSize', 8);
    xlabel('K值');
    ylabel('交叉验证准确率');
    title('K值选择');
    grid on;
    
    % 测试不同距离度量
    distance_metrics = {'euclidean', 'manhattan', 'chebyshev'};
    dist_accuracies = zeros(length(distance_metrics), 1);
    
    for i = 1:length(distance_metrics)
        knn = myKNNClassifier(5, distance_metrics{i});
        cv_acc = knn.cross_val_score(X, Y, 5);
        dist_accuracies(i) = mean(cv_acc);
    end
    
    subplot(2,2,2);
    bar(dist_accuracies);
    set(gca, 'XTickLabel', distance_metrics);
    xlabel('距离度量');
    ylabel('准确率');
    title('不同距离度量比较');
    
    % 可视化决策边界
    subplot(2,2,[3,4]);
    visualize_decision_boundary(X, Y);
    title('KNN决策边界');
end

function visualize_decision_boundary(X, Y)
    % 可视化决策边界(适用于2D数据)
    if size(X, 2) ~= 2
        fprintf('只支持2D数据可视化\n');
        return;
    end
    
    % 创建网格
    h = 0.02;
    x1_min = min(X(:,1)) - 1; x1_max = max(X(:,1)) + 1;
    x2_min = min(X(:,2)) - 1; x2_max = max(X(:,2)) + 1;
    [xx1, xx2] = meshgrid(x1_min:h:x1_max, x2_min:h:x2_max);
    
    % 训练KNN并预测网格点
    knn = myKNNClassifier(5);
    knn.fit(X, Y);
    X_grid = [xx1(:), xx2(:)];
    Z = knn.predict(X_grid);
    Z = reshape(Z, size(xx1));
    
    % 绘制决策区域
    contourf(xx1, xx2, Z, 'AlphaData', 0.3);
    hold on;
    
    % 绘制数据点
    classes = unique(Y);
    colors = {'red', 'blue', 'green', 'yellow', 'cyan'};
    for i = 1:length(classes)
        scatter(X(Y==classes(i), 1), X(Y==classes(i), 2), 50, colors{i}, 'filled');
    end
    
    xlabel('特征1');
    ylabel('特征2');
    colorbar;
    hold off;
end

5. 使用示例

% 基本使用示例
function basic_usage_example()
    % 生成数据
    [X, Y] = generate_sample_data(200, 2, 2);
    
    % 创建和训练KNN分类器
    knn = myKNNClassifier(3); % 使用3个最近邻
    knn.fit(X, Y);
    
    % 预测新样本
    X_new = [1.5, 2.0; 3.0, 1.0];
    predictions = knn.predict(X_new);
    fprintf('预测结果: %d, %d\n', predictions);
    
    % 获取预测得分
    [~, scores] = knn.predict_with_scores(X_new);
    fprintf('类别得分:\n');
    disp(scores);
    
    % 交叉验证
    cv_score = knn.cross_val_score(X, Y, 5);
    fprintf('5折交叉验证准确率: %.4f\n', mean(cv_score));
end

% 运行比较
compare_with_matlab_knn();

% 运行高级演示
demo_advanced_knn();

% 运行基本示例
basic_usage_example();

参考代码 knn分类器原理实现,可等同于matlab自带KNN函数 www.youwenfan.com/contentcnm/82555.html

功能特性总结

特性 MATLAB KNN 自定义KNN
基本KNN分类
多种距离度量
加权投票
交叉验证
决策边界可视化
预测概率/得分
数据标准化

这个自定义KNN分类器实现了MATLAB自带KNN函数的核心功能,并提供了良好的扩展性。

posted @ 2025-11-26 10:09  kiyte  阅读(0)  评论(0)    收藏  举报