Lasso算法在数据挖掘中的深入解析与MATLAB实现

Lasso算法在数据挖掘中的深入解析与MATLAB实现

Lasso(Least Absolute Shrinkage and Selection Operator)是一种广泛应用于数据挖掘和机器学习中的回归分析方法,特别擅长处理高维数据和特征选择问题。

Lasso算法核心原理

Lasso通过在普通最小二乘回归中加入L1正则化项实现特征选择和模型简化:

目标函数

\(\min_{\beta} \left\{ \frac{1}{2N} \sum_{i=1}^{N} (y_i - \beta_0 - \sum_{j=1}^{p} \beta_j x_{ij})^2 + \lambda \sum_{j=1}^{p} |\beta_j| \right\}\)

其中:

  • λ:正则化参数,控制惩罚强度
  • βj:特征系数
  • p:特征数量
  • N:样本数量

L1正则化的关键特性是能够将部分特征系数压缩为零,实现特征选择。

MATLAB代码

基础Lasso实现

function [beta, lambda_opt] = lasso_regression(X, y, varargin)
% LASSO回归实现
% 输入:
%   X: 特征矩阵 (n×p)
%   y: 响应变量 (n×1)
%  可选参数:
%   'Lambda': 正则化参数值或向量
%   'CV': 交叉验证折数
%   'Alpha': 弹性网络混合参数 (1为纯Lasso)
% 输出:
%   beta: 系数向量
%   lambda_opt: 最优lambda值

% 解析输入参数
p = inputParser;
addParameter(p, 'Lambda', logspace(-4, 1, 100), @isnumeric);
addParameter(p, 'CV', 10, @isscalar);
addParameter(p, 'Alpha', 1, @(x) x>0 && x<=1);
parse(p, varargin{:});

lambda = p.Results.Lambda;
k = p.Results.CV;
alpha = p.Results.Alpha;

% 数据标准化
X = normalize(X);
y = normalize(y);
[n, p] = size(X);

% 使用坐标下降法求解
if length(lambda) == 1
    beta = lasso_cd(X, y, lambda, alpha);
    lambda_opt = lambda;
else
    % 交叉验证选择最优lambda
    cv_indices = crossvalind('Kfold', n, k);
    mse_cv = zeros(length(lambda), k);
    
    for i = 1:k
        test_idx = (cv_indices == i);
        train_idx = ~test_idx;
        
        X_train = X(train_idx, :);
        y_train = y(train_idx);
        X_test = X(test_idx, :);
        y_test = y(test_idx);
        
        for j = 1:length(lambda)
            beta_j = lasso_cd(X_train, y_train, lambda(j), alpha);
            y_pred = X_test * beta_j;
            mse_cv(j, i) = mean((y_test - y_pred).^2);
        end
    end
    
    % 计算平均MSE
    mse_mean = mean(mse_cv, 2);
    [~, min_idx] = min(mse_mean);
    lambda_opt = lambda(min_idx);
    
    % 使用最优lambda训练全模型
    beta = lasso_cd(X, y, lambda_opt, alpha);
end

% 绘制结果
plot_lasso_results(X, y, beta, lambda_opt, mse_mean);

end

function beta = lasso_cd(X, y, lambda, alpha, max_iter, tol)
% 坐标下降法求解Lasso
if nargin < 5, max_iter = 1000; end
if nargin < 6, tol = 1e-4; end

[n, p] = size(X);
beta = zeros(p, 1); % 初始化系数
r = y; % 初始残差

