1 function [CC]=train_sc(D,classlabel,MODE,W)
2 % Train a (statistical) classifier
4 % CC = train_sc(D,classlabel)
5 % CC = train_sc(D,classlabel,MODE)
6 % CC = train_sc(D,classlabel,MODE, W)
7 % weighting D(k,:) with weight W(k) (not all classifiers supported weighting)
9 % CC contains the model parameters of a classifier which can be applied
10 % to test data using test_sc.
11 % R = test_sc(CC,D,...)
13 % D training samples (each row is a sample, each column is a feature)
14 % classlabel labels of each sample, must have the same number of rows as D.
15 % Two different encodings are supported:
16 % {-1,1}-encoding (multiple classes with separate columns for each class) or
18 % So [1;2;3;1;4] is equivalent to
24 % Note, samples with classlabel=0 are ignored.
26 % The following classifier types are supported MODE.TYPE
27 % 'MDA' mahalanobis distance based classifier [1]
28 % 'MD2' mahalanobis distance based classifier [1]
29 % 'MD3' mahalanobis distance based classifier [1]
30 % 'GRB' Gaussian radial basis function [1]
31 % 'QDA' quadratic discriminant analysis [1]
32 % 'LD2' linear discriminant analysis (see LDBC2) [1]
33 % MODE.hyperparameter.gamma: regularization parameter [default 0]
34 % 'LD3', 'FDA', 'LDA', 'FLDA'
35 % linear discriminant analysis (see LDBC3) [1]
36 % MODE.hyperparameter.gamma: regularization parameter [default 0]
37 % 'LD4' linear discriminant analysis (see LDBC4) [1]
38 % MODE.hyperparameter.gamma: regularization parameter [default 0]
39 % 'LD5' another LDA (motivated by CSP)
40 % MODE.hyperparameter.gamma: regularization parameter [default 0]
41 % 'RDA' regularized discriminant analysis [7]
42 % MODE.hyperparameter.gamma: regularization parameter
43 % MODE.hyperparameter.lambda =
44 % gamma = 0, lambda = 0 : MDA
45 % gamma = 0, lambda = 1 : LDA [default]
46 % Hint: hyperparameter are used only in test_sc.m, testing different
47 % the hyperparameters do not need repetitive calls to train_sc,
48 % it is sufficient to modify CC.hyperparameter before calling test_sc.
49 % 'GDBC' general distance based classifier [1]
50 % '' statistical classifier, requires Mode argument in TEST_SC
51 % '###/DELETION' if the data contains missing values (encoded as NaNs),
52 % a row-wise or column-wise deletion (depending on which method
53 % removes less data values) is applied;
54 % '###/GSVD' GSVD and statistical classifier [2,3],
55 % '###/sparse' sparse [5]
56 % '###' must be 'LDA' or any other classifier
57 % 'PLS' (linear) partial least squares regression
58 % 'REG' regression analysis;
59 % 'WienerHopf' Wiener-Hopf equation
60 % 'NBC' Naive Bayesian Classifier [6]
61 % 'aNBC' Augmented Naive Bayesian Classifier [6]
62 % 'NBPW' Naive Bayesian Parzen Window [9]
64 % 'PLA' Perceptron Learning Algorithm [11]
65 % MODE.hyperparameter.alpha = alpha [default: 1]
66 % w = w + alpha * e'*x
67 % 'LMS', 'AdaLine' Least mean squares, adaptive line element, Widrow-Hoff, delta rule
68 % MODE.hyperparameter.alpha = alpha [default: 1]
69 % 'Winnow2' Winnow2 algorithm [12]
71 % 'PSVM' Proximal SVM [8]
72 % MODE.hyperparameter.nu (default: 1.0)
73 % 'LPM' Linear Programming Machine
74 % uses and requires train_LPM of the iLog CPLEX optimizer
75 % MODE.hyperparameter.c_value =
76 % 'CSP' CommonSpatialPattern is very experimental and just a hack
77 % uses a smoothing window of 50 samples.
78 % 'SVM','SVM1r' support vector machines, one-vs-rest
79 % MODE.hyperparameter.c_value =
80 % 'SVM11' support vector machines, one-vs-one + voting
81 % MODE.hyperparameter.c_value =
82 % 'RBF' Support Vector Machines with RBF Kernel
83 % MODE.hyperparameter.c_value =
84 % MODE.hyperparameter.gamma =
85 % 'SVM:LIB' libSVM [default SVM algorithm)
86 % 'SVM:bioinfo' uses and requires svmtrain from the bioinfo toolbox
87 % 'SVM:OSU' uses and requires mexSVMTrain from the OSU-SVM toolbox
88 % 'SVM:LOO' uses and requires svcm_train from the LOO-SVM toolbox
89 % 'SVM:Gunn' uses and requires svc-functios from the Gunn-SVM toolbox
90 % 'SVM:KM' uses and requires svmclass-function from the KM-SVM toolbox
91 % 'SVM:LINz' LibLinear [10] (requires train.mex from LibLinear somewhere in the path)
92 % z=0 (default) LibLinear with -- L2-regularized logistic regression
93 % z=1 LibLinear with -- L2-loss support vector machines (dual)
94 % z=2 LibLinear with -- L2-loss support vector machines (primal)
95 % z=3 LibLinear with -- L1-loss support vector machines (dual)
96 % 'SVM:LIN4' LibLinear with -- multi-class support vector machines by Crammer and Singer
97 % 'DT' decision tree - not implemented yet.
99 % {'REG','MDA','MD2','QDA','QDA2','LD2','LD3','LD4','LD5','LD6','NBC','aNBC','WienerHopf','LDA/GSVD','MDA/GSVD', 'LDA/sparse','MDA/sparse', 'PLA', 'LMS','LDA/DELETION','MDA/DELETION','NBC/DELETION','RDA/DELETION','REG/DELETION','RDA','GDBC','SVM','RBF','PSVM','SVM11','SVM:LIN4','SVM:LIN0','SVM:LIN1','SVM:LIN2','SVM:LIN3','WINNOW', 'DT'};
101 % CC contains the model parameters of a classifier. Some time ago,
102 % CC was a statistical classifier containing the mean
103 % and the covariance of the data of each class (encoded in the
104 % so-called "extended covariance matrices". Nowadays, also other
105 % classifiers are supported.
107 % see also: TEST_SC, COVM, ROW_COL_DELETION
110 % [1] R. Duda, P. Hart, and D. Stork, Pattern Classification, second ed.
111 % John Wiley & Sons, 2001.
112 % [2] Peg Howland and Haesun Park,
113 % Generalizing Discriminant Analysis Using the Generalized Singular Value Decomposition
114 % IEEE Transactions on Pattern Analysis and Machine Intelligence, 26(8), 2004.
115 % dx.doi.org/10.1109/TPAMI.2004.46
116 % [3] http://www-static.cc.gatech.edu/~kihwan23/face_recog_gsvd.htm
117 % [4] Jieping Ye, Ravi Janardan, Cheong Hee Park, Haesun Park
118 % A new optimization criterion for generalized discriminant analysis on undersampled problems.
119 % The Third IEEE International Conference on Data Mining, Melbourne, Florida, USA
120 % November 19 - 22, 2003
121 % [5] J.D. Tebbens and P. Schlesinger (2006),
122 % Improving Implementation of Linear Discriminant Analysis for the Small Sample Size Problem
123 % Computational Statistics & Data Analysis, vol 52(1): 423-437, 2007
124 % http://www.cs.cas.cz/mweb/download/publi/JdtSchl2006.pdf
125 % [6] H. Zhang, The optimality of Naive Bayes,
126 % http://www.cs.unb.ca/profs/hzhang/publications/FLAIRS04ZhangH.pdf
127 % [7] J.H. Friedman. Regularized discriminant analysis.
128 % Journal of the American Statistical Association, 84:165–175, 1989.
129 % [8] G. Fung and O.L. Mangasarian, Proximal Support Vector Machine Classifiers, KDD 2001.
130 % Eds. F. Provost and R. Srikant, Proc. KDD-2001: Knowledge Discovery and Data Mining, August 26-29, 2001, San Francisco, CA.
132 % [9] Kai Keng Ang, Zhang Yang Chin, Haihong Zhang, Cuntai Guan.
133 % Filter Bank Common Spatial Pattern (FBCSP) in Brain-Computer Interface.
134 % IEEE International Joint Conference on Neural Networks, 2008. IJCNN 2008. (IEEE World Congress on Computational Intelligence).
135 % 1-8 June 2008 Page(s):2390 - 2397
136 % [10] R.-E. Fan, K.-W. Chang, C.-J. Hsieh, X.-R. Wang, and C.-J. Lin.
137 % LIBLINEAR: A Library for Large Linear Classification, Journal of Machine Learning Research 9(2008), 1871-1874.
138 % Software available at http://www.csie.ntu.edu.tw/~cjlin/liblinear
139 % [11] http://en.wikipedia.org/wiki/Perceptron#Learning_algorithm
140 % [12] Littlestone, N. (1988)
141 % "Learning Quickly When Irrelevant Attributes Abound: A New Linear-threshold Algorithm"
142 % Machine Learning 285-318(2)
143 % http://en.wikipedia.org/wiki/Winnow_(algorithm)
145 % $Id: train_sc.m 9601 2012-02-09 14:14:36Z schloegl $
146 % Copyright (C) 2005,2006,2007,2008,2009,2010 by Alois Schloegl <alois.schloegl@gmail.com>
147 % This function is part of the NaN-toolbox
148 % http://pub.ist.ac.at/~schloegl/matlab/NaN/
150 % This program is free software; you can redistribute it and/or
151 % modify it under the terms of the GNU General Public License
152 % as published by the Free Software Foundation; either version 3
153 % of the License, or (at your option) any later version.
155 % This program is distributed in the hope that it will be useful,
156 % but WITHOUT ANY WARRANTY; without even the implied warranty of
157 % MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
158 % GNU General Public License for more details.
160 % You should have received a copy of the GNU General Public License
161 % along with this program; if not, write to the Free Software
162 % Foundation, Inc., 51 Franklin Street - Fifth Floor, Boston, MA 02110-1301, USA.
165 error('insufficient input arguments\n\tusage: train_sc(D,C,...)\n');
167 if nargin<3, MODE = 'LDA'; end
168 if nargin<4, W = []; end
173 elseif ~isfield(MODE,'TYPE')
177 if isfield(MODE,'hyperparameters') && ~isfield(MODE,'hyperparameter'),
178 %% for backwards compatibility, this might become obsolete
179 warning('MODE.hyperparameters are used, You should use MODE.hyperparameter instead!!!');
180 MODE.hyperparameter = MODE.hyperparameters;
184 if sz(1)~=size(classlabel,1),
185 error('length of data and classlabel does not fit');
190 % several classifier can deal with NaN's, there is no need to remove them.
192 %% TODO: some classifiers can deal with NaN's in D. Test whether this can be relaxed.
193 %ix = any(isnan([classlabel]),2);
194 ix = any(isnan([D,classlabel]),2);
199 %ix = any(isnan([classlabel]),2);
200 ix = any(isnan([D,classlabel]),2);
204 warning('support for weighting of samples is still experimental');
208 if sz(1)~=length(classlabel),
209 error('length of data and classlabel does not fit');
211 if ~isfield(MODE,'hyperparameter')
212 MODE.hyperparameter = [];
217 elseif ~isempty(strfind(lower(MODE.TYPE),'/delet'))
218 POS1 = find(MODE.TYPE=='/');
219 [rix,cix] = row_col_deletion(D);
220 if ~isempty(W), W=W(rix); end
221 CC = train_sc(D(rix,cix),classlabel(rix,:),MODE.TYPE(1:POS1(1)-1),W);
222 CC.G = sparse(cix, 1:length(cix), 1, size(D,2), length(cix));
223 if isfield(CC,'weights')
224 W = [CC.weights(1,:); CC.weights(2:end,:)];
225 CC.weights = sparse(size(D,2)+1, size(W,2));
226 CC.weights([1,cix+1],:) = W;
227 CC.datatype = ['classifier:statistical:',lower(MODE.TYPE)];
229 CC.datatype = [CC.datatype,'/delet'];
232 elseif ~isempty(strfind(lower(MODE.TYPE),'nbpw'))
233 error('NBPW not implemented yet')
234 %%%% Naive Bayesian Parzen Window Classifier.
235 [classlabel,CC.Labels] = CL1M(classlabel);
236 for k = 1:length(CC.Labels),
237 [d,CC.MEAN(k,:)] = center(D(classlabel==CC.Labels(k),:),1);
238 [CC.VAR(k,:),CC.N(k,:)] = sumskipnan(d.^2,1);
239 h2_opt = (4./(3*CC.N(k,:))).^(2/5).*CC.VAR(k,:);
244 elseif ~isempty(strfind(lower(MODE.TYPE),'nbc'))
245 %%%% Naive Bayesian Classifier
246 if ~isempty(strfind(lower(MODE.TYPE),'anbc'))
247 %%%% Augmented Naive Bayesian classifier.
248 [CC.V,L] = eig(covm(D,'M',W));
251 CC.V = eye(size(D,2));
253 [classlabel,CC.Labels] = CL1M(classlabel);
254 for k = 1:length(CC.Labels),
255 ix = classlabel==CC.Labels(k);
256 %% [d,CC.MEAN(k,:)] = center(D(ix,:),1);
258 [s,n] = sumskipnan(D(ix,:),1,W(ix));
260 d = D(ix,:) - CC.MEAN(repmat(k,sum(ix),1),:);
261 [CC.VAR(k,:),CC.N(k,:)] = sumskipnan(d.^2,1,W(ix));
263 [s,n] = sumskipnan(D(ix,:),1);
265 d = D(ix,:) - CC.MEAN(repmat(k,sum(ix),1),:);
266 [CC.VAR(k,:),CC.N(k,:)] = sumskipnan(d.^2,1);
269 CC.VAR = CC.VAR./max(CC.N-1,0);
270 CC.datatype = ['classifier:',lower(MODE.TYPE)];
273 elseif ~isempty(strfind(lower(MODE.TYPE),'lpm'))
275 error('Error TRAIN_SC: Classifier (%s) does not support weighted samples.',MODE.TYPE);
277 % linear programming machine
278 % CPLEX optimizer: ILOG solver, ilog cplex 6.5 reference manual http://www.ilog.com
280 if ~isfield(MODE.hyperparameter,'c_value')
281 MODE.hyperparameter.c_value = 1;
283 [classlabel,CC.Labels] = CL1M(classlabel);
285 M = length(CC.Labels);
286 if M==2, M=1; end % For a 2-class problem, only 1 Discriminant is needed
288 %LPM = train_LPM(D,(classlabel==CC.Labels(k)),'C',MODE.hyperparameter.c_value);
289 LPM = train_LPM(D',(classlabel'==CC.Labels(k)));
290 CC.weights(:,k) = [-LPM.b; LPM.w(:)];
292 CC.hyperparameter.c_value = MODE.hyperparameter.c_value;
293 CC.datatype = ['classifier:',lower(MODE.TYPE)];
296 elseif ~isempty(strfind(lower(MODE.TYPE),'pla')),
297 % Perceptron Learning Algorithm
299 [rix,cix] = row_col_deletion(D);
300 [CL101,CC.Labels] = cl101(classlabel);
302 weights = sparse(length(cix)+1,M);
304 %ix = randperm(size(D,1)); %% randomize samples ???
305 if ~isfield(MODE.hyperparameter,'alpha')
306 if isfield(MODE.hyperparameter,'alpha')
307 alpha = MODE.hyperparameter.alpha;
312 %e = ((classlabel(k)==(1:M))-.5) - sign([1, D(k,cix)] * weights)/2;
313 e = CL101(k,:) - sign([1, D(k,cix)] * weights);
314 weights = weights + alpha * [1,D(k,cix)]' * e ;
318 if isfield(MODE.hyperparameter,'alpha')
319 W = W*MODE.hyperparameter.alpha;
322 %e = ((classlabel(k)==(1:M))-.5) - sign([1, D(k,cix)] * weights)/2;
323 e = CL101(k,:) - sign([1, D(k,cix)] * weights);
324 weights = weights + W(k) * [1,D(k,cix)]' * e ;
327 CC.weights = sparse(size(D,2)+1,M);
328 CC.weights([1,cix+1],:) = weights;
329 CC.datatype = ['classifier:',lower(MODE.TYPE)];
332 elseif ~isempty(strfind(lower(MODE.TYPE),'adaline')) || ~isempty(strfind(lower(MODE.TYPE),'lms')),
333 % adaptive linear elemente, least mean squares, delta rule, Widrow-Hoff,
335 [rix,cix] = row_col_deletion(D);
336 [CL101,CC.Labels] = cl101(classlabel);
338 weights = sparse(length(cix)+1,M);
340 %ix = randperm(size(D,1)); %% randomize samples ???
342 if isfield(MODE.hyperparameter,'alpha')
343 alpha = MODE.hyperparameter.alpha;
348 %e = (classlabel(k)==(1:M)) - [1, D(k,cix)] * weights;
349 e = CL101(k,:) - sign([1, D(k,cix)] * weights);
350 weights = weights + alpha * [1,D(k,cix)]' * e ;
354 if isfield(MODE.hyperparameter,'alpha')
355 W = W*MODE.hyperparameter.alpha;
358 %e = (classlabel(k)==(1:M)) - [1, D(k,cix)] * weights;
359 e = CL101(k,:) - sign([1, D(k,cix)] * weights);
360 weights = weights + W(k) * [1,D(k,cix)]' * e ;
363 CC.weights = sparse(size(D,2)+1,M);
364 CC.weights([1,cix+1],:) = weights;
365 CC.datatype = ['classifier:',lower(MODE.TYPE)];
368 elseif ~isempty(strfind(lower(MODE.TYPE),'winnow'))
371 error('Classifier (%s) does not support weighted samples.',MODE.TYPE);
374 [rix,cix] = row_col_deletion(D);
375 [CL101,CC.Labels] = cl101(classlabel);
377 weights = ones(length(cix),M);
381 e = CL101(k,:) - sign(D(k,cix) * weights - theta);
382 weights = weights.* 2.^(D(k,cix)' * e);
385 CC.weights = sparse(size(D,2)+1,M);
386 CC.weights(cix+1,:) = weights;
387 CC.datatype = ['classifier:',lower(MODE.TYPE)];
389 elseif ~isempty(strfind(lower(MODE.TYPE),'pls')) || ~isempty(strfind(lower(MODE.TYPE),'reg'))
390 % 4th version: support for weighted samples - work well with unequally distributed data:
391 % regression analysis, can handle sparse data, too.
396 [rix, cix] = row_col_deletion(D);
397 wD = [ones(length(rix),1),D(rix,cix)];
403 wD(:,k) = W(rix).*wD(:,k);
406 [CL101, CC.Labels] = cl101(classlabel(rix,:));
408 CC.weights = sparse(sz(2)+1,M);
410 %[rix, cix] = row_col_deletion(wD);
414 CC.weights([1,cix+1],:) = r\(q'*CL101);
416 CC.weights([1,cix+1],:) = r\(q'*(W(rix,ones(1,M)).*CL101));
419 % CC.weights(cix,k) = r\(q'*(W.*CL101(rix,k)));
421 CC.datatype = ['classifier:statistical:',lower(MODE.TYPE)];
424 elseif ~isempty(strfind(MODE.TYPE,'WienerHopf'))
425 % Q: equivalent to LDA
426 % equivalent to Regression, except regression can not deal with NaN's
427 [CL101,CC.Labels] = cl101(classlabel);
429 CC.weights = sparse(size(D,2)+1,M);
431 %c1 = classlabel(~isnan(classlabel));
432 %c2 = ones(sum(~isnan(classlabel)),M);
434 % c2(:,k) = c1==CC.Labels(k);
436 %CC.weights = cc\covm([ones(size(c2,1),1),D(~isnan(classlabel),:)],2*real(c2)-1,'M',W);
437 CC.weights = cc\covm([ones(size(D,1),1),D],CL101,'M',W);
438 CC.datatype = ['classifier:statistical:',lower(MODE.TYPE)];
441 elseif ~isempty(strfind(lower(MODE.TYPE),'/gsvd'))
443 error('Classifier (%s) does not support weighted samples.',MODE.TYPE);
445 % [2] Peg Howland and Haesun Park, 2004
446 % Generalizing Discriminant Analysis Using the Generalized Singular Value Decomposition
447 % IEEE Transactions on Pattern Analysis and Machine Intelligence, 26(8), 2004.
448 % dx.doi.org/10.1109/TPAMI.2004.46
449 % [3] http://www-static.cc.gatech.edu/~kihwan23/face_recog_gsvd.htm
451 [classlabel,CC.Labels] = CL1M(classlabel);
452 [rix,cix] = row_col_deletion(D);
454 Hw = zeros(length(rix)+length(CC.Labels), length(cix));
456 m0 = mean(D(rix,cix));
457 K = length(CC.Labels);
460 ix = find(classlabel(rix)==CC.Labels(k));
462 [Hw(ix,:), mu] = center(D(rix(ix),cix));
463 %Hb(k,:) = sqrt(N(k))*(mu(k,:)-m0);
464 Hw(length(rix)+k,:) = sqrt(N(k))*(mu-m0); % Hb(k,:)
467 [P,R,Q] = svd(Hw,'econ');
468 catch % needed because SVD(..,'econ') not supported in Matlab 6.x
474 %[size(D);size(P);size(Q);size(R)]
476 %P = P(1:size(D,1),1:t);
478 [U,E,W] = svd(P(1:length(rix),1:t),0);
479 %[size(U);size(E);size(W)]
481 %[size(Q);size(R);size(W)]
483 %G = Q(1:t,:)'*[R\W'];
484 G = Q(:,1:t)*(R\W'); % this works as well and needs only 'econ'-SVD
485 %G = G(:,1:t); % not needed
487 % do not use this, gives very bad results for Medline database
488 %G = G(:,1:K); this seems to be a typo in [2] and [3].
489 CC = train_sc(D(:,cix)*G,classlabel,MODE.TYPE(1:find(MODE.TYPE=='/')-1));
490 CC.G = sparse(size(D,2),size(G,2));
492 if isfield(CC,'weights')
493 CC.weights = sparse([CC.weights(1,:); CC.G*CC.weights(2:end,:)]);
494 CC.datatype = ['classifier:statistical:', lower(MODE.TYPE)];
496 CC.datatype = [CC.datatype,'/gsvd'];
500 elseif ~isempty(strfind(lower(MODE.TYPE),'sparse'))
502 error('Classifier (%s) does not support weighted samples.',MODE.TYPE);
504 % [5] J.D. Tebbens and P.Schlesinger (2006),
505 % Improving Implementation of Linear Discriminant Analysis for the Small Sample Size Problem
506 % http://www.cs.cas.cz/mweb/download/publi/JdtSchl2006.pdf
508 [classlabel,CC.Labels] = CL1M(classlabel);
509 [rix,cix] = row_col_deletion(D);
511 warning('sparse LDA is sensitive to linear transformations')
512 M = length(CC.Labels);
513 G = sparse([],[],[],length(rix),M,length(rix));
515 G(classlabel(rix)==CC.Labels(k),k) = 1;
519 G = train_lda_sparse(D(rix,cix),G,1,tol);
520 CC.datatype = 'classifier:slda';
521 POS1 = find(MODE.TYPE=='/');
522 %G = v(:,1:size(G.trafo,2)).*G.trafo;
523 %CC.weights = s * CC.weights(2:end,:) + sparse(1,1:M,CC.weights(1,:),sz(2)+1,M);
525 CC = train_sc(D(rix,cix)*G.trafo,classlabel(rix),MODE.TYPE(1:POS1(1)-1));
526 CC.G = sparse(size(D,2),size(G.trafo,2));
527 CC.G(cix,:) = G.trafo;
528 if isfield(CC,'weights')
529 CC.weights = sparse([CC.weights(1,:); CC.G*CC.weights(2:end,:)]);
530 CC.datatype = ['classifier:statistical:',lower(MODE.TYPE)];
532 CC.datatype = [CC.datatype,'/sparse'];
535 elseif ~isempty(strfind(lower(MODE.TYPE),'rbf'))
537 error('Classifier (%s) does not support weighted samples.',MODE.TYPE);
540 % Martin Hieden's RBF-SVM
541 if exist('svmpredict_mex','file')==3,
542 MODE.TYPE = 'SVM:LIB:RBF';
544 error('No SVM training algorithm available. Install LibSVM for Matlab.\n');
546 CC.options = '-t 2 -q'; %use RBF kernel, set C, set gamma
547 if isfield(MODE.hyperparameter,'gamma')
548 CC.options = sprintf('%s -c %g', CC.options, MODE.hyperparameter.c_value); % set C
550 if isfield(MODE.hyperparameter,'c_value')
551 CC.options = sprintf('%s -g %g', CC.options, MODE.hyperparameter.gamma); % set C
556 CC.prewhite = sparse(2:sz(2)+1,1:sz(2),r,sz(2)+1,sz(2),2*sz(2));
557 CC.prewhite(1,:) = -m.*r;
559 [classlabel,CC.Labels] = CL1M(classlabel);
560 CC.model = svmtrain_mex(classlabel, D, CC.options); % Call the training mex File
561 CC.datatype = ['classifier:',lower(MODE.TYPE)];
564 elseif ~isempty(strfind(lower(MODE.TYPE),'svm11'))
566 error('Classifier (%s) does not support weighted samples.',MODE.TYPE);
569 if ~isfield(MODE.hyperparameter,'c_value')
570 MODE.hyperparameter.c_value = 1;
573 CC.options=sprintf('-c %g -t 0 -q',MODE.hyperparameter.c_value); %use linear kernel, set C
574 CC.hyperparameter.c_value = MODE.hyperparameter.c_value;
578 CC.prewhite = sparse(2:sz(2)+1,1:sz(2),r,sz(2)+1,sz(2),2*sz(2));
579 CC.prewhite(1,:) = -m.*r;
581 [classlabel,CC.Labels] = CL1M(classlabel);
582 CC.model = svmtrain_mex(classlabel, D, CC.options); % Call the training mex File
584 FUN = 'SVM:LIB:1vs1';
585 CC.datatype = ['classifier:',lower(FUN)];
588 elseif ~isempty(strfind(lower(MODE.TYPE),'psvm'))
590 %%% error('Classifier (%s) does not support weighted samples.',MODE.TYPE);
591 warning('Classifier (%s) in combination with weighted samples is not tested.',MODE.TYPE);
593 if ~isfield(MODE,'hyperparameter')
595 elseif isfield(MODE.hyperparameter,'nu')
596 nu = MODE.hyperparameter.nu;
601 [CL101,CC.Labels] = cl101(classlabel);
602 CC.weights = sparse(n+1,size(CL101,2));
605 d = sparse(1:m,1:m,CL101(:,k));
606 H = d * [ones(m,1),D];
608 r = sumskipnan(H,1,W)';
609 %%% r = (speye(n+1)/nu + H' * H)\r; %solve (I/nu+H’*H)r=H’*e
610 [HTH, nn] = covm(H,H,'M',W);
611 r = (speye(n+1)/nu + HTH)\r; %solve (I/nu+H’*H)r=H’*e
613 %%% CC.weights(:,k) = u'*H;
614 [c,nn] = covm(u,H,'M',W);
615 CC.weights(:,k) = c';
617 CC.hyperparameter.nu = nu;
618 CC.datatype = ['classifier:',lower(MODE.TYPE)];
620 elseif ~isempty(strfind(lower(MODE.TYPE),'svm:lin4'))
621 if ~isfield(MODE.hyperparameter,'c_value')
622 MODE.hyperparameter.c_value = 1;
625 [classlabel,CC.Labels] = CL1M(classlabel);
626 M = length(CC.Labels);
627 CC.weights = sparse(size(D,2)+1,M);
629 [rix,cix] = row_col_deletion(D);
632 [D,r,m]=zscore(D(rix,cix),1);
634 s = sparse(2:sz2+1,1:sz2,r,sz2+1,sz2,2*sz2);
637 CC.options = sprintf('-s 4 -B 1 -c %f -q', MODE.hyperparameter.c_value); % C-SVC, C=1, linear kernel, degree = 1,
638 model = train(W, classlabel, sparse(D), CC.options); % C-SVC, C=1, linear kernel, degree = 1,
639 weights = model.w([end,1:end-1],:)';
641 CC.weights([1,cix+1],:) = s * weights(2:end,:) + sparse(1,1:M,weights(1,:),sz2+1,M); % include pre-whitening transformation
642 CC.weights([1,cix+1],:) = s * CC.weights(cix+1,:) + sparse(1,1:M,CC.weights(1,:),sz2+1,M); % include pre-whitening transformation
643 CC.hyperparameter.c_value = MODE.hyperparameter.c_value;
644 CC.datatype = ['classifier:',lower(MODE.TYPE)];
647 elseif ~isempty(strfind(lower(MODE.TYPE),'svm'))
649 if ~isfield(MODE.hyperparameter,'c_value')
650 MODE.hyperparameter.c_value = 1;
652 if any(MODE.TYPE==':'),
654 elseif exist('train','file')==3,
655 MODE.TYPE = 'SVM:LIN'; %% liblinear
656 elseif exist('svmtrain_mex','file')==3,
657 MODE.TYPE = 'SVM:LIB';
658 elseif (exist('svmtrain','file')==3),
659 MODE.TYPE = 'SVM:LIB';
660 fprintf(1,'You need to rename %s to svmtrain_mex.mex !! \n Press any key to continue !!!\n',which('svmtrain.mex'));
661 elseif exist('svmtrain','file')==2,
662 MODE.TYPE = 'SVM:bioinfo';
663 elseif exist('mexSVMTrain','file')==3,
664 MODE.TYPE = 'SVM:OSU';
665 elseif exist('svcm_train','file')==2,
666 MODE.TYPE = 'SVM:LOO';
667 elseif exist('svmclass','file')==2,
668 MODE.TYPE = 'SVM:KM';
669 elseif exist('svc','file')==2,
670 MODE.TYPE = 'SVM:Gunn';
672 error('No SVM training algorithm available. Install OSV-SVM, or LOO-SVM, or libSVM for Matlab.\n');
675 %%CC = train_svm(D,classlabel,MODE);
676 [CL101,CC.Labels] = cl101(classlabel);
678 [rix,cix] = row_col_deletion(D);
679 CC.weights = sparse(sz(2)+1, M);
682 [D,r,m]=zscore(D(rix,cix),1);
684 s = sparse(2:sz2+1,1:sz2,r,sz2+1,sz2,2*sz2);
689 if strncmp(MODE.TYPE, 'SVM:LIN',7);
690 if isfield(MODE,'options')
691 CC.options = MODE.options;
694 if length(MODE.TYPE)>7, t=MODE.TYPE(8)-'0'; end
695 if (t<0 || t>6) t=0; end
696 CC.options = sprintf('-s %i -B 1 -c %f -q',t, MODE.hyperparameter.c_value); % C-SVC, C=1, linear kernel, degree = 1,
698 model = train(W, cl, sparse(D), CC.options); % C-SVC, C=1, linear kernel, degree = 1,
701 w = -model.w(:,1:end-1)';
702 Bias = -model.w(:,end)';
704 elseif strcmp(MODE.TYPE, 'SVM:LIB'); %% tested with libsvm-mat-2.9-1
705 if isfield(MODE,'options')
706 CC.options = MODE.options;
708 CC.options = sprintf('-s 0 -c %f -t 0 -d 1 -q', MODE.hyperparameter.c_value); % C-SVC, C=1, linear kernel, degree = 1,
710 model = svmtrain_mex(cl, D, CC.options); % C-SVC, C=1, linear kernel, degree = 1,
711 w = cl(1) * model.SVs' * model.sv_coef; %Calculate decision hyperplane weight vector
712 % ensure correct sign of weight vector and Bias according to class label
713 Bias = model.rho * cl(1);
715 elseif strcmp(MODE.TYPE, 'SVM:bioinfo');
716 % SVM classifier from bioinformatics toolbox.
717 % Settings suggested by Ian Daly, 2011-06-06
718 options = optimset('Display','iter','maxiter',20000, 'largescale','off');
719 CC.SVMstruct = svmtrain(D, cl, 'AUTOSCALE', 0, 'quadprog_opts', options, 'Method', 'LS', 'kernel_function', 'polynomial');
720 Bias = -CC.SVMstruct.Bias;
721 w = -CC.SVMstruct.Alpha'*CC.SVMstruct.SupportVectors;
723 elseif strcmp(MODE.TYPE, 'SVM:OSU');
724 [AlphaY, SVs, Bias] = mexSVMTrain(D', cl', [0 1 1 1 MODE.hyperparameter.c_value]); % Linear Kernel, C=1; degree=1, c-SVM
725 w = -SVs * AlphaY'*cl(1); %Calculate decision hyperplane weight vector
726 % ensure correct sign of weight vector and Bias according to class label
727 Bias = -Bias * cl(1);
729 elseif strcmp(MODE.TYPE, 'SVM:LOO');
730 [a, Bias, g, inds] = svcm_train(D, cl, MODE.hyperparameter.c_value); % C = 1;
731 w = D(inds,:)' * (a(inds).*cl(inds)) ;
733 elseif strcmp(MODE.TYPE, 'SVM:Gunn');
734 [nsv, alpha, Bias,svi] = svc(D, cl, 1, MODE.hyperparameter.c_value); % linear kernel, C = 1;
735 w = D(svi,:)' * alpha(svi) * cl(1);
738 elseif strcmp(MODE.TYPE, 'SVM:KM');
739 [xsup,w1,Bias,inds] = svmclass(D, cl, MODE.hyperparameter.c_value, 1, 'poly', 1); % C = 1;
740 w = -D(inds,:)' * w1;
743 fprintf(2,'Error TRAIN_SVM: no SVM training algorithm available\n');
747 CC.weights(1,k) = -Bias;
748 CC.weights(cix+1,k) = w;
750 CC.weights([1,cix+1],:) = s * CC.weights(cix+1,:) + sparse(1,1:M,CC.weights(1,:),sz2+1,M); % include pre-whitening transformation
751 CC.hyperparameter.c_value = MODE.hyperparameter.c_value;
752 CC.datatype = ['classifier:',lower(MODE.TYPE)];
755 elseif ~isempty(strfind(lower(MODE.TYPE),'csp'))
756 CC.datatype = ['classifier:',lower(MODE.TYPE)];
757 [classlabel,CC.Labels] = CL1M(classlabel);
758 CC.MD = repmat(NaN,[sz(2)+[1,1],length(CC.Labels)]);
760 for k = 1:length(CC.Labels),
761 %% [CC.MD(k,:,:),CC.NN(k,:,:)] = covm(D(classlabel==CC.Labels(k),:),'E');
762 ix = classlabel==CC.Labels(k);
764 [CC.MD(:,:,k),CC.NN(:,:,k)] = covm(D(ix,:), 'E');
766 [CC.MD(:,:,k),CC.NN(:,:,k)] = covm(D(ix,:), 'E', W(ix));
771 %%% ### This is a hack ###
773 CC.FiltB = ones(CC.FiltA,1);
774 d = filtfilt(CC.FiltB,CC.FiltA,(D*W).^2);
776 CC.CSP = train_sc(log(d),classlabel);
779 else % Linear and Quadratic statistical classifiers
780 CC.datatype = ['classifier:statistical:',lower(MODE.TYPE)];
781 [classlabel,CC.Labels] = CL1M(classlabel);
782 CC.MD = repmat(NaN,[sz(2)+[1,1],length(CC.Labels)]);
784 for k = 1:length(CC.Labels),
785 ix = classlabel==CC.Labels(k);
787 [CC.MD(:,:,k),CC.NN(:,:,k)] = covm(D(ix,:), 'E');
789 [CC.MD(:,:,k),CC.NN(:,:,k)] = covm(D(ix,:), 'E', W(ix));
795 if strncmpi(MODE.TYPE,'LD',2) || strncmpi(MODE.TYPE,'FDA',3) || strncmpi(MODE.TYPE,'FLDA',3),
797 %if NC(1)==2, NC(1)=1; end % linear two class problem needs only one discriminant
798 CC.weights = repmat(NaN,NC(2),NC(3)); % memory allocation
799 type = MODE.TYPE(3)-'0';
801 ECM0 = squeeze(sum(ECM,3)); %decompose ECM
803 ix = [1:k-1,k+1:NC(3)];
804 dM = CC.MD(:,1,k)./CC.NN(:,1,k) - sum(CC.MD(:,1,ix),3)./sum(CC.NN(:,1,ix),3);
807 ecm0 = (sum(ECM(:,:,ix),3)/(NC(3)-1) + ECM(:,:,k));
809 ecm0 = 2*(sum(ECM(:,:,ix),3) + ECM(:,:,k))/NC(3);
810 % ecm0 = sum(CC.MD,3)./sum(CC.NN,3);
814 ecm0 = sum(CC.MD(:,:,ix),3)./sum(CC.NN(:,:,ix),3);
815 otherwise % LD3, LDA, FDA
818 if isfield(MODE.hyperparameter,'gamma')
819 ecm0 = ecm0 + mean(diag(ecm0))*eye(size(ecm0))*MODE.hyperparameter.gamma;
822 CC.weights(:,k) = ecm0\dM;
825 %CC.weights = sparse(CC.weights);
827 elseif strcmpi(MODE.TYPE,'RDA');
828 if isfield(MODE,'hyperparameter')
829 CC.hyperparameter = MODE.hyperparameter;
832 if ~isfield(CC.hyperparameter,'gamma')
833 CC.hyperparameter.gamma = 0;
835 if ~isfield(CC.hyperparameter,'lambda')
836 CC.hyperparameter.lambda = 1;
840 nn = ECM0(1,1,1); % number of samples in training set for class k
841 XC = squeeze(ECM0(:,:,1))/nn; % normalize correlation matrix
842 M = XC(1,2:NC(2)); % mean
843 S = XC(2:NC(2),2:NC(2)) - M'*M;% covariance matrix
847 U0 = v(diag(d)==0,:);
851 %M = M/nn; S=S/(nn-1);
854 % ICOV1 = zeros(size(S));
856 %[M,sd,S,xc,N] = decovm(ECM{k}); %decompose ECM
858 nn = ECM(1,1,k);% number of samples in training set for class k
859 XC = squeeze(ECM(:,:,k))/nn;% normalize correlation matrix
860 M = XC(1,2:NC(2));% mean
861 S = XC(2:NC(2),2:NC(2)) - M'*M;% covariance matrix
862 %M = M/nn; S=S/(nn-1);
864 %ICOV(1) = ICOV(1) + (XC(2:NC(2),2:NC(2)) - )/nn
867 CC.IR{k} = [-M;eye(NC(2)-1)]*inv(S)*[-M',eye(NC(2)-1)]; % inverse correlation matrix extended by mean
868 CC.IR0{k} = [-M;eye(NC(2)-1)]*ICOV0*[-M',eye(NC(2)-1)]; % inverse correlation matrix extended by mean
870 if exist('OCTAVE_VERSION','builtin')
873 CC.logSF(k) = log(nn) - d/2*log(2*pi) - det(S)/2;
874 CC.logSF2(k) = -2*log(nn/sum(ECM(:,1,1)));
875 CC.logSF3(k) = d*log(2*pi) + log(det(S));
876 CC.logSF4(k) = log(det(S)) + 2*log(nn);
877 CC.logSF5(k) = log(det(S));
878 CC.logSF6(k) = log(det(S)) - 2*log(nn/sum(ECM(:,1,1)));
879 CC.logSF7(k) = log(det(S)) + d*log(2*pi) - 2*log(nn/sum(ECM(:,1,1)));
880 CC.logSF8(k) = sum(log(svd(S))) + log(nn) - log(sum(ECM(:,1,1)));
881 CC.SF(k) = nn/sqrt((2*pi)^d * det(S));
888 function [CL101,Labels] = cl101(classlabel)
889 %% convert classlabels to {-1,1} encoding
891 if (all(classlabel>=0) && all(classlabel==fix(classlabel)) && (size(classlabel,2)==1))
894 CL101 = (classlabel==2)-(classlabel==1);
896 CL101 = zeros(size(classlabel,1),M);
898 %% One-versus-Rest scheme
899 CL101(:,k) = 2*real(classlabel==k) - 1;
902 CL101(isnan(classlabel),:) = NaN; %% or zero ???
904 elseif all((classlabel==1) | (classlabel==-1) | (classlabel==0) )
909 error('format of classlabel unsupported');
916 function [cl1m, Labels] = CL1M(classlabel)
917 %% convert classlabels to 1..M encoding
918 if (all(classlabel>=0) && all(classlabel==fix(classlabel)) && (size(classlabel,2)==1))
921 elseif all((classlabel==1) | (classlabel==-1) | (classlabel==0) )
923 M = size(classlabel,2);
924 if any(sum(classlabel==1,2)>1)
925 warning('invalid format of classlabel - at most one category may have +1');
928 cl1m = (classlabel==-1) + 2*(classlabel==+1);
930 [tmp, cl1m] = max(classlabel,[],2);
932 warning('some class might not be properly represented - you might what to add another column to classlabel = [max(classlabel,[],2)<1,classlabel]');
934 cl1m(tmp<1)= 0; %% or NaN ???
938 error('format of classlabel unsupported');
940 Labels = 1:max(cl1m);