]> Creatis software - CreaPhase.git/blob - octave_packages/nan-2.5.5/test_sc.m
Add a useful package (from Source forge) for octave
[CreaPhase.git] / octave_packages / nan-2.5.5 / test_sc.m
1 function [R]=test_sc(CC,D,mode,classlabel)
2 % TEST_SC: apply statistical and SVM classifier to test data 
3 %
4 %  R = test_sc(CC,D,TYPE [,target_Classlabel]) 
5 %       R.output        output: "signed" distance for each class. 
6 %               This represents the distances between sample D and the separating hyperplane
7 %               The "signed distance" is possitive if it matches the target class, and 
8 %               and negative if it lays on the opposite side of the separating hyperplane. 
9 %       R.classlabel    class for output data
10 %  The target class is optional. If it is provided, the following values are returned. 
11 %       R.kappa         Cohen's kappa coefficient
12 %       R.ACC           Classification accuracy 
13 %       R.H             Confusion matrix 
14 %
15 % The classifier CC is typically obtained by TRAIN_SC. If a statistical 
16 % classifier is used, TYPE can be used to modify the classifier. 
17 %    TYPE = 'MDA'    mahalanobis distance based classifier
18 %    TYPE = 'MD2'    mahalanobis distance based classifier
19 %    TYPE = 'MD3'    mahalanobis distance based classifier
20 %    TYPE = 'GRB'    Gaussian radial basis function 
21 %    TYPE = 'QDA'    quadratic discriminant analysis
22 %    TYPE = 'LD2'    linear discriminant analysis
23 %    TYPE = 'LD3', 'LDA', 'FDA, 'FLDA'   (Fisher's) linear discriminant analysis
24 %    TYPE = 'LD4'    linear discriminant analysis
25 %    TYPE = 'GDBC'   general distance based classifier
26
27 % see also: TRAIN_SC
28 %
29 % References: 
30 % [1] R. Duda, P. Hart, and D. Stork, Pattern Classification, second ed. 
31 %       John Wiley & Sons, 2001.
32
33 %       $Id: test_sc.m 9601 2012-02-09 14:14:36Z schloegl $
34 %       Copyright (C) 2005,2006,2008,2009,2010 by Alois Schloegl <alois.schloegl@gmail.com>
35 %       This function is part of the NaN-toolbox
36 %       http://pub.ist.ac.at/~schloegl/matlab/NaN/
37
38 % This program is free software; you can redistribute it and/or
39 % modify it under the terms of the GNU General Public License
40 % as published by the Free Software Foundation; either version 3
41 % of the  License, or (at your option) any later version.
42
43 % This program is distributed in the hope that it will be useful,
44 % but WITHOUT ANY WARRANTY; without even the implied warranty of
45 % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
46 % GNU General Public License for more details.
47
48 % You should have received a copy of the GNU General Public License
49 % along with this program; if not, write to the Free Software
50 % Foundation, Inc., 51 Franklin Street - Fifth Floor, Boston, MA 02110-1301, USA.
51
52 if nargin<3,
53         mode = [];
54 end;
55 [t1,t] = strtok(CC.datatype,':');
56 [t2,t] = strtok(t,':');
57 [t3] = strtok(t,':');
58 if ~strcmp(t1,'classifier'), return; end; 
59
60 if isfield(CC,'prewhite')
61         D = D*CC.prewhite(2:end,:) + CC.prewhite(ones(size(D,1),1),:);
62         CC = rmfield(CC,'prewhite');
63 end;
64
65 POS1 = [strfind(CC.datatype,'/gsvd'),strfind(CC.datatype,'/sparse'),strfind(CC.datatype,'/delet')];
66
67 if 0,
68
69
70 elseif strcmp(CC.datatype,'classifier:nbpw')
71         error('NBPW not implemented yet')
72         %%%% Naive Bayesian Parzen Window Classifier %%%%
73         d = repmat(NaN,size(D,1),size(CC.MEAN,1));
74         for k = 1:size(CC.MEAN,1)
75                 z = (D - CC.MEAN(repmat(k,size(D,1),1),:)).^2 ./ (CC.VAR(repmat(k,size(D,1),1),:));
76                 z = z + log(CC.VAR(repmat(k,size(D,1),1),:)); % + log(2*pi);
77                 d(:,k) = sum(-z/2, 2) + log(mean(CC.N(k,:)));
78         end;
79         d = exp(d-log(mean(sum(CC.N,1)))-log(2*pi)/2);
80
81
82 elseif strcmp(CC.datatype,'classifier:nbc')
83         %%%% Naive Bayesian Classifier %%%%
84         d = repmat(NaN,size(D,1),size(CC.MEAN,1));
85         for k = 1:size(CC.MEAN,1)
86                 z = (D - CC.MEAN(repmat(k,size(D,1),1),:)).^2 ./ (CC.VAR(repmat(k,size(D,1),1),:));
87                 z = z + log(CC.VAR(repmat(k,size(D,1),1),:)); % + log(2*pi);
88                 d(:,k) = sum(-z/2, 2) + log(mean(CC.N(k,:)));
89         end; 
90         d = exp(d-log(mean(sum(CC.N,1)))-log(2*pi)/2);
91
92
93 elseif strcmp(CC.datatype,'classifier:anbc')
94         %%%% Augmented Naive Bayesian Classifier %%%%
95         d = repmat(NaN,size(D,1),size(CC.MEAN,1));
96         for k = 1:size(CC.MEAN,1)
97                 z = (D*CC.V - CC.MEAN(repmat(k,size(D,1),1),:)).^2 ./ (CC.VAR(repmat(k,size(D,1),1),:));
98                 z = z + log(CC.VAR(repmat(k,size(D,1),1),:)); % + log(2*pi);
99                 d(:,k) = sum(-z/2, 2) + log(mean(CC.N(k,:)));
100         end; 
101         d = exp(d-log(mean(sum(CC.N,1)))-log(2*pi)/2);
102
103
104 elseif strcmp(CC.datatype,'classifier:statistical:rda')
105         % Friedman (1989) Regularized Discriminant analysis
106         if isfield(CC,'hyperparameter') && isfield(CC.hyperparameter,'lambda')  && isfield(CC.hyperparameter,'gamma')
107                 D = [ones(size(D,1),1),D];  % add 1-column
108                 lambda = CC.hyperparameter.lambda;
109                 gamma  = CC.hyperparameter.gamma;
110                 d = repmat(NaN,size(D,1),size(CC.MD,1));
111                 ECM = CC.MD./CC.NN; 
112                 NC = size(ECM); 
113                 ECM0 = squeeze(sum(ECM,3));  %decompose ECM
114                 [M0,sd,COV0] = decovm(ECM0);
115                 for k = 1:NC(3);
116                         [M,sd,s,xc,N] = decovm(squeeze(ECM(:,:,k)));
117                         s = ((1-lambda)*N*s+lambda*COV0)/((1-lambda)*N+lambda);
118                         s = (1-gamma)*s+gamma*(trace(s))/(NC(2)-1)*eye(NC(2)-1);
119                         ir  = [-M;eye(NC(2)-1)]*inv(s)*[-M',eye(NC(2)-1)];  % inverse correlation matrix extended by mean
120                         d(:,k) = -sum((D*ir).*D,2); % calculate distance of each data point to each class
121                 end;
122         else 
123                 error('QDA: hyperparamters lambda and/or gamma not defined')
124         end;
125
126
127 elseif strcmp(CC.datatype,'classifier:csp')
128         d = filtfilt(CC.FiltB,CC.FiltA,(D*CC.csp_w).^2);
129         R = test_sc(CC.CSP,log(d));     % LDA classifier of 
130         d = R.output; 
131
132
133 elseif strcmp(CC.datatype,'classifier:svm:lib:1vs1') || strcmp(CC.datatype,'classifier:svm:lib:rbf');
134         nr = size(D,1);
135         [cl] = svmpredict_mex(ones(nr,1), D, CC.model);   %Use the classifier
136         %Create a pseudo tsd matrix for bci4eval
137         d = full(sparse(1:nr,cl,1,nr,CC.model.nr_class));
138
139
140 elseif isfield(CC,'weights'); %strcmpi(t2,'svm') || (strcmpi(t2,'statistical') & strncmpi(t3,'ld',2)) ;
141         % linear classifiers like: LDA, SVM, LPM 
142         %d = [ones(size(D,1),1), D] * CC.weights;
143         d = repmat(NaN,size(D,1),size(CC.weights,2));
144         for k = 1:size(CC.weights,2),
145                 d(:,k) = D * CC.weights(2:end,k) + CC.weights(1,k);
146         end;
147
148
149 elseif ~isempty(POS1)   % GSVD, sparse & DELETION
150         CC.datatype = CC.datatype(1:POS1(1)-1);
151         r = test_sc(CC, D*sparse(CC.G));
152         d = r.output; 
153
154
155 elseif strcmp(t2,'statistical');
156         if isempty(mode)
157                 mode.TYPE = upper(t3); 
158         end;
159         D = [ones(size(D,1),1),D];  % add 1-column
160         W = repmat(NaN, size(D,2), size(CC.MD,3));
161
162         if 0,
163         elseif strcmpi(mode.TYPE,'LD2'),
164                 %d = ldbc2(CC,D);
165                 ECM = CC.MD./CC.NN; 
166                 NC = size(ECM); 
167                 ECM0 = squeeze(sum(ECM,3));  %decompose ECM
168                 [M0] = decovm(ECM0);
169                 for k = 1:NC(3);
170                         ecm = squeeze(ECM(:,:,k));
171                         [M1,sd,COV1] = decovm(ECM0-ecm);
172                         [M2,sd,COV2] = decovm(ecm);
173                         w     = (COV1+COV2)\(M2'-M1')*2;
174                         w0    = -M0*w;
175                         W(:,k) = [w0; w];
176                 end;
177                 d = D*W;
178         elseif strcmpi(mode.TYPE,'LD3') || strcmpi(mode.TYPE,'FLDA');
179                 %d = ldbc3(CC,D);
180                 ECM = CC.MD./CC.NN; 
181                 NC = size(ECM); 
182                 ECM0 = squeeze(sum(ECM,3));  %decompose ECM
183                 [M0,sd,COV0] = decovm(ECM0);
184                 for k = 1:NC(3);
185                         ecm = squeeze(ECM(:,:,k));
186                         [M1] = decovm(ECM0-ecm);
187                         [M2] = decovm(ecm);
188                         w     = COV0\(M2'-M1')*2;
189                         w0    = -M0*w;
190                         W(:,k) = [w0; w];
191                 end;
192                 d = D*W;
193         elseif strcmpi(mode.TYPE,'LD4');
194                 %d = ldbc4(CC,D);
195                 ECM = CC.MD./CC.NN; 
196                 NC = size(ECM); 
197                 ECM0 = squeeze(sum(ECM,3));  %decompose ECM
198                 M0 = decovm(ECM0);
199                 for k = 1:NC(3);
200                         ecm = squeeze(ECM(:,:,k));
201                         [M1,sd,COV1,xc,N1] = decovm(ECM0-ecm);
202                         [M2,sd,COV2,xc,N2] = decovm(ecm);
203                         w     = (COV1*N1+COV2*N2)\((M2'-M1')*(N1+N2));
204                         w0    = -M0*w;
205                         W(:,k) = [w0; w];
206                 end;
207                 d = D*W;
208         elseif strcmpi(mode.TYPE,'MDA');
209                 d = repmat(NaN,size(D,1),length(CC.IR)); 
210                 for k = 1:length(CC.IR);
211                         d(:,k) = -sum((D*CC.IR{k}).*D,2); % calculate distance of each data point to each class
212                 end;
213         elseif strcmpi(mode.TYPE,'MD2');
214                 d = repmat(NaN,size(D,1),length(CC.IR)); 
215                 for k = 1:length(CC.IR);
216                         d(:,k) = sum((D*CC.IR{k}).*D,2); % calculate distance of each data point to each class
217                 end;
218                 d = -sqrt(d);
219         elseif strcmpi(mode.TYPE,'GDBC');
220                 d = repmat(NaN,size(D,1),length(CC.IR)); 
221                 for k = 1:length(CC.IR);
222                         d(:,k) = sum((D*CC.IR{k}).*D,2) + CC.logSF7(k); % calculate distance of each data point to each class
223                 end;
224                 d = exp(-d/2);
225         elseif strcmpi(mode.TYPE,'MD3');
226                 d = repmat(NaN,size(D,1),length(CC.IR)); 
227                 for k = 1:length(CC.IR);
228                         d(:,k) = sum((D*CC.IR{k}).*D,2) + CC.logSF7(k); % calculate distance of each data point to each class
229                 end;
230                 d = exp(-d/2);
231                 d = d./repmat(sum(d,2),1,size(d,2));  % Zuordungswahrscheinlichkeit [1], p.601, equ (18.39)
232         elseif strcmpi(mode.TYPE,'QDA');
233                 d = repmat(NaN,size(D,1),length(CC.IR)); 
234                 for k = 1:length(CC.IR);
235                         % [1] (18.33) QCF - quadratic classification function  
236                         d(:,k) = -(sum((D*CC.IR{k}).*D,2) - CC.logSF5(k)); 
237                 end;
238         elseif strcmpi(mode.TYPE,'QDA2');
239                 d = repmat(NaN,size(D,1),length(CC.IR)); 
240                 for k = 1:length(CC.IR);
241                         % [1] (18.33) QCF - quadratic classification function  
242                         d(:,k) = -(sum((D*(CC.IR{k})).*D,2) + CC.logSF4(k)); 
243                 end;
244         elseif strcmpi(mode.TYPE,'GRB');     % Gaussian RBF
245                 d = repmat(NaN,size(D,1),length(CC.IR)); 
246                 for k = 1:length(CC.IR);
247                         d(:,k) = sum((D*CC.IR{k}).*D,2); % calculate distance of each data point to each class
248                 end;
249                 d = exp(-sqrt(d)/2);
250         elseif strcmpi(mode.TYPE,'GRB2');     % Gaussian RBF
251                 d = repmat(NaN,size(D,1),length(CC.IR)); 
252                 for k = 1:length(CC.IR);
253                         d(:,k) = sum((D*CC.IR{k}).*D,2); % calculate distance of each data point to each class
254                 end;
255                 d = exp(-d);
256         elseif strcmpi(mode.TYPE,'MQU');     % Multiquadratic 
257                 d = repmat(NaN,size(D,1),length(CC.IR)); 
258                 for k = 1:length(CC.IR);
259                         d(:,k) = sum((D*CC.IR{k}).*D,2); % calculate distance of each data point to each class
260                 end;
261                 d = -sqrt(1+d);
262         elseif strcmpi(mode.TYPE,'IMQ');     % Inverse Multiquadratic 
263                 d = repmat(NaN,size(D,1),length(CC.IR)); 
264                 for k = 1:length(CC.IR);
265                         d(:,k) = sum((D*CC.IR{k}).*D,2); % calculate distance of each data point to each class
266                 end;
267                 d = (1+d).^(-1/2);
268         elseif strcmpi(mode.TYPE,'Cauchy');     % Cauchy RBF
269                 d = repmat(NaN,size(D,1),length(CC.IR)); 
270                 for k = 1:length(CC.IR);
271                         d(:,k) = sum((D*CC.IR{k}).*D,2); % calculate distance of each data point to each class
272                 end;
273                 d = 1./(1+d);
274         else 
275                 error('Classifier %s not supported. see HELP TRAIN_SC for supported classifiers.',mode.TYPE);
276         end;
277 else
278         fprintf(2,'Error TEST_SC: unknown classifier\n');
279         return;
280 end;
281
282 if size(d,2)>1,
283         [tmp,cl] = max(d,[],2);
284         cl = CC.Labels(cl); 
285         cl(isnan(tmp)) = NaN; 
286 elseif size(d,2)==1,
287         cl = (d<0) + 2*(d>0);
288         cl(isnan(d)) = NaN; 
289 end;    
290
291 R.output = d; 
292 R.classlabel = cl; 
293
294 if nargin>3,
295         [R.kappa,R.sd,R.H,z,R.ACC] = kappa(classlabel(:),cl(:));
296 end;