SVD implement, MATLAB

Task description

  1. Implement a Matlab function which computes the SVD of a given matrix.

  2. Most internal functions, like "svd", are not allowed to be used in the codes. The only exceptions are: Matrix/Vector addition and multiplication, 2-norm.

  3. Compare your result with the internal "svd"-function.

  4. Find a gray-scale image and show the rank-r approximation of the image with different r's.

SVD implement

Pseudo code and some supplements

A is a \(m \times n\) matrix, our target is to get \(A = U\Sigma V^*\).

\(m \geq n\) ; (Since we can transpose \(A\) before calculating and then transepose $\Sigma $ back and swap \(U\) and \(V\) to have the output correct)

call my_svd(A)

function my_svd(M):

  • Trivial situation : M is a \(m\times 1\) matrix. Easily done and return;

  • General circumstances:

    1. $\sigma = $norm (M);

    2. find a non-zero eigen-vector of \(MM^*\), i.e. find $\vb u $ s.t. \(MM^* \vb u = \sigma^2 \vb u, \vb u \neq \vb 0\).

    3. extend \(\vb u\) to an orthogonal basis \(U_{cur}\) of \(\mathbb{C}^{size(u)}\) (using gschmidt());

    4. Find \(\vb v\) s.t. \(A\vb v = \sigma \vb u\).

    5. extend \(\vb v\) to an orthogonal basis \(V_{cur}\) of \(\mathbb{C}^{size(v)}\) (using gschmidt());

    6. $B \leftarrow U_{cur}^*MV_{cur} $,

      \(B\) delete its first column and first row;

    7. call \([U_B,\Sigma_B, V_B]\) = my_svd(B);

    8. \(U \leftarrow U_{cur} \left[\begin{array}{cc}1 & 0 \\ 0 & U_{B}\end{array}\right]\)

      \(V \leftarrow V_{cur} \left[\begin{array}{cc}1 & 0 \\ 0 & V_{B}\end{array}\right]\)

      \(\Sigma \leftarrow \left[\begin{array}{cc}1 & 0 \\ 0 & \Sigma_B\end{array}\right]\)

      return;

Code

The code below can solve svd correctly, for any \(n\times m\) martix;

But the variable name or the function name don't necessarily corresond to the algorithm according to their name (Because this code have a long maintain period and the algorithm do change irrespective of the function name or variable name).

format short
A = randi([0,10],[3,4]);
% Somehow input a matrix A

% /////////////////start 
[m,n] = size(A);
flag = 0;
if m < n
    A = A.';
    flag = 1;
end
[m,n] = size(A);
% /////////////////end


disp("_________my_svd:_______________");
[U D V] = mysvd(A)
disp("_________matlab_svd:_______________");
[U D V] = svd(A)


% /////////////////start
if flag == 1 
    T = U;
    U = V;
    V = T;
    D = D.';
end
% /////////////////end


function [U,D,V] = mysvd(B)
    [m,n] = size(B);
    U = eye(m,m);
    V = eye(n,n);

    if B == 0
        D = B;
        return;
    end

%     if n == 1
%         if B < 0
%             D = -1*B;
%             U = -1*U;
%             return;
%         end
%         D = B;
%         return;
%     end

    if n == 1
       sigma = norm(B);
       u = B;
       u = stdlize(u);
       rnd_mU = [u rnd_std_m(u)];
       U = gschmidt(rnd_mU);
       V = 1;
       D = zeros(m,n);
       D(1,1) = sigma;
       return;
    end


    sigma = norm(B);
    M = B * B';
    M = M - (sigma^2 * eye(m,m));
    [M,s] = my_gauss(M);

    M = del_null(M);

    u = [find_eigen(M) -1]';
    u = stdlize(u);


    rnd_mU = [u rnd_std_m(u)];
    U_cur = gschmidt(rnd_mU);

%     v = B^(-1) * u;
%     v = v * sigma;
%     v

%     dddet = det([B sigma*u])
    [MT,xxx] = my_gauss([B sigma*u]);
    MT = del_null(MT);
    v = find_eigen(MT)';
    

    rnd_mV = [v rnd_std_m(v)];
    V_cur = gschmidt(rnd_mV);
    
