SVM在高光谱遥感图像分类与预测中的MATLAB实现

SVM在高光谱分类中的优势

优势 说明
小样本学习 在高光谱标注样本有限的情况下仍能有效学习
高维处理 适合处理高光谱数据的高维特征
非线性分类 通过核函数处理复杂的非线性分类问题
泛化能力强 基于结构风险最小化原理,泛化性能好

完整MATLAB实现代码

1. 数据加载与预处理

function [data, labels, feature_names] = load_hyperspectral_data()
    % 加载高光谱数据
    % 这里以Indian Pines数据集为例
    
    % 如果已有数据文件,可以直接加载
    % load('Indian_pines.mat');
    % load('Indian_pines_gt.mat');
    
    % 或者使用模拟数据演示
    fprintf('生成模拟高光谱数据...\n');
    
    num_samples = 2000;    % 总样本数
    num_bands = 200;       % 波段数
    num_classes = 6;       % 地物类别数
    
    % 生成模拟高光谱数据(正态分布模拟不同地物)
    data = zeros(num_samples, num_bands);
    labels = zeros(num_samples, 1);
    
    % 不同类别具有不同的光谱特征
    class_centers = linspace(0.3, 0.8, num_classes);
    for i = 1:num_classes
        class_samples = floor(num_samples / num_classes);
        start_idx = (i-1) * class_samples + 1;
        end_idx = min(i * class_samples, num_samples);
        
        % 为每个类别生成具有特定光谱特征的数据
        for j = 1:num_bands
            center_val = class_centers(i) * sin(j/50) + 0.2;
            data(start_idx:end_idx, j) = center_val + 0.1 * randn(end_idx-start_idx+1, 1);
        end
        labels(start_idx:end_idx) = i;
    end
    
    % 特征名称(波段)
    feature_names = arrayfun(@(x) sprintf('Band_%d', x), 1:num_bands, 'UniformOutput', false);
    
    fprintf('数据维度: %d × %d\n', size(data, 1), size(data, 2));
    fprintf('类别数量: %d\n', num_classes);
    fprintf('类别分布: \n');
    tabulate(labels);
end

2. 特征提取与降维

function [features_selected, selected_indices] = feature_selection_hyperspectral(data, labels, method)
    % 高光谱特征选择
    % method: 'PCA', 'SPA', 'CARS', 'RF'
    
    switch method
        case 'PCA'
            % 主成分分析
            [coeff, score, ~, ~, explained] = pca(data);
            
            % 选择累计贡献率>95%的主成分
            cum_explained = cumsum(explained);
            num_components = find(cum_explained >= 95, 1);
            features_selected = score(:, 1:num_components);
            selected_indices = 1:num_components;
            
            fprintf('PCA选择 %d 个主成分 (累计方差: %.2f%%)\n', ...
                    num_components, cum_explained(num_components));
            
        case 'SPA'
            % 连续投影算法 - 简化实现
            num_selected = min(30, size(data, 2));
            selected_indices = successive_projections_algorithm(data, num_selected);
            features_selected = data(:, selected_indices);
            
        case 'RF'
            % 基于随机森林的特征重要性
            tree = fitensemble(data, labels, 'Bag', 100, 'Tree', 'Type', 'Classification');
            imp = oobPermutedPredictorImportance(tree);
            [~, idx] = sort(imp, 'descend');
            selected_indices = idx(1:min(50, length(idx)));
            features_selected = data(:, selected_indices);
            
        otherwise
            % 默认使用所有特征
            features_selected = data;
            selected_indices = 1:size(data, 2);
    end
end

function selected_indices = successive_projections_algorithm(data, k)
    % 简化的连续投影算法实现
    [n, p] = size(data);
    selected_indices = zeros(1, k);
    
    % 选择初始波长(反射率方差最大的)
    [~, selected_indices(1)] = max(var(data));
    
    for i = 2:k
        available_indices = setdiff(1:p, selected_indices(1:i-1));
        projections = zeros(length(available_indices), 1);
        
        for j = 1:length(available_indices)
            idx = available_indices(j);
            % 计算投影向量
            x_j = data(:, idx);
            proj_sum = 0;
            
            for m = 1:i-1
                x_m = data(:, selected_indices(m));
                proj_sum = proj_sum + (x_j' * x_m) / (x_m' * x_m) * x_m;
            end
            
            projections(j) = norm(x_j - proj_sum);
        end
        
        [~, max_idx] = max(projections);
        selected_indices(i) = available_indices(max_idx);
    end
