线性判别分析(LDA)降维算法的原理、实现及应用

1. LDA算法原理

基本思想

LDA是一种有监督的降维方法,目标是找到能够最大化类间距离、最小化类内距离的特征投影方向。

数学原理

  • 类内散度矩阵\(S_w = \sum_{i=1}^c \sum_{x \in X_i} (x - \mu_i)(x - \mu_i)^T\)
  • 类间散度矩阵\(S_b = \sum_{i=1}^c n_i(\mu_i - \mu)(\mu_i - \mu)^T\)
  • 目标函数\(J(W) = \frac{W^T S_b W}{W^T S_w W}\)

2. MATLAB实现

基础LDA实现

function [W, projected_data] = lda(X, y, k)
% LDA线性判别分析
% 输入:
%   X: 数据矩阵 (n×d),n个样本,d个特征
%   y: 标签向量 (n×1)
%   k: 目标维度
% 输出:
%   W: 投影矩阵 (d×k)
%   projected_data: 降维后的数据 (n×k)

    [n, d] = size(X);
    classes = unique(y);
    c = length(classes);
    
    % 计算总体均值
    total_mean = mean(X);
    
    % 初始化散度矩阵
    S_w = zeros(d, d);
    S_b = zeros(d, d);
    
    % 计算类内散度矩阵和类间散度矩阵
    for i = 1:c
        class_idx = (y == classes(i));
        X_class = X(class_idx, :);
        n_i = sum(class_idx);
        
        % 类内散度
        class_mean = mean(X_class);
        class_cov = (X_class - class_mean)' * (X_class - class_mean);
        S_w = S_w + class_cov;
        
        % 类间散度
        mean_diff = (class_mean - total_mean)';
        S_b = S_b + n_i * (mean_diff * mean_diff');
    end
    
    % 解决广义特征值问题:S_b * W = λ * S_w * W
    [eigen_vectors, eigen_values] = eig(S_b, S_w);
    
    % 排序特征值(降序)
    [~, sort_idx] = sort(diag(eigen_values), 'descend');
    eigen_vectors_sorted = eigen_vectors(:, sort_idx);
    
    % 选择前k个特征向量
    W = eigen_vectors_sorted(:, 1:min(k, c-1));
    
    % 投影数据
    projected_data = X * W;
end

改进的LDA实现(处理奇异矩阵)

function [W, projected_data, explained] = lda_improved(X, y, k)
% 改进的LDA实现,处理奇异矩阵问题

    [n, d] = size(X);
    classes = unique(y);
    c = length(classes);
    
    % 如果k大于最大可能维度,进行调整
    max_k = min(d, c-1);
    if k > max_k
        k = max_k;
        warning('目标维度调整为 %d', k);
    end
    
    total_mean = mean(X);
    S_w = zeros(d, d);
    S_b = zeros(d, d);
    
    % 计算散度矩阵
    for i = 1:c
        class_idx = (y == classes(i));
        X_class = X(class_idx, :);
        n_i = sum(class_idx);
        
        class_mean = mean(X_class);
        class_cov = cov(X_class) * (n_i - 1); % 无偏估计
        
        S_w = S_w + class_cov;
        
        mean_diff = (class_mean - total_mean)';
        S_b = S_b + n_i * (mean_diff * mean_diff');
    end
    
    % 添加正则化防止奇异矩阵
    regularization = 1e-6 * eye(d);
    S_w = S_w + regularization;
    
    % 使用SVD解决数值稳定性问题
    [U, S, V] = svd(S_w);
    
    % 计算S_w的逆的平方根
    S_inv_sqrt = U * diag(1./sqrt(diag(S))) * V';
    
    % 转换问题为标准特征值问题
    S_b_transformed = S_inv_sqrt' * S_b * S_inv_sqrt;
    
    [eigen_vectors, eigen_values] = eig(S_b_transformed);
    
    % 排序特征值
    [eigen_values_sorted, sort_idx] = sort(diag(eigen_values), 'descend');
    eigen_vectors_sorted = eigen_vectors(:, sort_idx);
    
    % 选择特征向量
    W_transformed = eigen_vectors_sorted(:, 1:k);
    
    % 转换回原始空间
    W = S_inv_sqrt * W_transformed;
    
    % 投影数据
    projected_data = X * W;
    
    % 计算方差解释率
    explained = eigen_values_sorted(1:k) / sum(eigen_values_sorted) * 100;
end

3. 可视化示例

function demo_lda()
% LDA演示示例

    % 生成示例数据(3类,4维)
    rng(42); % 设置随机种子保证可重复性
    
    n_per_class = 50;
    X1 = mvnrnd([2, 2, 1, 1], eye(4), n_per_class);
    X2 = mvnrnd([-1, -1, 2, 2], eye(4), n_per_class);
    X3 = mvnrnd([0, 3, -1, 0], eye(4), n_per_class);
    
    X = [X1; X2; X3];
    y = [ones(n_per_class, 1); 2*ones(n_per_class, 1); 3*ones(n_per_class, 1)];
    
    % 应用LDA降维到2维
    k = 2;
    [W, projected_data, explained] = lda_improved(X, y, k);
    
    % 可视化结果
    figure('Position', [100, 100, 1200, 400]);
    
    % 原始数据(前两个维度)
    subplot(1, 3, 1);
    gscatter(X(:, 1), X(:, 2), y);
    title('原始数据(前两个维度)');
    xlabel('特征1'); ylabel('特征2');
    legend('类1', '类2', '类3');
    
    % LDA降维结果
    subplot(1, 3, 2);
    gscatter(projected_data(:, 1), projected_data(:, 2), y);
    title('LDA降维结果(2维)');
    xlabel(sprintf('LDA分量1 (%.1f%%)', explained(1)));
    ylabel(sprintf('LDA分量2 (%.1f%%)', explained(2)));
    legend('类1', '类2', '类3');
    
    % 投影矩阵可视化
    subplot(1, 3, 3);
    imagesc(W);
    colorbar;
    title('投影矩阵 W');
    xlabel('LDA分量');
    ylabel('原始特征');
    
    % 输出统计信息
    fprintf('数据信息:\n');
    fprintf('  样本数: %d\n', size(X, 1));
    fprintf('  原始维度: %d\n', size(X, 2));
    fprintf('  降维后维度: %d\n', k);
    fprintf('  类别数: %d\n', length(unique(y)));
    fprintf('方差解释率:\n');
    for i = 1:k
        fprintf('  分量%d: %.2f%%\n', i, explained(i));
    end
end

4. 与PCA的比较

function compare_lda_pca()
% 比较LDA和PCA的性能

    % 加载数据(使用MATLAB内置数据集)
    load fisheriris;
    X = meas;
    y = grp2idx(species);
    
    % 数据标准化
    X = zscore(X);
    
    % LDA降维
    k = 2;
    [W_lda, X_lda] = lda_improved(X, y, k);
    
    % PCA降维
    [coeff, score, ~, ~, explained_pca] = pca(X);
    X_pca = score(:, 1:k);
    
    % 可视化比较
    figure('Position', [100, 100, 1000, 400]);
    
    subplot(1, 2, 1);
    gscatter(X_pca(:, 1), X_pca(:, 2), y);
    title('PCA降维结果');
    xlabel(sprintf('PC1 (%.1f%%)', explained_pca(1)));
    ylabel(sprintf('PC2 (%.1f%%)', explained_pca(2)));
    legend('Setosa', 'Versicolor', 'Virginica');
    
    subplot(1, 2, 2);
    gscatter(X_lda(:, 1), X_lda(:, 2), y);
    title('LDA降维结果');
    xlabel('LDA分量1');
    ylabel('LDA分量2');
    legend('Setosa', 'Versicolor', 'Virginica');
    
    % 计算类内距离和类间距离比率
    fprintf('分类性能评估:\n');
    
    % LDA的类间类内距离比
    [~, lda_ratio] = calculate_separation_ratio(X_lda, y);
    fprintf('LDA 类间/类内距离比: %.4f\n', lda_ratio);
    
    % PCA的类间类内距离比
    [~, pca_ratio] = calculate_separation_ratio(X_pca, y);
    fprintf('PCA 类间/类内距离比: %.4f\n', pca_ratio);
end

function [separation_ratio, overall_ratio] = calculate_separation_ratio(X, y)
% 计算类间距离与类内距离的比率

    classes = unique(y);
    c = length(classes);
    
    % 计算总体均值
    total_mean = mean(X);
    
    % 类内距离
    within_distance = 0;
    for i = 1:c
        class_idx = (y == classes(i));
        X_class = X(class_idx, :);
        class_mean = mean(X_class);
        
        % 类内距离平方和
        within_distance = within_distance + sum(sum((X_class - class_mean).^2, 2));
    end
    
    % 类间距离
    between_distance = 0;
    for i = 1:c
        class_idx = (y == classes(i));
        n_i = sum(class_idx);
        class_mean = mean(X(class_idx, :));
        
        % 类间距离平方和
        between_distance = between_distance + n_i * sum((class_mean - total_mean).^2);
    end
    
    separation_ratio = between_distance / within_distance;
    overall_ratio = separation_ratio;
end

5. 实际应用案例

function lda_classification_demo()
% LDA在分类问题中的应用演示

    % 加载葡萄酒数据集
    [wine_data, wine_labels] = load_wine_data();
    
    if isempty(wine_data)
        fprintf('无法加载葡萄酒数据集,使用鸢尾花数据集代替\n');
        load fisheriris;
        X = meas;
        y = grp2idx(species);
        feature_names = {'SepalLength', 'SepalWidth', 'PetalLength', 'PetalWidth'};
    else
        X = wine_data;
        y = wine_labels;
        feature_names = {};
    end
    
    % 数据标准化
    X = zscore(X);
    
    % 分割训练集和测试集
    rng(123);
    cv = cvpartition(y, 'HoldOut', 0.3);
    X_train = X(cv.training, :);
    y_train = y(cv.training);
    X_test = X(cv.test, :);
    y_test = y(cv.test);
    
    % 方法1: 直接使用分类器
    fprintf('方法1: 直接使用原始特征\n');
    mdl_direct = fitcdiscr(X_train, y_train);
    y_pred_direct = predict(mdl_direct, X_test);
    accuracy_direct = sum(y_pred_direct == y_test) / length(y_test);
    fprintf('直接分类准确率: %.2f%%\n', accuracy_direct * 100);
    
    % 方法2: LDA降维后分类
    fprintf('\n方法2: LDA降维后分类\n');
    k = min(2, length(unique(y_train)) - 1);
    [W_lda, X_train_lda] = lda_improved(X_train, y_train, k);
    X_test_lda = X_test * W_lda;
    
    mdl_lda = fitcdiscr(X_train_lda, y_train);
    y_pred_lda = predict(mdl_lda, X_test_lda);
    accuracy_lda = sum(y_pred_lda == y_test) / length(y_test);
    fprintf('LDA降维后分类准确率: %.2f%%\n', accuracy_lda * 100);
    
    % 可视化决策边界
    if k == 2
        figure('Position', [100, 100, 800, 400]);
        
        subplot(1, 2, 1);
        plot_decision_boundary(X_train, y_train, mdl_direct, '原始特征决策边界');
        
        subplot(1, 2, 2);
        plot_decision_boundary_lda(X_train_lda, y_train, mdl_lda, 'LDA特征决策边界');
    end
end

function plot_decision_boundary_lda(X, y, model, title_str)
% 绘制LDA决策边界

    h = 0.02;
    x1_min = min(X(:, 1)) - 1; x1_max = max(X(:, 1)) + 1;
    x2_min = min(X(:, 2)) - 1; x2_max = max(X(:, 2)) + 1;
    
    [xx1, xx2] = meshgrid(x1_min:h:x1_max, x2_min:h:x2_max);
    Z = predict(model, [xx1(:), xx2(:)]);
    Z = reshape(Z, size(xx1));
    
    contourf(xx1, xx2, Z, 'Alpha', 0.3);
    hold on;
    gscatter(X(:, 1), X(:, 2), y);
    title(title_str);
    xlabel('LDA分量1');
    ylabel('LDA分量2');
    hold off;
end

6. LDA算法特点总结

特性 描述
监督学习 需要标签信息
最大降维数 类别数-1
目标 最大化类间方差,最小化类内方差
适用场景 分类问题、特征提取
优点 考虑类别信息,分类效果好
缺点 对数据分布假设较强,需要足够样本

参考代码 线性判别分析LDA降维算法 www.youwenfan.com/contentcnl/80282.html

使用建议

  1. 数据预处理:建议先进行标准化处理
  2. 维度选择:最大降维维度为min(特征数, 类别数-1)
  3. 奇异矩阵处理:添加正则化项提高数值稳定性
  4. 与PCA结合:可先用PCA降维,再用LDA进一步优化
posted @ 2025-11-12 09:36  kiyte  阅读(72)  评论(0)    收藏  举报