1 function [R]=test_sc(CC,D,mode,classlabel)
2 % TEST_SC: apply statistical and SVM classifier to test data
4 % R = test_sc(CC,D,TYPE [,target_Classlabel])
5 % R.output output: "signed" distance for each class.
6 % This represents the distances between sample D and the separating hyperplane
7 % The "signed distance" is possitive if it matches the target class, and
8 % and negative if it lays on the opposite side of the separating hyperplane.
9 % R.classlabel class for output data
10 % The target class is optional. If it is provided, the following values are returned.
11 % R.kappa Cohen's kappa coefficient
12 % R.ACC Classification accuracy
13 % R.H Confusion matrix
15 % The classifier CC is typically obtained by TRAIN_SC. If a statistical
16 % classifier is used, TYPE can be used to modify the classifier.
17 % TYPE = 'MDA' mahalanobis distance based classifier
18 % TYPE = 'MD2' mahalanobis distance based classifier
19 % TYPE = 'MD3' mahalanobis distance based classifier
20 % TYPE = 'GRB' Gaussian radial basis function
21 % TYPE = 'QDA' quadratic discriminant analysis
22 % TYPE = 'LD2' linear discriminant analysis
23 % TYPE = 'LD3', 'LDA', 'FDA, 'FLDA' (Fisher's) linear discriminant analysis
24 % TYPE = 'LD4' linear discriminant analysis
25 % TYPE = 'GDBC' general distance based classifier
30 % [1] R. Duda, P. Hart, and D. Stork, Pattern Classification, second ed.
31 % John Wiley & Sons, 2001.
33 % $Id: test_sc.m 9601 2012-02-09 14:14:36Z schloegl $
34 % Copyright (C) 2005,2006,2008,2009,2010 by Alois Schloegl <alois.schloegl@gmail.com>
35 % This function is part of the NaN-toolbox
36 % http://pub.ist.ac.at/~schloegl/matlab/NaN/
38 % This program is free software; you can redistribute it and/or
39 % modify it under the terms of the GNU General Public License
40 % as published by the Free Software Foundation; either version 3
41 % of the License, or (at your option) any later version.
43 % This program is distributed in the hope that it will be useful,
44 % but WITHOUT ANY WARRANTY; without even the implied warranty of
45 % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
46 % GNU General Public License for more details.
48 % You should have received a copy of the GNU General Public License
49 % along with this program; if not, write to the Free Software
50 % Foundation, Inc., 51 Franklin Street - Fifth Floor, Boston, MA 02110-1301, USA.
55 [t1,t] = strtok(CC.datatype,':');
56 [t2,t] = strtok(t,':');
58 if ~strcmp(t1,'classifier'), return; end;
60 if isfield(CC,'prewhite')
61 D = D*CC.prewhite(2:end,:) + CC.prewhite(ones(size(D,1),1),:);
62 CC = rmfield(CC,'prewhite');
65 POS1 = [strfind(CC.datatype,'/gsvd'),strfind(CC.datatype,'/sparse'),strfind(CC.datatype,'/delet')];
70 elseif strcmp(CC.datatype,'classifier:nbpw')
71 error('NBPW not implemented yet')
72 %%%% Naive Bayesian Parzen Window Classifier %%%%
73 d = repmat(NaN,size(D,1),size(CC.MEAN,1));
74 for k = 1:size(CC.MEAN,1)
75 z = (D - CC.MEAN(repmat(k,size(D,1),1),:)).^2 ./ (CC.VAR(repmat(k,size(D,1),1),:));
76 z = z + log(CC.VAR(repmat(k,size(D,1),1),:)); % + log(2*pi);
77 d(:,k) = sum(-z/2, 2) + log(mean(CC.N(k,:)));
79 d = exp(d-log(mean(sum(CC.N,1)))-log(2*pi)/2);
82 elseif strcmp(CC.datatype,'classifier:nbc')
83 %%%% Naive Bayesian Classifier %%%%
84 d = repmat(NaN,size(D,1),size(CC.MEAN,1));
85 for k = 1:size(CC.MEAN,1)
86 z = (D - CC.MEAN(repmat(k,size(D,1),1),:)).^2 ./ (CC.VAR(repmat(k,size(D,1),1),:));
87 z = z + log(CC.VAR(repmat(k,size(D,1),1),:)); % + log(2*pi);
88 d(:,k) = sum(-z/2, 2) + log(mean(CC.N(k,:)));
90 d = exp(d-log(mean(sum(CC.N,1)))-log(2*pi)/2);
93 elseif strcmp(CC.datatype,'classifier:anbc')
94 %%%% Augmented Naive Bayesian Classifier %%%%
95 d = repmat(NaN,size(D,1),size(CC.MEAN,1));
96 for k = 1:size(CC.MEAN,1)
97 z = (D*CC.V - CC.MEAN(repmat(k,size(D,1),1),:)).^2 ./ (CC.VAR(repmat(k,size(D,1),1),:));
98 z = z + log(CC.VAR(repmat(k,size(D,1),1),:)); % + log(2*pi);
99 d(:,k) = sum(-z/2, 2) + log(mean(CC.N(k,:)));
101 d = exp(d-log(mean(sum(CC.N,1)))-log(2*pi)/2);
104 elseif strcmp(CC.datatype,'classifier:statistical:rda')
105 % Friedman (1989) Regularized Discriminant analysis
106 if isfield(CC,'hyperparameter') && isfield(CC.hyperparameter,'lambda') && isfield(CC.hyperparameter,'gamma')
107 D = [ones(size(D,1),1),D]; % add 1-column
108 lambda = CC.hyperparameter.lambda;
109 gamma = CC.hyperparameter.gamma;
110 d = repmat(NaN,size(D,1),size(CC.MD,1));
113 ECM0 = squeeze(sum(ECM,3)); %decompose ECM
114 [M0,sd,COV0] = decovm(ECM0);
116 [M,sd,s,xc,N] = decovm(squeeze(ECM(:,:,k)));
117 s = ((1-lambda)*N*s+lambda*COV0)/((1-lambda)*N+lambda);
118 s = (1-gamma)*s+gamma*(trace(s))/(NC(2)-1)*eye(NC(2)-1);
119 ir = [-M;eye(NC(2)-1)]*inv(s)*[-M',eye(NC(2)-1)]; % inverse correlation matrix extended by mean
120 d(:,k) = -sum((D*ir).*D,2); % calculate distance of each data point to each class
123 error('QDA: hyperparamters lambda and/or gamma not defined')
127 elseif strcmp(CC.datatype,'classifier:csp')
128 d = filtfilt(CC.FiltB,CC.FiltA,(D*CC.csp_w).^2);
129 R = test_sc(CC.CSP,log(d)); % LDA classifier of
133 elseif strcmp(CC.datatype,'classifier:svm:lib:1vs1') || strcmp(CC.datatype,'classifier:svm:lib:rbf');
135 [cl] = svmpredict_mex(ones(nr,1), D, CC.model); %Use the classifier
136 %Create a pseudo tsd matrix for bci4eval
137 d = full(sparse(1:nr,cl,1,nr,CC.model.nr_class));
140 elseif isfield(CC,'weights'); %strcmpi(t2,'svm') || (strcmpi(t2,'statistical') & strncmpi(t3,'ld',2)) ;
141 % linear classifiers like: LDA, SVM, LPM
142 %d = [ones(size(D,1),1), D] * CC.weights;
143 d = repmat(NaN,size(D,1),size(CC.weights,2));
144 for k = 1:size(CC.weights,2),
145 d(:,k) = D * CC.weights(2:end,k) + CC.weights(1,k);
149 elseif ~isempty(POS1) % GSVD, sparse & DELETION
150 CC.datatype = CC.datatype(1:POS1(1)-1);
151 r = test_sc(CC, D*sparse(CC.G));
155 elseif strcmp(t2,'statistical');
157 mode.TYPE = upper(t3);
159 D = [ones(size(D,1),1),D]; % add 1-column
160 W = repmat(NaN, size(D,2), size(CC.MD,3));
163 elseif strcmpi(mode.TYPE,'LD2'),
167 ECM0 = squeeze(sum(ECM,3)); %decompose ECM
170 ecm = squeeze(ECM(:,:,k));
171 [M1,sd,COV1] = decovm(ECM0-ecm);
172 [M2,sd,COV2] = decovm(ecm);
173 w = (COV1+COV2)\(M2'-M1')*2;
178 elseif strcmpi(mode.TYPE,'LD3') || strcmpi(mode.TYPE,'FLDA');
182 ECM0 = squeeze(sum(ECM,3)); %decompose ECM
183 [M0,sd,COV0] = decovm(ECM0);
185 ecm = squeeze(ECM(:,:,k));
186 [M1] = decovm(ECM0-ecm);
188 w = COV0\(M2'-M1')*2;
193 elseif strcmpi(mode.TYPE,'LD4');
197 ECM0 = squeeze(sum(ECM,3)); %decompose ECM
200 ecm = squeeze(ECM(:,:,k));
201 [M1,sd,COV1,xc,N1] = decovm(ECM0-ecm);
202 [M2,sd,COV2,xc,N2] = decovm(ecm);
203 w = (COV1*N1+COV2*N2)\((M2'-M1')*(N1+N2));
208 elseif strcmpi(mode.TYPE,'MDA');
209 d = repmat(NaN,size(D,1),length(CC.IR));
210 for k = 1:length(CC.IR);
211 d(:,k) = -sum((D*CC.IR{k}).*D,2); % calculate distance of each data point to each class
213 elseif strcmpi(mode.TYPE,'MD2');
214 d = repmat(NaN,size(D,1),length(CC.IR));
215 for k = 1:length(CC.IR);
216 d(:,k) = sum((D*CC.IR{k}).*D,2); % calculate distance of each data point to each class
219 elseif strcmpi(mode.TYPE,'GDBC');
220 d = repmat(NaN,size(D,1),length(CC.IR));
221 for k = 1:length(CC.IR);
222 d(:,k) = sum((D*CC.IR{k}).*D,2) + CC.logSF7(k); % calculate distance of each data point to each class
225 elseif strcmpi(mode.TYPE,'MD3');
226 d = repmat(NaN,size(D,1),length(CC.IR));
227 for k = 1:length(CC.IR);
228 d(:,k) = sum((D*CC.IR{k}).*D,2) + CC.logSF7(k); % calculate distance of each data point to each class
231 d = d./repmat(sum(d,2),1,size(d,2)); % Zuordungswahrscheinlichkeit [1], p.601, equ (18.39)
232 elseif strcmpi(mode.TYPE,'QDA');
233 d = repmat(NaN,size(D,1),length(CC.IR));
234 for k = 1:length(CC.IR);
235 % [1] (18.33) QCF - quadratic classification function
236 d(:,k) = -(sum((D*CC.IR{k}).*D,2) - CC.logSF5(k));
238 elseif strcmpi(mode.TYPE,'QDA2');
239 d = repmat(NaN,size(D,1),length(CC.IR));
240 for k = 1:length(CC.IR);
241 % [1] (18.33) QCF - quadratic classification function
242 d(:,k) = -(sum((D*(CC.IR{k})).*D,2) + CC.logSF4(k));
244 elseif strcmpi(mode.TYPE,'GRB'); % Gaussian RBF
245 d = repmat(NaN,size(D,1),length(CC.IR));
246 for k = 1:length(CC.IR);
247 d(:,k) = sum((D*CC.IR{k}).*D,2); % calculate distance of each data point to each class
250 elseif strcmpi(mode.TYPE,'GRB2'); % Gaussian RBF
251 d = repmat(NaN,size(D,1),length(CC.IR));
252 for k = 1:length(CC.IR);
253 d(:,k) = sum((D*CC.IR{k}).*D,2); % calculate distance of each data point to each class
256 elseif strcmpi(mode.TYPE,'MQU'); % Multiquadratic
257 d = repmat(NaN,size(D,1),length(CC.IR));
258 for k = 1:length(CC.IR);
259 d(:,k) = sum((D*CC.IR{k}).*D,2); % calculate distance of each data point to each class
262 elseif strcmpi(mode.TYPE,'IMQ'); % Inverse Multiquadratic
263 d = repmat(NaN,size(D,1),length(CC.IR));
264 for k = 1:length(CC.IR);
265 d(:,k) = sum((D*CC.IR{k}).*D,2); % calculate distance of each data point to each class
268 elseif strcmpi(mode.TYPE,'Cauchy'); % Cauchy RBF
269 d = repmat(NaN,size(D,1),length(CC.IR));
270 for k = 1:length(CC.IR);
271 d(:,k) = sum((D*CC.IR{k}).*D,2); % calculate distance of each data point to each class
275 error('Classifier %s not supported. see HELP TRAIN_SC for supported classifiers.',mode.TYPE);
278 fprintf(2,'Error TEST_SC: unknown classifier\n');
283 [tmp,cl] = max(d,[],2);
285 cl(isnan(tmp)) = NaN;
287 cl = (d<0) + 2*(d>0);
295 [R.kappa,R.sd,R.H,z,R.ACC] = kappa(classlabel(:),cl(:));