end

3. SVM分类器实现

function svm_model = train_svm_classifier(features, labels, kernel_type)
    % 训练SVM分类器
    % kernel_type: 'linear', 'rbf', 'polynomial'
    
    % 数据标准化
    features = zscore(features);
    
    % 设置SVM参数
    switch kernel_type
        case 'linear'
            template = templateSVM('KernelFunction', 'linear', ...
                                  'BoxConstraint', 1, ...
                                  'Standardize', true);
        case 'rbf'
            template = templateSVM('KernelFunction', 'rbf', ...
                                  'BoxConstraint', 1, ...
                                  'KernelScale', 'auto', ...
                                  'Standardize', true);
        case 'polynomial'
            template = templateSVM('KernelFunction', 'polynomial', ...
                                  'BoxConstraint', 1, ...
                                  'PolynomialOrder', 3, ...
                                  'Standardize', true);
    end
    
    % 训练多类SVM分类器
    svm_model = fitcecoc(features, labels, ...
                        'Learners', template, ...
                        'Coding', 'onevsone', ...
                        'Verbose', 1);
    
    fprintf('SVM分类器训练完成 (核函数: %s)\n', kernel_type);
end

function [accuracy, confusion_mat, class_report] = evaluate_svm_model(model, features_test, labels_test)
    % 评估SVM模型性能
    
    % 预测
    labels_pred = predict(model, features_test);
    
    % 计算准确率
    accuracy = sum(labels_pred == labels_test) / length(labels_test);
    
    % 混淆矩阵
    confusion_mat = confusionmat(labels_test, labels_pred);
    
    % 各类别性能指标
    unique_labels = unique(labels_test);
    class_report = struct();
    
    for i = 1:length(unique_labels)
        true_positive = sum((labels_test == unique_labels(i)) & (labels_pred == unique_labels(i)));
        false_positive = sum((labels_test ~= unique_labels(i)) & (labels_pred == unique_labels(i)));
        false_negative = sum((labels_test == unique_labels(i)) & (labels_pred ~= unique_labels(i)));
        
        precision = true_positive / (true_positive + false_positive + eps);
        recall = true_positive / (true_positive + false_negative + eps);
        f1_score = 2 * (precision * recall) / (precision + recall + eps);
        
        class_report(i).Class = unique_labels(i);
        class_report(i).Precision = precision;
        class_report(i).Recall = recall;
        class_report(i).F1_Score = f1_score;
        class_report(i).Support = sum(labels_test == unique_labels(i));
    end
    
    fprintf('测试集准确率: %.4f\n', accuracy);
end

4. 参数优化与交叉验证

function [best_model, best_params] = optimize_svm_parameters(features, labels)
    % SVM参数优化
    
    % 创建优化变量
    box_constraint = optimizableVariable('BoxConstraint', [0.1, 100], 'Transform', 'log');
    kernel_scale = optimizableVariable('KernelScale', [0.1, 100], 'Transform', 'log');
    
    % 对于RBF核
    if size(features, 2) > 10  % 如果特征维度较高,使用RBF核
        kernel_function = 'rbf';
        variables = [box_constraint, kernel_scale];
    else
        kernel_function = 'linear';
        variables = box_constraint;
    end
    
    % 目标函数
    fun = @(params)svm_crossval_loss(features, labels, kernel_function, params);
    
    % 贝叶斯优化
    results = bayesopt(fun, variables, ...
                      'MaxTime', 300, ...
                      'IsObjectiveDeterministic', false, ...
                      'NumSeedPoints', 10, ...
                      'Verbose', 1);
    
    % 获取最佳参数
    best_params = results.XAtMinObjective;
    
    % 使用最佳参数训练最终模型
    if strcmp(kernel_function, 'rbf')
        template = templateSVM('KernelFunction', kernel_function, ...
                              'BoxConstraint', best_params.BoxConstraint, ...
                              'KernelScale', best_params.KernelScale, ...
                              'Standardize', true);
    else
        template = templateSVM('KernelFunction', kernel_function, ...
                              'BoxConstraint', best_params.BoxConstraint, ...
                              'Standardize', true);
    end
    
    best_model = fitcecoc(features, labels, 'Learners', template);
    
    fprintf('参数优化完成\n');
