基于PCA人脸识别算法的Matlab实现(2)
基于人脸年识别算法PCA的另一个matlab工程
妈妈再也不用担心我的人脸识别算法,
但是怎么移植到嵌入式系统上,
要用C重构的话,
我选择死亡。
main.m
clear all
clc
close all
database=[pwd '\ORL'];%使用的人脸库
train_samplesize=5;%每类训练样本数
address=[database '\s'];
rows=112;
cols=92;
ClassNum=40;
tol_num=10;
image_fmt='.bmp';
%------------------------PCA降维
train=1:train_samplesize;
test=train_samplesize+1:tol_num;
train_num=length(train);
test_num=length(test);
train_tol=train_num*ClassNum;
test_tol=test_num*ClassNum;
[train_sample,train_label]=readsample(address,ClassNum,train,rows,cols,image_fmt);
[test_sample,test_label]=readsample(address,ClassNum,test,rows,cols,image_fmt);
for pro_dim=40:10:90
%PCA降维
[Pro_Matrix,Mean_Image]=my_pca(train_sample,pro_dim);
train_project=Pro_Matrix'*train_sample;
test_project=Pro_Matrix'*test_sample;
%单位化
train_norm=normc(train_project);
test_norm=normc(test_project);
accuracy=computaccuracy(train_norm,ClassNum,train_label,test_norm,test_label);
fprintf('投影维数为:%d\n',pro_dim);
fprintf('每类训练样本个数为:%d\n',train_samplesize);
fprintf(2,'识别率为:%3.2f%%\n\n',accuracy*100);
end
computaccuracy.m
function [accuracy,xp,r]=computaccuracy(trainsample,classnum,train_label,testsample,test_label)
test_tol=size(testsample,2);
train_tol=size(trainsample,2);
pre_label=zeros(1,test_tol);
h = waitbar(0,'Please wait...');
for i=1:test_tol
xp = SolveHomotopy_CBM_std(trainsample, testsample(:,i),'lambda', 0.01);
for j=1:classnum
mmu=zeros(train_tol,1);
ind=(j==train_label);
mmu(ind)=xp(ind);
r(j)=norm(testsample(:,i)-trainsample*mmu);
end
[temp,index]=min(r);
pre_label(i)=index;
% computations take place here
per = i / test_tol;
waitbar(per, h ,sprintf('%2.0f%%',per*100))
end
close(h)
accuracy=sum(pre_label==test_label)/test_tol;
l1eq_pd.m
% l1eq_pd.m
%
% Solve
% min_x ||x||_1 s.t. Ax = b
%
% Recast as linear program
% min_{x,u} sum(u) s.t. -u <= x <= u, Ax=b
% and use primal-dual interior point method
%
% Usage: xp = l1eq_pd(x0, A, At, b, pdtol, pdmaxiter, cgtol, cgmaxiter)
%
% x0 - Nx1 vector, initial point.
%
% A - Either a handle to a function that takes a N vector and returns a K
% vector , or a KxN matrix. If A is a function handle, the algorithm
% operates in "largescale" mode, solving the Newton systems via the
% Conjugate Gradients algorithm.
%
% At - Handle to a function that takes a K vector and returns an N vector.
% If A is a KxN matrix, At is ignored.
%
% b - Kx1 vector of observations.
%
% pdtol - Tolerance for primal-dual algorithm (algorithm terminates if
% the duality gap is less than pdtol).
% Default = 1e-3.
%
% pdmaxiter - Maximum number of primal-dual iterations.
% Default = 50.
%
% cgtol - Tolerance for Conjugate Gradients; ignored if A is a matrix.
% Default = 1e-8.
%
% cgmaxiter - Maximum number of iterations for Conjugate Gradients; ignored
% if A is a matrix.
% Default = 200.
%
% Written by: Justin Romberg, Caltech
% Email: jrom@acm.caltech.edu
% Created: October 2005
%
function xp = l1eq_pd(x0, A, At, b, pdtol, pdmaxiter, cgtol, cgmaxiter)
largescale = isa(A,'function_handle');
if (nargin < 5), pdtol = 1e-3; end
if (nargin < 6), pdmaxiter = 50; end
if (nargin < 7), cgtol = 1e-8; end
if (nargin < 8), cgmaxiter = 200; end
N = length(x0);
alpha = 0.01;
beta = 0.5;
mu = 10;
gradf0 = [zeros(N,1); ones(N,1)];
x = x0;
u = (0.95)*abs(x0) + (0.10)*max(abs(x0));
fu1 = x - u;
fu2 = -x - u;
lamu1 = -1./fu1;
lamu2 = -1./fu2;
if (largescale)
v = -A(lamu1-lamu2);
Atv = At(v);
rpri = A(x) - b;
else
v = -A*(lamu1-lamu2);
Atv = A'*v;
rpri = A*x - b;
end
sdg = -(fu1'*lamu1 + fu2'*lamu2);
tau = mu*2*N/sdg;
rcent = [-lamu1.*fu1; -lamu2.*fu2] - (1/tau);
rdual = gradf0 + [lamu1-lamu2; -lamu1-lamu2] + [Atv; zeros(N,1)];
resnorm = norm([rdual; rcent; rpri]);
pditer = 0;
done = (sdg < pdtol) | (pditer >= pdmaxiter);
while (~done)
pditer = pditer + 1;
w1 = -1/tau*(-1./fu1 + 1./fu2) - Atv;
w2 = -1 - 1/tau*(1./fu1 + 1./fu2);
w3 = -rpri;
sig1 = -lamu1./fu1 - lamu2./fu2;
sig2 = lamu1./fu1 - lamu2./fu2;
sigx = sig1 - sig2.^2./sig1;
if (largescale)
w1p = w3 - A(w1./sigx - w2.*sig2./(sigx.*sig1));
h11pfun = @(z) -A(1./sigx.*At(z));
[dv, cgres, cgiter] = cgsolve(h11pfun, w1p, cgtol, cgmaxiter, 0);
if (cgres > 1/2)
disp('Primal-dual: Cannot solve system. Returning previous iterate.');
xp = x;
return
end
dx = (w1 - w2.*sig2./sig1 - At(dv))./sigx;
Adx = A(dx);
Atdv = At(dv);
else
H11p = -A*diag(1./sigx)*A';
w1p = w3 - A*(w1./sigx - w2.*sig2./(sigx.*sig1));
[dv,hcond] = linsolve(H11p,w1p);
if (hcond < 1e-14)
disp('Primal-dual: Matrix ill-conditioned. Returning previous iterate.');
xp = x;
return
end
dx = (w1 - w2.*sig2./sig1 - A'*dv)./sigx;
Adx = A*dx;
Atdv = A'*dv;
end
du = (w2 - sig2.*dx)./sig1;
dlamu1 = (lamu1./fu1).*(-dx+du) - lamu1 - (1/tau)*1./fu1;
dlamu2 = (lamu2./fu2).*(dx+du) - lamu2 - 1/tau*1./fu2;
% make sure that the step is feasible: keeps lamu1,lamu2 > 0, fu1,fu2 < 0
indp = find(dlamu1 < 0); indn = find(dlamu2 < 0);
s = min([1; -lamu1(indp)./dlamu1(indp); -lamu2(indn)./dlamu2(indn)]);
indp = find((dx-du) > 0); indn = find((-dx-du) > 0);
s = (0.99)*min([s; -fu1(indp)./(dx(indp)-du(indp)); -fu2(indn)./(-dx(indn)-du(indn))]);
% backtracking line search
backiter = 0;
xp = x + s*dx; up = u + s*du;
vp = v + s*dv; Atvp = Atv + s*Atdv;
lamu1p = lamu1 + s*dlamu1; lamu2p = lamu2 + s*dlamu2;
fu1p = xp - up; fu2p = -xp - up;
rdp = gradf0 + [lamu1p-lamu2p; -lamu1p-lamu2p] + [Atvp; zeros(N,1)];
rcp = [-lamu1p.*fu1p; -lamu2p.*fu2p] - (1/tau);
rpp = rpri + s*Adx;
while(norm([rdp; rcp; rpp]) > (1-alpha*s)*resnorm)
s = beta*s;
xp = x + s*dx; up = u + s*du;
vp = v + s*dv; Atvp = Atv + s*Atdv;
lamu1p = lamu1 + s*dlamu1; lamu2p = lamu2 + s*dlamu2;
fu1p = xp - up; fu2p = -xp - up;
rdp = gradf0 + [lamu1p-lamu2p; -lamu1p-lamu2p] + [Atvp; zeros(N,1)];
rcp = [-lamu1p.*fu1p; -lamu2p.*fu2p] - (1/tau);
rpp = rpri + s*Adx;
backiter = backiter+1;
if (backiter > 32)
disp('Stuck backtracking, returning last iterate.')
xp = x;
return
end
end
% next iteration
x = xp; u = up;
v = vp; Atv = Atvp;
lamu1 = lamu1p; lamu2 = lamu2p;
fu1 = fu1p; fu2 = fu2p;
% surrogate duality gap
sdg = -(fu1'*lamu1 + fu2'*lamu2);
tau = mu*2*N/sdg;
rpri = rpp;
rcent = [-lamu1.*fu1; -lamu2.*fu2] - (1/tau);
rdual = gradf0 + [lamu1-lamu2; -lamu1-lamu2] + [Atv; zeros(N,1)];
resnorm = norm([rdual; rcent; rpri]);
done = (sdg < pdtol) | (pditer >= pdmaxiter);
%disp(sprintf('Iteration = %d, tau = %8.3e, Primal = %8.3e, PDGap = %8.3e, Dual res = %8.3e, Primal res = %8.3e',...
% pditer, tau, sum(u), sdg, norm(rdual), norm(rpri)));
% if (largescale)
% disp(sprintf(' CG Res = %8.3e, CG Iter = %d', cgres, cgiter));
% else
% disp(sprintf(' H11p condition number = %8.3e', hcond));
% end
end
my_pca.m
function [Pro_Matrix,Mean_Image]=my_pca(Train_SET,Eigen_NUM)
%输入:
%Train_SET:训练样本集,每列是一个样本,每行一类特征,Dim*Train_Num
%Eigen_NUM:投影维数
%输出:
%Pro_Matrix:投影矩阵
%Mean_Image:均值图像
[Dim,Train_Num]=size(Train_SET);
%当训练样本数大于样本维数时,直接分解协方差矩阵
if Dim<=Train_Num
Mean_Image=mean(Train_SET,2);
Train_SET=bsxfun(@minus,Train_SET,Mean_Image);
R=Train_SET*Train_SET'/(Train_Num-1);
[eig_vec,eig_val]=eig(R);
eig_val=diag(eig_val);
[~,ind]=sort(eig_val,'descend');
W=eig_vec(:,ind);
Pro_Matrix=W(:,1:Eigen_NUM);
else
%构造小矩阵,计算其特征值和特征向量,然后映射到大矩阵
Mean_Image=mean(Train_SET,2);
Train_SET=bsxfun(@minus,Train_SET,Mean_Image);
R=Train_SET'*Train_SET/(Train_Num-1);
[eig_vec,eig_val]=eig(R);
eig_val=diag(eig_val);
[val,ind]=sort(eig_val,'descend');
W=eig_vec(:,ind);
Pro_Matrix=Train_SET*W(:,1:Eigen_NUM)*diag(val(1:Eigen_NUM).^(-1/2));
end
end
pca.m
function [newsample,basevector]=pca(patterns,num)
%主分量分析程序,patterns表示输入模式向量,num为控制变量,当num大于1的时候表示
%要求去的特征数为num,当num大于0小于等于1的时候表示求取的特征数的能量为num
%输出:basevector表示求取的最大特征值对应的特征向量,newsample表示在basevector
%映射下获得的样本表示。
[u,v]=size(patterns);
totalsamplemean=mean(patterns);
for i=1:u
gensample(i,:)=patterns(i,:)-totalsamplemean;
end
sigma=gensample*gensample';
[U,V]=eig(sigma);
d=diag(V);
[d1,index]=dsort(d);
if num>1
for i=1:num
vector(:,i)=U(:,index(i));
base(:,i)=d(index(i))^(-1/2)* gensample' * vector(:,i);
end
else
sumv=sum(d1);
for i=1:u
if sum(d1(1:i))/sumv>=num
l=i;
break;
end
end
for i=1:l
vector(:,i)=U(:,index(i));
base(:,i)=d(index(i))^(-1/2)* gensample' * vector(:,i);
end
end
newsample=patterns*base;
basevector=base;
readsample.m
function [sample,label]=readsample(address,ClassNum,data,rows,cols,image_fmt)
%这个函数用来读取样本。
%输入:
%address:要读取的样本的路径
%ClassNum:代表要读入样本的类别数
%data:样本索引
%rows:样本行数
%cols:样本列数
%image_fmt:图片格式
%输出:
%sample:样本矩阵,每列为一个样本,每行为一类特征
%label:样本标签
allsamples=[];
label=[];
ImageSize=rows*cols;
for i=1:ClassNum
for j=data
a=double(imread(strcat(address,num2str(i),'_',num2str(j),image_fmt)));
a=imresize(a,[rows cols]);
b=reshape(a,ImageSize,1);
allsamples=[allsamples,b];
label=[label,i];
end
end
sample=allsamples;
SolveHomotopy_CBM_std.m
%% This function is modified from Matlab Package: L1-Homotopy
% BPDN_homotopy_function.m
%
% Solves the following basis pursuit denoising (BPDN) problem
% min_x \lambda ||x||_1 + 1/2*||y-Ax||_2^2
%
% Inputs:
% A - m x n measurement matrix
% y - measurement vector
% lambda - final value of regularization parameter
% maxiter - maximum number of homotopy iterations
%
% Outputs:
% x_out - output for BPDN
% total_iter - number of homotopy iterations taken by the solver
%
% Written by: Salman Asif, Georgia Tech
% Email: sasif@ece.gatech.edu
%
%-------------------------------------------+
% Copyright (c) 2007. Muhammad Salman Asif
%-------------------------------------------+
function [x_out, e_out, total_iter] = SolveHomotopy_CBM_std(A, y, varargin)
global N n gamma_x z_x xk_temp del_x_vec pk_temp dk epsilon isNonnegative
t0 = tic ;
lambda = 1e-6;
maxiter = 100;
isNonnegative = false;
verbose = false;
xk_1 = [];
tolerance = 1e-4 ;
STOPPING_TIME = -2;
STOPPING_GROUND_TRUTH = -1;
STOPPING_DUALITY_GAP = 1;
STOPPING_SPARSE_SUPPORT = 2;
STOPPING_OBJECTIVE_VALUE = 3;
STOPPING_SUBGRADIENT = 4;
STOPPING_DEFAULT = STOPPING_OBJECTIVE_VALUE;
stoppingCriterion = STOPPING_DEFAULT;
% Parse the optional inputs.
if (mod(length(varargin), 2) ~= 0 ),
error(['Extra Parameters passed to the function ''' mfilename ''' must be passed in pairs.']);
end
parameterCount = length(varargin)/2;
for parameterIndex = 1:parameterCount,
parameterName = varargin{parameterIndex*2 - 1};
parameterValue = varargin{parameterIndex*2};
switch lower(parameterName)
case 'stoppingcriterion'
stoppingCriterion = parameterValue;
case 'initialization'
xk_1 = parameterValue;
if ~all(size(xk_1)==[n,1])
error('The dimension of the initial x0 does not match.');
end
case 'groundtruth'
xG = parameterValue;
case 'lambda'
lambda = parameterValue;
case 'maxiteration'
maxiter = parameterValue;
case 'isnonnegative'
isNonnegative = parameterValue;
case 'tolerance'
tolerance = parameterValue;
case 'verbose'
verbose = parameterValue;
case 'maxtime'
maxTime = parameterValue;
otherwise
error(['The parameter ''' parameterName ''' is not recognized by the function ''' mfilename '''.']);
end
end
clear varargin
[K, n] = size(A);
B = [A eye(K)];
At = A';
Bt = B';
N = K + n;
timeSteps = nan(1,maxiter) ;
% Initialization of primal and dual sign and support
z_x = zeros(N,1);
gamma_x = []; % Primal support
% Initial step
Primal_constrk = -[At*y; y];
if isNonnegative
[c i] = min(Primal_constrk);
c = max(-c, 0);
else
[c i] = max(abs(Primal_constrk));
end
epsilon = c;
nz_x = zeros(N,1);
if isempty(xk_1)
xk_1 = zeros(N,1);
gamma_xk = i;
else
gamma_xk = find(abs(xk_1)>eps*10);
nz_x(gamma_xk) = 1;
end
f = epsilon*norm(xk_1,1) + 1/2*norm(y - A*xk_1(1:n) - xk_1(n+1:end))^2;
z_x(gamma_xk) = -sign(Primal_constrk(gamma_xk));
%Primal_constrk(gamma_xk) = sign(Primal_constrk(gamma_xk))*epsilon;
z_xk = z_x;
% loop parameters
iter = 0;
out_x = [];
old_delta = 0;
count_delta_stop = 0;
gamma_xkx = gamma_xk(gamma_xk<=n);
gamma_xke = gamma_xk(gamma_xk>n)-n;
AtgxAgx = [At(gamma_xkx,:)*A(:,gamma_xkx) At(gamma_xkx,gamma_xke); A(gamma_xke,gamma_xkx) eye(length(gamma_xke))];
iAtgxAgx = inv(AtgxAgx);
while iter < maxiter
iter = iter+1;
gamma_x = gamma_xk;
gamma_xx = gamma_x(gamma_x<=n);
gamma_xe = gamma_x(gamma_x>n)-n;
z_x = z_xk;
x_k = xk_1;
%%%%%%%%%%%%%%%%%%%%%
%%%% update on x %%%%
%%%%%%%%%%%%%%%%%%%%%
% Update direction
del_x = iAtgxAgx*z_x(gamma_x);
del_x_vec = zeros(N,1);
del_x_vec(gamma_x) = del_x;
Asupported = B(:,gamma_x);
Agdelx = Asupported*del_x;
dk = [At*Agdelx; Agdelx];
%%% CONTROL THE MACHINE PRECISION ERROR AT EVERY OPERATION: LIKE BELOW.
pk_temp = Primal_constrk;
gammaL_temp = find(abs(abs(Primal_constrk)-epsilon)<min(epsilon,2*eps));
pk_temp(gammaL_temp) = sign(Primal_constrk(gammaL_temp))*epsilon;
xk_temp = x_k;
xk_temp(abs(x_k)<2*eps) = 0;
%%%---
% Compute the step size
[i_delta, delta, out_x] = update_primal(out_x);
if old_delta < 4*eps && delta < 4*eps
count_delta_stop = count_delta_stop + 1;
if count_delta_stop >= 500
if verbose
disp('stuck in some corner');
end
break;
end
else
count_delta_stop = 0;
end
old_delta = delta;
xk_1 = x_k+delta*del_x_vec;
Primal_constrk = Primal_constrk+delta*dk;
epsilon_old = epsilon;
epsilon = epsilon-delta;
if epsilon <= lambda;
% xk_1 = x_k + (epsilon_old-lambda)*del_x_vec;
% disp('Reach prescribed lambda in SolveHomotopy_CBM.');
break;
end
timeSteps(iter) = toc(t0) ;
% if mod(iter, 100) == 0
% disp([ 'epsilon = ' num2str(epsilon) ', time =' num2str(timeSteps(iter))]);
% end
% compute stopping criteria and test for termination
keep_going = true;
if delta~=0
switch stoppingCriterion
case STOPPING_GROUND_TRUTH
keep_going = norm(xk_1(1:n)-xG)>tolerance;
case STOPPING_SPARSE_SUPPORT
nz_x_prev = nz_x;
nz_x = (abs(xk_1)>eps*10);
num_nz_x = sum(nz_x(:));
num_changes_active = (sum(nz_x(:)~=nz_x_prev(:)));
if num_nz_x >= 1
criterionActiveSet = num_changes_active / num_nz_x;
keep_going = (criterionActiveSet > tolerance);
end
case STOPPING_DUALITY_GAP
error('Duality gap is not a valid stopping criterion for Homotopy.');
case STOPPING_OBJECTIVE_VALUE
% continue if not yeat reached target value tolA
prev_f = f;
f = lambda*norm(xk_1,1) + 1/2*norm(y-Asupported*xk_1(gamma_x))^2;
keep_going = (abs((prev_f-f)/prev_f) > tolerance);
case STOPPING_SUBGRADIENT
keep_going = norm(delta*del_x_vec)>tolerance;
case STOPPING_TIME
keep_going = timeSteps(iter) < maxTime ;
otherwise,
error('Undefined stopping criterion');
end % end of the stopping criteria switch
end
if ~keep_going
disp('Maximum time reached!');
break;
end
if ~isempty(out_x)
% If an element is removed from gamma_x
len_gamma = length(gamma_x);
outx_index = find(gamma_x==out_x(1));
gamma_x(outx_index) = gamma_x(end);
gamma_x = gamma_x(1:end-1);
gamma_xk = gamma_x;
rowi = outx_index; % ith row of A is swapped with last row (out_x)
colj = outx_index; % jth column of A is swapped with last column (out_lambda)
AtgxAgx_ij = AtgxAgx;
temp_row = AtgxAgx_ij(rowi,:);
AtgxAgx_ij(rowi,:) = AtgxAgx_ij(len_gamma,:);
AtgxAgx_ij(len_gamma,:) = temp_row;
temp_col = AtgxAgx_ij(:,colj);
AtgxAgx_ij(:,colj) = AtgxAgx_ij(:,len_gamma);
AtgxAgx_ij(:,len_gamma) = temp_col;
iAtgxAgx_ij = iAtgxAgx;
temp_row = iAtgxAgx_ij(colj,:);
iAtgxAgx_ij(colj,:) = iAtgxAgx_ij(len_gamma,:);
iAtgxAgx_ij(len_gamma,:) = temp_row;
temp_col = iAtgxAgx_ij(:,rowi);
iAtgxAgx_ij(:,rowi) = iAtgxAgx_ij(:,len_gamma);
iAtgxAgx_ij(:,len_gamma) = temp_col;
AtgxAgx = AtgxAgx_ij(1:len_gamma-1,1:len_gamma-1);
nn = size(AtgxAgx_ij,1);
%delete columns
Q11 = iAtgxAgx_ij(1:nn-1,1:nn-1);
Q12 = iAtgxAgx_ij(1:nn-1,nn);
Q21 = iAtgxAgx_ij(nn,1:nn-1);
Q22 = iAtgxAgx_ij(nn,nn);
Q12Q21_Q22 = Q12*(Q21/Q22);
iAtgxAgx = Q11 - Q12Q21_Q22;
xk_1(out_x(1)) = 0;
else
% If an element is added to gamma_x
gamma_xk = [gamma_xk; i_delta];
if i_delta>n
AtgxAnx = Bt(gamma_x,i_delta-n);
AtgxAgx_mod = [AtgxAgx AtgxAnx; AtgxAnx' 1];
else
AtgxAnx = Bt(gamma_x,:)*A(:,i_delta);
AtgxAgx_mod = [AtgxAgx AtgxAnx; AtgxAnx' A(:,i_delta).'*A(:,i_delta)];
end
AtgxAgx = AtgxAgx_mod;
%iAtgxAgx = update_inverse(AtgxAgx, iAtgxAgx,1);
nn = size(AtgxAgx,1);
% add columns
iA11 = iAtgxAgx;
iA11A12 = iA11*AtgxAgx(1:nn-1,nn);
A21iA11 = AtgxAgx(nn,1:nn-1)*iA11;
S = AtgxAgx(nn,nn)-AtgxAgx(nn,1:nn-1)*iA11A12;
Q11_right = iA11A12*(A21iA11/S);
% Q11 = iA11+ Q11_right;
% Q12 = -iA11A12/S;
% Q21 = -A21iA11/S;
% Q22 = 1/S;
iAtgxAgx = zeros(nn);
%iAtB = [Q11 Q12; Q21 Q22];
iAtgxAgx(1:nn-1,1:nn-1) = iA11+ Q11_right;
iAtgxAgx(1:nn-1,nn) = -iA11A12/S;
iAtgxAgx(nn,1:nn-1) = -A21iA11/S;
iAtgxAgx(nn,nn) = 1/S;
xk_1(i_delta) = 0;
end
z_xk = zeros(N,1);
z_xk(gamma_xk) = -sign(Primal_constrk(gamma_xk));
Primal_constrk(gamma_x) = sign(Primal_constrk(gamma_x))*epsilon;
end
total_iter = iter;
x_out = xk_1(1:n);
e_out = xk_1(n+1:end);
timeSteps = timeSteps(1:total_iter-1) ;
% Debiasing
% x_out = zeros(N,1);
% x_out(intersect(gamma_x,1:n)) = A(:,intersect(gamma_x,1:n))\(y-e_out);
% update_primal.m
%
% This function computes the minimum step size in the primal update direction and
% finds change in the primal or dual support with that step.
%
% Inputs:
% gamma_x - current support of x
% gamma_lambda - current support of lambda
% z_x - sign sequence of x
% z_lambda - sign sequence of lambda
% del_x_vec - primal update direction
% pk_temp
% dk
% epsilon - current value of epsilon
% out_lambda - element removed from support of lambda in previous step (if any)
%
% Outputs:
% i_delta - index corresponding to newly active primal constraint (new_lambda)
% out_x - element in x shrunk to zero
% delta - primal step size
%
% Written by: Salman Asif, Georgia Tech
% Email: sasif@ece.gatech.edu
function [i_delta, delta, out_x] = update_primal(out_x)
global N n gamma_x z_x xk_temp del_x_vec pk_temp dk epsilon isNonnegative
gamma_lc = setdiff(1:N, union(gamma_x, out_x));
gamma_lc_nonneg = setdiff(n+1:N, union(gamma_x, out_x));
if isNonnegative
delta1_constr = (epsilon-pk_temp(gamma_lc_nonneg))./(1+dk(gamma_lc_nonneg));
delta1_pos_ind = find(delta1_constr>0);
delta1_pos = delta1_constr(delta1_pos_ind);
[delta1 i_delta1] = min(delta1_pos);
if isempty(delta1)
delta1 = inf;
end
else
delta1_constr = (epsilon-pk_temp(gamma_lc))./(1+dk(gamma_lc));
delta1_pos_ind = find(delta1_constr>0);
delta1_pos = delta1_constr(delta1_pos_ind);
[delta1 i_delta1] = min(delta1_pos);
if isempty(delta1)
delta1 = inf;
end
end
delta2_constr = (epsilon+pk_temp(gamma_lc))./(1-dk(gamma_lc));
delta2_pos_ind = find(delta2_constr>0);
delta2_pos = delta2_constr(delta2_pos_ind);
[delta2 i_delta2] = min(delta2_pos);
if isempty(delta2)
delta2 = inf;
end
if delta1>delta2
delta = delta2;
i_delta = gamma_lc(delta2_pos_ind(i_delta2));
else
delta = delta1;
if isNonnegative
i_delta = gamma_lc_nonneg(delta1_pos_ind(i_delta1));
else
i_delta = gamma_lc(delta1_pos_ind(i_delta1));
end
end
delta3_constr = (-xk_temp(gamma_x)./del_x_vec(gamma_x));
delta3_pos_index = find(delta3_constr>0);
[delta3 i_delta3] = min(delta3_constr(delta3_pos_index));
out_x_index = gamma_x(delta3_pos_index(i_delta3));
out_x = [];
if ~isempty(delta3) && (delta3 > 0) && (delta3 <= delta)
delta = delta3;
out_x = out_x_index;
end
%%% THESE ARE PROBABLY UNNECESSARY
%%% NEED TO REMOVE THEM.
% The following checks are just to deal with degenerate cases when more
% than one elements want to enter or leave the support at any step
% (e.g., Bernoulli matrix with small number of measurements)
% This one is ONLY for those indices which are zero. And we don't know where
% will its dx point in next steps, so after we calculate dx and its in opposite
% direction to z_x, we will have to remove that index from the support.
xk_1 = xk_temp+delta*del_x_vec;
xk_1(out_x) = 0;
wrong_sign = find(sign(xk_1(gamma_x)).*z_x(gamma_x)==-1);
if isNonnegative
gamma_x_nonneg = intersect(gamma_x, 1:n);
wrong_sign = union(wrong_sign, find(xk_1(gamma_x_nonneg)<0));
end
if ~isempty(gamma_x(wrong_sign))
delta = 0;
% can also choose specific element which became non-zero first but all
% that matters here is AtA(gx,gl) doesn't become singular.
[val_wrong_x ind_wrong_x] = sort(abs(del_x_vec(gamma_x(wrong_sign))),'descend');
out_x = gamma_x(wrong_sign(ind_wrong_x));
end
% If more than one primal constraints became active in previous iteration i.e.,
% more than one elements wanted to enter the support and we added only one.
% So here we need to check if those remaining elements are still active.
i_delta_temp = gamma_lc(abs(pk_temp(gamma_lc)+delta*dk(gamma_lc))-(epsilon-delta) >= 10*eps);
if ~isempty(i_delta_temp)
i_delta_more = i_delta_temp;
if (length(i_delta_more)>=1) && (~any((i_delta_temp==i_delta)))
% ideal way would be to check that incoming element doesn't make AtA
% singular!
[v_temp i_temp] = max(-pk_temp(i_delta_more)./dk(i_delta_more));
i_delta = i_delta_more(i_temp);
delta = 0;
out_x = [];
end
end
不想做机器学习的硬件工程师不是好的CCIE
浙公网安备 33010602011771号