]> Creatis software - CreaPhase.git/blob - octave_packages/nan-2.5.5/train_lda_sparse.m
Add a useful package (from Source forge) for octave
[CreaPhase.git] / octave_packages / nan-2.5.5 / train_lda_sparse.m
1 function [CC] = train_lda_sparse(X,G,par,tol)
2 % Linear Discriminant Analysis for the Small Sample Size Problem as described in
3 % Algorithm 1 of J. Duintjer Tebbens, P. Schlesinger: 'Improving
4 % Implementation of Linear Discriminant Analysis for the High Dimension/Small Sample Size
5 % Problem', Computational Statistics and Data Analysis, vol. 52, no. 1, pp. 423-437, 2007.  
6 % Input:
7 %               X                 ......       (sparse) training data matrix
8 %               G                 ......       group coding matrix of the training data
9 %               test              ......       (sparse) test data matrix
10 %               Gtest             ......       group coding matrix of the test data
11 %               par               ......       if par = 0 then classification exploits sparsity too
12 %               tol               ......       tolerance to distinguish zero eigenvalues
13 % Output:
14 %               err               ......       Wrong classification rate (in %)
15 %               trafo             ......       LDA transformation vectors
16 %
17 % Reference(s): 
18 % J. Duintjer Tebbens, P. Schlesinger: 'Improving
19 % Implementation of Linear Discriminant Analysis for the High Dimension/Small Sample Size
20 % Problem', Computational Statistics and Data Analysis, vol. 52, no. 1, 
21 % pp. 423-437, 2007.
22 %
23 % Copyright (C) by J. Duintjer Tebbens, Institute of Computer Science of the Academy of Sciences of the Czech Republic,
24 % Pod Vodarenskou vezi 2, 182 07 Praha 8 Liben, 18.July.2006. 
25 % This work was supported by the Program Information Society under project
26 % 1ET400300415.
27 %
28 %
29 % Modified for the use with Matlab6.5 by A. Schloegl, 22.Aug.2006
30 %
31 %       $Id$
32 %       This function is part of the NaN-toolbox
33 %       http://pub.ist.ac.at/~schloegl/matlab/NaN/
34
35 % This program is free software; you can redistribute it and/or
36 % modify it under the terms of the GNU General Public License
37 % as published by the Free Software Foundation; either version 3
38 % of the  License, or (at your option) any later version.
39
40 % This program is distributed in the hope that it will be useful,
41 % but WITHOUT ANY WARRANTY; without even the implied warranty of
42 % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
43 % GNU General Public License for more details.
44
45 % You should have received a copy of the GNU General Public License
46 % along with this program; if not, write to the Free Software
47 % Foundation, Inc., 51 Franklin Street - Fifth Floor, Boston, MA 02110-1301, USA.
48
49 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Step (1)
50 %p = length(X(1,:));n = length(X(:,1));g = length(G(1,:));
51 G = sparse(G);
52 [n,p]=size(X); 
53 g = size(G,2);
54
55 for j=1:g
56         nj(j) = norm(G(:,j))^2;
57 end
58 Dtild = spdiags(nj'.^(-1),0,g,g);
59 Xtild = X*X';
60 Xtild1 = Xtild*ones(n,1);
61 help = ones(n,1)*Xtild1'/n - (ones(1,n)*Xtild'*ones(n,1))/(n^2); 
62 matrix = Xtild - Xtild1*ones(1,n)/n - help;
63 % eliminate non-symmetry of matrix due to rounding error:
64 matrix = (matrix+matrix')/2;
65 [V0,S] = eig(matrix);
66 % [s,I] = sort(diag(S),'descend');
67 [s,I] = sort(-diag(S)); s = -s; 
68
69 cc = sum(s<tol);
70
71 count = n-cc;
72 V1 = V0(:,I(1:count));
73 D1inv = diag(s(1:count).^(-1.0));
74 Dhalfinv = diag(s(1:count).^(-0.5));
75 Dhalf = diag(s(1:count).^(0.5));
76 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Step (2)
77 help2 = V1*D1inv;
78 M1 = Dtild*G'*Xtild;
79 B1 = (G*(M1*(speye(n)-1/n))-help)*help2;
80 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Step (3)
81 opts.issym = 1;opts.isreal = 1;opts.disp = 0;
82 %if 0, 
83 try,
84         [V0,S,flag] = eigs(B1'*B1,g-1,'lm',opts);
85         EV = Dhalfinv*V0;
86         [s,I] = sort(-diag(S)); s = -s; 
87         %else
88 catch
89         % needed as long as eigs is not supported by Octave
90         [V0,S] = eig(B1'*B1);
91         flag   = 0;
92         [s,I]  = sort(-diag(S)); s = -s(I(1:g-1));
93         EV = Dhalfinv * V0(:,I(1:g-1));
94         I = 1:g-1;
95 end;
96 %EV = Dhalfinv*V0;
97 %[s,I] = sort((diag(S)),'descend');
98 %[s,I] = sort(-diag(S)); s = -s; 
99 if flag ~= 0,
100         'eigs did not converge';
101 end
102 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Step (4)
103 for j=1:g-1,
104         C(:,j) = EV(:,I(j))/norm(EV(:,I(j)));
105 end
106 cc = 0;
107 for j=1:g-1,
108         if (1-s(j))<tol
109                 cc = cc+1;
110                 V2(:,j) = EV(:,I(j));
111         else
112                 break
113         end
114 end
115 if cc > 0
116         [Q,R] = qr(V2,0);
117         matrix = B1*Dhalf*Q;
118         [V0,S] = eig(matrix'*matrix);
119         %[s,I] = sort(diag(S),'descend');
120         [s,I] = sort(-diag(S)); s = -s; 
121         for j=1:cc
122                 C(:,j) = Q*V0(:,I(j));
123         end
124 end
125
126 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Step (5)
127 C1 = help2*Dhalf*C;
128 trafo(:,1:g-1) = X'*C1 - (X'*ones(n,1))*(ones(1,n)*C1/n);
129 for j=1:g-1
130         trafo(:,j) = trafo(:,j)/norm(trafo(:,j));
131 end
132 CC.trafo = trafo; 
133
134 if par == 0
135 %    X2 = full(test*X');
136 %    [pred] = classifs(C1,M1,X2);
137         CC.C1 = C1;
138         CC.M1 = M1;
139         CC.X  = X;
140 else
141 %    M = Dtild*G'*X;
142 %    [pred] = classifs(trafo,M,test);
143         CC.C1 = trafo; 
144         CC.M1 = Dtild*G'*X;
145 end