end

function loss = svm_crossval_loss(features, labels, kernel_function, params)
    % SVM交叉验证损失函数
    
    % 设置SVM参数
    if strcmp(kernel_function, 'rbf')
        template = templateSVM('KernelFunction', kernel_function, ...
                              'BoxConstraint', params.BoxConstraint, ...
                              'KernelScale', params.KernelScale, ...
                              'Standardize', true);
    else
        template = templateSVM('KernelFunction', kernel_function, ...
                              'BoxConstraint', params.BoxConstraint, ...
                              'Standardize', true);
    end
    
    % 5折交叉验证
    cv_model = crossval(fitcecoc(features, labels, 'Learners', template), 'KFold', 5);
    
    % 计算分类误差
    loss = kfoldLoss(cv_model);
end

5. 完整的分类系统

function hyperspectral_svm_classification_system()
    % 高光谱遥感图像SVM分类完整系统
    
    close all; clc;
    
    fprintf('=== 高光谱遥感图像SVM分类系统 ===\n\n');
    
    %% 1. 数据加载与探索
    fprintf('步骤1: 数据加载...\n');
    [data, labels, feature_names] = load_hyperspectral_data();
    
    % 数据探索
    figure('Position', [100, 100, 1200, 800]);
    subplot(2, 3, 1);
    plot_mean_spectra(data, labels);
    title('各类别平均光谱曲线');
    
    %% 2. 数据预处理
    fprintf('步骤2: 数据预处理...\n');
    
    % 划分训练集和测试集 (70%训练, 30%测试)
    rng(42);  % 设置随机种子确保可重复性
    cv = cvpartition(labels, 'HoldOut', 0.3);
    
    data_train = data(cv.training, :);
    labels_train = labels(cv.training);
    data_test = data(cv.test, :);
    labels_test = labels(cv.test);
    
    fprintf('训练集: %d 样本\n', size(data_train, 1));
    fprintf('测试集: %d 样本\n', size(data_test, 1));
    
    %% 3. 特征选择
    fprintf('步骤3: 特征选择...\n');
    
    feature_methods = {'PCA', 'RF', 'None'};
    feature_results = struct();
    
    for i = 1:length(feature_methods)
        method = feature_methods{i};
        fprintf('  使用 %s 方法进行特征选择...\n', method);
        
        [features_train, selected_idx] = feature_selection_hyperspectral(data_train, labels_train, method);
        features_test = data_test(:, selected_idx);
        
        % 存储结果
        feature_results(i).Method = method;
        feature_results(i).FeaturesTrain = features_train;
        feature_results(i).FeaturesTest = features_test;
        feature_results(i).SelectedIndices = selected_idx;
    end
    
    %% 4. SVM模型训练与评估
    fprintf('步骤4: SVM模型训练...\n');
    
    kernel_types = {'linear', 'rbf'};
    results = struct();
    result_count = 1;
    
    for feat_idx = 1:length(feature_methods)
        for kernel_idx = 1:length(kernel_types)
            fprintf('  训练: %s特征 + %s核SVM\n', ...
                    feature_methods{feat_idx}, kernel_types{kernel_idx});
            
            % 训练SVM模型
            svm_model = train_svm_classifier(...
                feature_results(feat_idx).FeaturesTrain, ...
                labels_train, kernel_types{kernel_idx});
            
            % 模型评估
            [accuracy, confusion_mat, class_report] = evaluate_svm_model(...
                svm_model, ...
                feature_results(feat_idx).FeaturesTest, ...
                labels_test);
            
            % 存储结果
            results(result_count).FeatureMethod = feature_methods{feat_idx};
            results(result_count).KernelType = kernel_types{kernel_idx};
            results(resultCount).Accuracy = accuracy;
            results(resultCount).ConfusionMatrix = confusion_mat;
            results(resultCount).ClassReport = class_report;
            results(resultCount).Model = svm_model;
            
            result_count = result_count + 1;
        end
    end
    
    %% 5. 参数优化
    fprintf('步骤5: 参数优化...\n');
    best_feat_idx = 1;  % 选择PCA特征进行优化
    [optimized_model, best_params] = optimize_svm_parameters(...
        feature_results(best_feat_idx).FeaturesTrain, labels_train);
    
    % 评估优化后的模型
    [opt_accuracy, opt_confusion, opt_report] = evaluate_svm_model(...
        optimized_model, ...
        feature_results(best_feat_idx).FeaturesTest, ...
        labels_test);
    
    results(result_count).FeatureMethod = 'PCA_Optimized';
    results(result_count).KernelType = 'rbf';
    results(result_count).Accuracy = opt_accuracy;
    results(result_count).ConfusionMatrix = opt_confusion;
    results(result_count).ClassReport = opt_report;
    results(result_count).Model = optimized_model;
    results(result_count).BestParams = best_params;
    
    %% 6. 结果可视化
    fprintf('步骤6: 结果可视化...\n');
    plot_classification_results(results, data_test, labels_test, feature_results);
    
    %% 7. 性能总结
    print_performance_summary(results);