%     U_cur
%     B 
%     V_cur
    B = U_cur' * B * V_cur;
    
    B(1,:) = [];
    B(:,1) = [];
    [U_b,D_b,V_b] = mysvd(B);
    U_b = m_ext(U_b);
    V_b = m_ext(V_b);
    D_b = m_ext(D_b);
    D = D_b;
    U = U_cur * U_b;
    V = V_cur * V_b;
    D(1,1) = sigma;

end

function [M] = del_null(M)
    [m,n] = size(M);
    for i = 1:m
%         abs_of_Miiabs = M(i,i)
        if abs(M(i,i)) < 1e-4
            M(i,:) = [];
        end
    end
    
end

function [xx] = m_ext(a)
    [m, n] = size(a);
    a = [zeros([m,1]) a];
    [m, n] = size(a);
    a = [zeros([1,n]); a]; 
    a(1,1) = 1;
    xx = a;
end

function [v] = stdlize(u)
    sum = 0;
    for i = 1:size(u)
        sum = sum + u(i) * u(i);
    end
    v = u / sqrt(sum);
end

function [a] = rnd_std_m(u)
    [m,n] = size(u);
    a = randi([0, 100],[m,m-1]);
    
    for i = 1:(m-1)
%         stdlizzz = stdlize(a(:,i))
        a(:,i) = stdlize(a(:,i));
    end
        
end

function [M, x] = my_gauss(a)
    [m, n] = size(a);
    s = zeros([m,1]);
    for i = 1:m
        s(i) = i;
    end
    for i = 1:m      
        for j = 1:(i-1)
            if a(j,j) == 0
                a([i,j],:) = a([j,i],:);
                disp("swap");
            end
            if a(j,j) ~= 0
                a(i,:) = a(i,:) - a(j,:) * (a(i,j)./a(j,j));
            end              
        end
    end
%     for i = 1:m
%         if a(i,i) ~= 0
%             a(i,:) = a(i,:) / a(i,i);
%         end
%     end
%     disp("__my_guass__");
%     a
    M = a;
    x = s;
end

function [x] = find_eigen(a)
    [m,n]=size(a);
    for j=1:m-1
        for z=2:m
            if a(j,j)==0
                t=a(1,:);a(1,:)=a(z,:);
                a(z,:)=t;
            end
        end
        for i=j+1:m
            a(i,:)=a(i,:)-a(j,:)*(a(i,j)/a(j,j));
        end
    end

    for j=m:-1:2
        for i=j-1:-1:1
            a(i,:)=a(i,:)-a(j,:)*(a(i,j)/a(j,j));
        end
    end

    for s=1:m
        a(s,:)=a(s,:)/a(s,s);
        x(s)=a(s,n);
    end
%     disp("__find_eigen");
%     a
%     x
end

function [Q,R] = gschmidt(V)
    [m,n] = size(V);
    R = zeros(n);
    R(1,1) = norm(V(:,1));
    Q(:,1) = V(:,1)/R(1,1);
    for k = 2:n
        R(1:k-1,k) = Q(:,1:k-1)' * V(:,k);
        Q(:,k) = V(:,k) - Q(:,1:k-1) * R(1:k-1,k);
        R(k,k) = norm(Q(:,k));
        Q(:,k) = Q(:,k)/R(k,k);
    end
end

Output:

A = randi([0,10],[3,4]);

>> mysvd_re
_________my_svd:_______________

U =

   -0.4161    0.7207   -0.3644    0.4179
   -0.3270   -0.6105   -0.0067    0.7213
   -0.6796   -0.2782   -0.4017   -0.5472
   -0.5080    0.1747    0.8401   -0.0746


D =

   21.8823         0         0
         0    8.9410         0
         0         0    2.0548
         0         0         0


V =

   -0.5890    0.0374    0.8072
   -0.5230   -0.7792   -0.3455
   -0.6161    0.6257   -0.4785

_________matlab_svd:_______________

U =

   -0.4161    0.7207    0.3644   -0.4179
   -0.3270   -0.6105    0.0067   -0.7213
   -0.6796   -0.2782    0.4017    0.5472
   -0.5080    0.1747   -0.8401    0.0746


