稀疏贝叶斯分类算法(Relevance Vector Machine, RVM)
稀疏贝叶斯分类是一种基于贝叶斯推断的分类方法,最著名的实现是 相关向量机(Relevance Vector Machine, RVM),由 Michael E. Tipping 于 2001 年提出。它通过引入自动相关性确定(ARD)先验,迫使大部分权重趋于零,从而实现稀疏性,同时输出概率预测。
一、核心思想
给定训练集 (\(\{(\mathbf{x}_i, t_i)\}_{i=1}^N\)),其中 (\(t_i \in \{0,1\}\))(二分类),RVM 采用类似于支持向量机(SVM)的基函数展开,但使用贝叶斯框架:
其中 (\(K(\cdot,\cdot)\)) 是核函数(如 RBF 核)。分类决策通过 sigmoid 函数链接:
二、贝叶斯建模
2.1 似然函数
假设训练样本独立,似然函数为:
2.2 先验:自动相关性确定(ARD)
对每个权重 (\(w_i\)) 赋予高斯先验,但每个有自己的超参数 (\(\alpha_i\))(精度):
当 (\(\alpha_i \to \infty\)) 时,对应权重 (\(w_i\)) 被强制为零,从而移除该基函数(稀疏性)。
2.3 超先验
对超参数 (\(\alpha_i\)) 和噪声精度 (\(\beta\))(分类问题中固定为 1)赋予 Gamma 分布:
通常取无信息先验 (\(a=b=c=d=10^{-6}\))。
三、推断与学习(Laplace 近似)
由于后验 (\(p(\mathbf{w}|\mathbf{t},\boldsymbol{\alpha})\)) 非高斯,采用 Laplace 近似:
-
固定当前 (\(\boldsymbol{\alpha}\)),找到权重后验众数 (\(\mathbf{w}_{\text{MP}}\))(通过迭代重加权最小二乘,IRLS):
\[\mathbf{w}_{\text{new}} = \mathbf{w}_{\text{old}} - \mathbf{H}^{-1}\nabla E \]其中 (\(E = -\ln p(\mathbf{t}|\mathbf{w}) - \frac{1}{2}\mathbf{w}^T\mathbf{A}\mathbf{w}\)),(\(\mathbf{A} = \text{diag}(\alpha_i)\)),(\(\mathbf{H}\)) 是 Hessian 矩阵。
-
更新超参数(通过最大化边际似然的证据近似):
\[\alpha_i^{\text{new}} = \frac{\gamma_i}{w_i^2}, \quad \gamma_i = 1 - \alpha_i \Sigma_{ii} \]其中 (\(\Sigma = (\mathbf{H} + \mathbf{A})^{-1}\)) 是后验协方差。
-
重复直到收敛。最终只有少量 (\alpha_i) 有限(对应“相关向量”),其余趋于无穷大。
四、预测
对于新样本 (\(\mathbf{x}_*\)),预测概率:
其中 (\(\mathbf{k}_* = [1, K(\mathbf{x}_*, \mathbf{x}_1), \dots, K(\mathbf{x}_*, \mathbf{x}_N)]^T\))。
五、优点与缺点
| 优点 | 缺点 |
|---|---|
| 极高的稀疏性(通常比 SVM 更稀疏) | 训练复杂度较高((O(N^3))) |
| 输出概率(可进行不确定性估计) | 核函数选择仍需经验 |
| 无需交叉验证(超参数自动学习) | 对初始值敏感 |
| 可自然地扩展到回归和多分类 | 大样本时内存消耗大 |
六、MATLAB 实现(核心函数)
以下提供一个简洁的 RVM 二分类训练函数(使用 Laplace 近似 + IRLS),不依赖任何工具箱。
function [rvm_model] = rvm_train(X, t, kernel, max_iter, tol)
% RVM 二分类训练
% X: N×d 训练数据
% t: N×1 标签 {0,1}
% kernel: 核函数句柄 @(x1,x2) value
% max_iter: 最大迭代次数
% tol: 收敛容忍度
% 返回结构体包含:权重 w,相关向量索引 RV,超参数 alpha,核函数等
N = size(X,1);
% 设计矩阵 Phi: N×(N+1) 包含偏置项
Phi = [ones(N,1), zeros(N,N)];
for i = 1:N
for j = 1:N
Phi(i, j+1) = kernel(X(i,:), X(j,:));
end
end
% 初始化超参数 alpha (N+1 维)
alpha = ones(N+1, 1) * 1e6; % 初始很大,鼓励稀疏
% 但为了启动,给一个较小的 alpha 让权重有自由度
alpha(1) = 1e-3; % 偏置项
% 初始化权重(使用正则化逻辑回归)
w = zeros(N+1, 1);
% IRLS 循环
for iter = 1:max_iter
% 计算 y = Phi * w
y = Phi * w;
% Sigmoid
sigma = 1 ./ (1 + exp(-y));
% 避免数值问题
sigma = max(min(sigma, 1-1e-12), 1e-12);
% 对角线权重矩阵 R (N×N)
R = diag(sigma .* (1 - sigma));
% Hessian: H = -Phi' * R * Phi - diag(alpha)
H = -Phi' * R * Phi - diag(alpha);
% 梯度: g = Phi' * (t - sigma) - diag(alpha) * w
g = Phi' * (t - sigma) - alpha .* w;
% 牛顿更新
delta = -H \ g;
w_new = w + delta;
% 更新 alpha (使用 MacKay 的 evidence procedure)
Sigma = -inv(H); % 后验协方差
gamma = 1 - alpha .* diag(Sigma);
alpha_new = gamma ./ (w_new.^2 + eps);
% 固定偏置项的 alpha 为小值(不使其无穷大)
alpha_new(1) = min(alpha_new(1), 1e-3);
% 检查收敛
change = norm(w_new - w) / norm(w + eps);
w = w_new;
alpha = alpha_new;
if change < tol
break;
end
end
% 找出相关向量(alpha 有限且非极大)
RV_idx = find(alpha < 1e6); % 阈值可调
% 移除 alpha 极大的索引(对应权重接近零)
% 更严格:alpha > 1e6 视为无穷大
RV_idx = intersect(RV_idx, find(abs(w) > 1e-6));
% 保存模型
rvm_model.w = w;
rvm_model.alpha = alpha;
rvm_model.RV_idx = RV_idx;
rvm_model.X = X;
rvm_model.kernel = kernel;
rvm_model.Phi = Phi;
end
function prob = rvm_predict(model, X_test)
% 预测概率
N_test = size(X_test,1);
N_RV = length(model.RV_idx);
Phi_test = [ones(N_test,1), zeros(N_test, N_RV)];
for i = 1:N_test
for j = 1:N_RV
rv_idx = model.RV_idx(j);
Phi_test(i, j+1) = model.kernel(X_test(i,:), model.X(rv_idx,:));
end
end
y = Phi_test * model.w([1; model.RV_idx]);
prob = 1 ./ (1 + exp(-y));
end
使用示例:
% 生成二分类数据
X = [randn(50,2)+1; randn(50,2)-1];
t = [ones(50,1); zeros(50,1)];
kernel = @(x1,x2) exp(-sum((x1-x2).^2)/2); % RBF 核 sigma=1
model = rvm_train(X, t, kernel, 100, 1e-4);
fprintf('相关向量个数: %d\n', length(model.RV_idx));
% 预测
prob = rvm_predict(model, X);
pred = prob > 0.5;
accuracy = mean(pred == t);
fprintf('训练准确率: %.2f%%\n', accuracy*100);
参考代码 稀疏贝叶斯分类算法 www.youwenfan.com/contentcnv/81299.html
七、应用场景
- 基因表达数据分类(高维小样本,稀疏性至关重要)
- 文本分类(词袋特征,大量冗余特征)
- 故障诊断(需要概率输出以评估风险)
- 脑机接口(EEG 信号分类,稀疏性帮助解释)
八、与 SVM 的关键区别
| 方面 | SVM | RVM |
|---|---|---|
| 输出 | 判别函数(无概率) | 概率输出 |
| 稀疏性 | 支持向量数量较多 | 通常更稀疏 |
| 超参数 | C, γ(需交叉验证) | 自动学习 |
| 核函数 | 必须满足 Mercer 条件 | 任意核 |
| 训练复杂度 | (O(N^2)) ~ (O(N^3)) | (O(N^3)) |
九、扩展阅读
- 多分类:采用一对多(OvR)或概率结对比较(pairwise coupling)。
- 回归 RVM:类似框架,但似然为高斯分布,可解析求解。
- 增量 RVM:在线学习版本。
浙公网安备 33010602011771号