end

6. 可视化函数

function plot_mean_spectra(data, labels)
    % 绘制各类别平均光谱曲线
    
    unique_labels = unique(labels);
    colors = lines(length(unique_labels));
    
    hold on;
    for i = 1:length(unique_labels)
        class_data = data(labels == unique_labels(i), :);
        mean_spectrum = mean(class_data, 1);
        std_spectrum = std(class_data, 0, 1);
        
        x = 1:length(mean_spectrum);
        plot(x, mean_spectrum, 'Color', colors(i,:), 'LineWidth', 2, ...
             'DisplayName', sprintf('Class %d', unique_labels(i)));
        
        % 绘制标准差区域
        patch([x, fliplr(x)], ...
              [mean_spectrum + std_spectrum, fliplr(mean_spectrum - std_spectrum)], ...
              colors(i,:), 'FaceAlpha', 0.2, 'EdgeColor', 'none');
    end
    hold off;
    
    xlabel('波段');
    ylabel('反射率');
    legend('show');
    grid on;
end

function plot_classification_results(results, data_test, labels_test, feature_results)
    % 绘制分类结果
    
    figure('Position', [100, 100, 1400, 1000]);
    
    % 1. 准确率比较
    subplot(2, 3, 1);
    accuracies = [results.Accuracy];
    methods = cellfun(@(x,y) sprintf('%s\n%s', x, y), ...
                     {results.FeatureMethod}, {results.KernelType}, ...
                     'UniformOutput', false);
    
    bar(accuracies);
    set(gca, 'XTickLabel', methods, 'XTickLabelRotation', 45);
    ylabel('准确率');
    title('不同方法准确率比较');
    grid on;
    
    % 添加数值标签
    for i = 1:length(accuracies)
        text(i, accuracies(i) + 0.01, sprintf('%.3f', accuracies(i)), ...
             'HorizontalAlignment', 'center');
    end
    
    % 2. 最佳模型的混淆矩阵
    subplot(2, 3, 2);
    [~, best_idx] = max(accuracies);
    best_confusion = results(best_idx).ConfusionMatrix;
    
    imagesc(best_confusion);
    colorbar;
    title(sprintf('最佳模型混淆矩阵\n(%s + %s)', ...
          results(best_idx).FeatureMethod, results(best_idx).KernelType));
    xlabel('预测类别');
    ylabel('真实类别');
    
    % 3. 各类别F1分数
    subplot(2, 3, 3);
    best_report = results(best_idx).ClassReport;
    f1_scores = [best_report.F1_Score];
    class_labels = [best_report.Class];
    
    bar(f1_scores);
    set(gca, 'XTickLabel', arrayfun(@num2str, class_labels, 'UniformOutput', false));
    ylabel('F1分数');
    title('各类别F1分数');
    grid on;
    
    % 4. 特征重要性(如果使用RF特征选择)
    subplot(2, 3, 4);
    rf_idx = find(strcmp({results.FeatureMethod}, 'RF'), 1);
    if ~isempty(rf_idx)
        % 这里可以绘制特征重要性图
        plot(1:length(feature_results(2).SelectedIndices), ...
             ones(1, length(feature_results(2).SelectedIndices)), 'o-');
        title('选择的特征波段');
        xlabel('特征索引');
        ylabel('选择状态');
        grid on;
    end
    
    % 5. PCA投影可视化
    subplot(2, 3, 5);
    pca_features = feature_results(1).FeaturesTest;
    if size(pca_features, 2) >= 2
        scatter(pca_features(:,1), pca_features(:,2), 30, labels_test, 'filled');
        xlabel('第一主成分');
        ylabel('第二主成分');
        title('PCA投影可视化');
        colorbar;
    end
    
    % 6. 学习曲线(简化)
    subplot(2, 3, 6);
    plot(accuracies, 'o-', 'LineWidth', 2, 'MarkerSize', 8);
    xlabel('实验编号');
    ylabel('准确率');
    title('模型性能趋势');
    grid on;
    
    sgtitle('高光谱图像SVM分类结果分析', 'FontSize', 14, 'FontWeight', 'bold');