D =

   21.8823         0         0
         0    8.9410         0
         0         0    2.0548
         0         0         0


V =

   -0.5890    0.0374   -0.8072
   -0.5230   -0.7792    0.3455
   -0.6161    0.6257    0.4785

A = randi([0,10],[9,9]);

>> mysvd_re
_________my_svd:_______________

U =

   -0.3591   -0.2647    0.3080    0.6610   -0.1323   -0.1422    0.3165   -0.2326    0.2778
   -0.3656    0.4015   -0.0847    0.1034    0.7136   -0.1976   -0.2605   -0.0346    0.2643
   -0.3051    0.1391   -0.5811   -0.3061   -0.1893    0.0770    0.4136   -0.4423    0.2185
   -0.2861    0.6491    0.1971    0.1364   -0.4970    0.3521   -0.2415    0.1003    0.0084
   -0.1647    0.0202    0.0063   -0.1010   -0.3092   -0.7252   -0.3662   -0.3161   -0.3266
   -0.4680    0.0736    0.2498   -0.2705    0.1171   -0.1367    0.5122    0.3951   -0.4350
   -0.2497   -0.2632    0.4500   -0.3554    0.1769    0.4176   -0.2209   -0.5216   -0.1139
   -0.3601   -0.3965   -0.0300   -0.3197   -0.2179   -0.0711   -0.2964    0.4453    0.5209
   -0.3526   -0.3095   -0.5055    0.3581    0.0480    0.2946   -0.2637    0.1121   -0.4742


D =

   49.8159         0         0         0         0         0         0         0         0
         0   16.1835         0         0         0         0         0         0         0
         0         0   12.0945         0         0         0         0         0         0
         0         0         0   10.2072         0         0         0         0         0
         0         0         0         0    8.4920         0         0         0         0
         0         0         0         0         0    6.1391         0         0         0
         0         0         0         0         0         0    4.6952         0         0
         0         0         0         0         0         0         0    4.0203         0
         0         0         0         0         0         0         0         0    1.4139


V =

   -0.3989   -0.0656   -0.4200    0.0748   -0.3255    0.3455    0.1666   -0.0922    0.6269
   -0.3837   -0.0912   -0.3022   -0.4258    0.5781   -0.0623   -0.1726    0.4496    0.0412
   -0.2489    0.5891   -0.0864   -0.3244   -0.1767   -0.4511   -0.2768   -0.4049    0.0549
   -0.2887   -0.0276    0.5825   -0.5447   -0.0138    0.4051    0.2740   -0.1915   -0.0627
   -0.3626    0.1437    0.4721    0.3591   -0.1577    0.1177   -0.5521    0.3740    0.1129
   -0.3142   -0.4470   -0.2751   -0.0451   -0.3487    0.0932   -0.3055   -0.1871   -0.6043
   -0.3146   -0.2476    0.1482    0.4012    0.5437   -0.1364   -0.0230   -0.5755    0.1043
   -0.3059   -0.3058    0.2004    0.0238   -0.2757   -0.6656    0.4295    0.2566    0.0521
   -0.3545    0.5141   -0.1575    0.3443    0.1015    0.1581    0.4557    0.1300   -0.4548

_________matlab_svd:_______________

U =

   -0.3591    0.2647    0.3080    0.6610   -0.1323   -0.1422   -0.3165    0.2326   -0.2778
   -0.3656   -0.4015   -0.0847    0.1034    0.7136   -0.1976    0.2605    0.0346   -0.2643
   -0.3051   -0.1391   -0.5811   -0.3061   -0.1893    0.0770   -0.4136    0.4423   -0.2185
   -0.2861   -0.6491    0.1971    0.1364   -0.4970    0.3521    0.2415   -0.1003   -0.0084
   -0.1647   -0.0202    0.0063   -0.1010   -0.3092   -0.7252    0.3662    0.3161    0.3266
   -0.4680   -0.0736    0.2498   -0.2705    0.1171   -0.1367   -0.5122   -0.3951    0.4350
   -0.2497    0.2632    0.4500   -0.3554    0.1769    0.4176    0.2209    0.5216    0.1139
   -0.3601    0.3965   -0.0300   -0.3197   -0.2179   -0.0711    0.2964   -0.4453   -0.5209
   -0.3526    0.3095   -0.5055    0.3581    0.0480    0.2946    0.2637   -0.1121    0.4742