for iter = 1:max_iter
    beta_old = beta;
    
    for j = 1:p
        % 计算当前特征的伪残差
        r_j = r + X(:, j) * beta(j);
        
        % 计算软阈值
        rho_j = X(:, j)' * r_j / n;
        
        % 更新系数
        if rho_j < -lambda * alpha / 2
            beta(j) = (rho_j + lambda * alpha / 2) / (X(:, j)' * X(:, j) / n);
        elseif rho_j > lambda * alpha / 2
            beta(j) = (rho_j - lambda * alpha / 2) / (X(:, j)' * X(:, j) / n);
        else
            beta(j) = 0;
        end
        
        % 更新残差
        r = r_j - X(:, j) * beta(j);
    end
    
    % 检查收敛
    if norm(beta - beta_old) < tol
        break;
    end
end
end

function plot_lasso_results(X, y, beta, lambda_opt, mse_cv)
% 可视化Lasso结果
figure('Position', [100, 100, 1200, 800], 'Color', 'w');

% 1. 特征系数可视化
subplot(2, 2, 1);
stem(beta, 'filled', 'LineWidth', 1.5);
xlabel('特征索引');
ylabel('系数值');
title('Lasso系数');
grid on;
hold on;
plot(xlim, [0 0], 'k--');
non_zero_idx = find(beta ~= 0);
plot(non_zero_idx, beta(non_zero_idx), 'ro', 'MarkerSize', 8);
legend('系数', '非零系数');

% 2. 预测值与实际值比较
subplot(2, 2, 2);
y_pred = X * beta;
scatter(y, y_pred, 50, 'filled');
hold on;
plot([min(y), max(y)], [min(y), max(y)], 'r--', 'LineWidth', 1.5);
xlabel('实际值');
ylabel('预测值');
title('预测性能');
grid on;
axis equal;
R2 = 1 - sum((y - y_pred).^2) / sum((y - mean(y)).^2);
text(0.05, 0.95, sprintf('R² = %.3f', R2), 'Units', 'normalized');

% 3. 交叉验证误差曲线
if ~isempty(mse_cv)
    subplot(2, 2, 3);
    semilogx(lambda_opt, min(mse_cv), 'ro', 'MarkerSize', 10, 'MarkerFaceColor', 'r');
    hold on;
    semilogx(lambda, mse_cv, 'b-', 'LineWidth', 1.5);
    xlabel('正则化参数 \lambda');
    ylabel('交叉验证均方误差');
    title('正则化路径');
    grid on;
    legend('最优 \lambda', 'MSE');
end

% 4. 特征重要性
subplot(2, 2, 4);
[~, idx] = sort(abs(beta), 'descend');
bar(abs(beta(idx)), 'FaceColor', [0.5 0.5 0.8]);
xlabel('特征排序');
ylabel('系数绝对值');
title('特征重要性排序');
grid on;
end

弹性网络扩展

function beta = elastic_net(X, y, lambda, alpha, max_iter, tol)
% 弹性网络回归 (L1 + L2正则化)
% alpha: L1和L2的混合参数 (0=Ridge, 1=Lasso)

if nargin < 4, alpha = 0.5; end
if nargin < 5, max_iter = 1000; end
if nargin < 6, tol = 1e-4; end

[n, p] = size(X);
beta = zeros(p, 1);
r = y;
lambda1 = lambda * alpha;
lambda2 = lambda * (1 - alpha);

for iter = 1:max_iter
    beta_old = beta;
    
    for j = 1:p
        % 计算当前特征的伪残差
        r_j = r + X(:, j) * beta(j);
        rho_j = X(:, j)' * r_j / n;
        
        % 更新系数 (弹性网络)
        if rho_j < -lambda1 / 2
            beta(j) = (rho_j + lambda1 / 2) / (X(:, j)' * X(:, j) / n + lambda2);
        elseif rho_j > lambda1 / 2
            beta(j) = (rho_j - lambda1 / 2) / (X(:, j)' * X(:, j) / n + lambda2);
        else
            beta(j) = 0;
        end
        
        % 更新残差
        r = r_j - X(:, j) * beta(j);
    end
    
    % 检查收敛
    if norm(beta - beta_old) < tol
        break;
    end
end
end

关键应用场景

1. 高维特征选择

% 生成高维数据 (n=100, p=1000)
X = randn(100, 1000);
true_beta = zeros(1000, 1);
true_beta(randperm(1000, 10)) = randn(10, 1); % 10个相关特征
y = X * true_beta + randn(100, 1)*0.5;

% 应用Lasso
[beta, lambda_opt] = lasso_regression(X, y, 'CV', 5);

% 评估特征选择性能
selected_features = find(beta ~= 0);
true_features = find(true_beta ~= 0);
precision = sum(ismember(selected_features, true_features)) / length(selected_features);
recall = sum(ismember(true_features, selected_features)) / length(true_features);
fprintf('特征选择精度: %.2f, 召回率: %.2f\n', precision, recall);

2. 多重共线性处理

% 创建具有多重共线性的数据
X1 = randn(100, 1);
X2 = X1 + randn(100, 1)*0.1; % 高度相关
X3 = randn(100, 1);
X = [X1, X2, X3, randn(100, 7)]; % 10个特征
y = 3*X1 + 2*X3 + randn(100, 1)*0.5;

% 普通线性回归
lin_reg = fitlm(X, y);
disp('普通线性回归系数:');
disp(lin_reg.Coefficients.Estimate(2:end)');

% Lasso回归
[beta, ~] = lasso_regression(X, y);
disp('Lasso系数:');
disp(beta');

3. 预测模型构建

% 加载加州房价数据集
load california_housing.mat; % X_train, y_train, X_test, y_test

% 训练Lasso模型
[beta, lambda_opt] = lasso_regression(X_train, y_train, 'CV', 10);

% 测试集预测
y_pred = X_test * beta;

% 评估指标
mse = mean((y_test - y_pred).^2);
mae = mean(abs(y_test - y_pred));
R2 = 1 - sum((y_test - y_pred).^2) / sum((y_test - mean(y_test)).^2);

fprintf('测试集性能:\nMSE=%.4f, MAE=%.4f, R²=%.4f\n', mse, mae, R2);
fprintf('选择的特征数: %d/%d\n', nnz(beta), length(beta));

参考代码 针对数据挖掘中的lasso算法 youwenfan.com/contentcnb/82518.html

结论

Lasso算法是数据挖掘中强大的特征选择和回归建模工具,特别适用于高维数据集。通过MATLAB实现,我们可以:

  1. 高效处理特征数大于样本数的场景
  2. 自动识别最相关特征,提高模型可解释性
  3. 构建更稳健的预测模型,防止过拟合
  4. 扩展应用于各种复杂场景(时间序列、分组特征等)

实际应用中,应结合交叉验证进行参数调优,并考虑使用弹性网络等扩展方法处理高度相关特征。Lasso与其他方法的系统比较有助于选择最适合特定问题的建模策略。

posted @ 2025-08-08 16:22  kang_ms  阅读(96)  评论(0)    收藏  举报