end

function print_performance_summary(results)
    % 打印性能总结
    
    fprintf('\n=== 性能总结 ===\n');
    fprintf('%-20s %-10s %-8s\n', '方法', '核函数', '准确率');
    fprintf('----------------------------------------\n');
    
    for i = 1:length(results)
        fprintf('%-20s %-10s %.4f\n', ...
                results(i).FeatureMethod, ...
                results(i).KernelType, ...
                results(i).Accuracy);
    end
    
    [best_acc, best_idx] = max([results.Accuracy]);
    fprintf('\n最佳模型: %s + %s核SVM\n', ...
            results(best_idx).FeatureMethod, results(best_idx).KernelType);
    fprintf('最佳准确率: %.4f\n', best_acc);
    
    if isfield(results(best_idx), 'BestParams')
        fprintf('最佳参数:\n');
        disp(results(best_idx).BestParams);
    end
end

7. 预测新样本

function [predicted_labels, scores] = predict_new_samples(model, new_data, feature_method)
    % 对新样本进行预测
    
    % 特征选择(需要与训练时使用相同的方法)
    if nargin > 2 && ~strcmp(feature_method, 'None')
        % 在实际应用中,这里需要保存训练时的特征选择参数
        new_features = feature_selection_hyperspectral(new_data, [], feature_method);
    else
        new_features = new_data;
    end
    
    % 标准化
    new_features = zscore(new_features);
    
    % 预测
    [predicted_labels, scores] = predict(model, new_features);
    
    fprintf('完成 %d 个新样本的预测\n', size(new_data, 1));
end

使用说明

  1. 运行完整系统
hyperspectral_svm_classification_system();
  1. 单独训练模型
[data, labels] = load_hyperspectral_data();
features = feature_selection_hyperspectral(data, labels, 'PCA');
model = train_svm_classifier(features, labels, 'rbf');
  1. 预测新数据
new_labels = predict_new_samples(model, new_data, 'PCA');

参考代码 SVM分类用于高光谱遥感图像分类、预测 www.youwenfan.com/contentcnl/79960.html

关键改进建议

  1. 实际数据适配

    • 替换load_hyperspectral_data函数以加载真实的高光谱数据
    • 调整特征选择参数以适应具体数据集
  2. 性能优化

    • 对于大数据集,考虑使用随机子空间或特征bagging
    • 使用GPU加速SVM训练(如果可用)
  3. 高级技术

    • 结合空间-光谱特征(如Gabor滤波、形态学剖面)
    • 使用深度特征提取器预处理数据
posted @ 2025-11-09 10:05  kiyte  阅读(0)  评论(0)    收藏  举报