D =

   49.8159         0         0         0         0         0         0         0         0
         0   16.1835         0         0         0         0         0         0         0
         0         0   12.0945         0         0         0         0         0         0
         0         0         0   10.2072         0         0         0         0         0
         0         0         0         0    8.4920         0         0         0         0
         0         0         0         0         0    6.1391         0         0         0
         0         0         0         0         0         0    4.6952         0         0
         0         0         0         0         0         0         0    4.0203         0
         0         0         0         0         0         0         0         0    1.4139


V =

   -0.3989    0.0656   -0.4200    0.0748   -0.3255    0.3455   -0.1666    0.0922   -0.6269
   -0.3837    0.0912   -0.3022   -0.4258    0.5781   -0.0623    0.1726   -0.4496   -0.0412
   -0.2489   -0.5891   -0.0864   -0.3244   -0.1767   -0.4511    0.2768    0.4049   -0.0549
   -0.2887    0.0276    0.5825   -0.5447   -0.0138    0.4051   -0.2740    0.1915    0.0627
   -0.3626   -0.1437    0.4721    0.3591   -0.1577    0.1177    0.5521   -0.3740   -0.1129
   -0.3142    0.4470   -0.2751   -0.0451   -0.3487    0.0932    0.3055    0.1871    0.6043
   -0.3146    0.2476    0.1482    0.4012    0.5437   -0.1364    0.0230    0.5755   -0.1043
   -0.3059    0.3058    0.2004    0.0238   -0.2757   -0.6656   -0.4295   -0.2566   -0.0521
   -0.3545   -0.5141   -0.1575    0.3443    0.1015    0.1581   -0.4557   -0.1300    0.4548

A = randi([0,10],[3,4]);

>> mysvd_re
_________my_svd:_______________

U =

   -0.5059   -0.1218    0.7662   -0.3770
   -0.5475   -0.2204   -0.6392   -0.4931
   -0.5713   -0.2437   -0.0306    0.7831
   -0.3433    0.9366   -0.0587    0.0387


D =

   22.2952         0         0
         0    6.7864         0
         0         0    1.3673
         0         0         0


V =

   -0.7296   -0.4135   -0.5447
   -0.4301    0.8967   -0.1047
   -0.5318   -0.1579    0.8320

_________matlab_svd:_______________

U =

   -0.5059    0.1218    0.7662    0.3770
   -0.5475    0.2204   -0.6392    0.4931
   -0.5713    0.2437   -0.0306   -0.7831
   -0.3433   -0.9366   -0.0587   -0.0387


D =

   22.2952         0         0
         0    6.7864         0
         0         0    1.3673
         0         0         0


V =

   -0.7296    0.4135   -0.5447
   -0.4301   -0.8967   -0.1047
   -0.5318    0.1579    0.8320

SVD: Image Compress

A = imread('/Users/kion/Desktop/IMG_2638.jpeg');
A = rgb2gray(A);
imshow(A)
title(['Original (',sprintf('Rank %d)',rank(double(A)))])

[U1,S1,V1] = svdsketch(double(A),1e-2);
Anew1 = uint8(U1*S1*V1');

[U2,S2,V2] = svdsketch(double(A),1e-1);
Anew2 = uint8(U2*S2*V2');

[U3,S3,V3,apxErr] = svdsketch(double(A),1e-1,'MaxSubspaceDimension',15);
Anew3 = uint8(U3*S3*V3');

tiledlayout(2,2,'TileSpacing','Compact')
nexttile
imshow(A)
title('Original')
nexttile
imshow(Anew1)
title(sprintf('Rank %d approximation',size(S1,1)))
nexttile
imshow(Anew2)
title(sprintf('Rank %d approximation',size(S2,1)))
nexttile
imshow(Anew3)
title(sprintf('Rank %d approximation',size(S3,1)))

Screen Shot 2021-10-16 at 21.00.43

posted @ 2022-03-07 13:10  miyasaka  阅读(218)  评论(0)    收藏  举报