]> Creatis software - CreaPhase.git/blob - octave_packages/linear-algebra-2.2.0/nmf_bpas.m
Add a useful package (from Source forge) for octave
[CreaPhase.git] / octave_packages / linear-algebra-2.2.0 / nmf_bpas.m
1 ## Copyright (c) 2012 by Jingu Kim and Haesun Park <jingu@cc.gatech.edu>
2 ##
3 ##    This program is free software: you can redistribute it and/or modify
4 ##    it under the terms of the GNU General Public License as published by
5 ##    the Free Software Foundation, either version 3 of the License, or
6 ##    any later version.
7 ##
8 ##    This program is distributed in the hope that it will be useful,
9 ##    but WITHOUT ANY WARRANTY; without even the implied warranty of
10 ##    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 ##    GNU General Public License for more details.
12 ##
13 ##    You should have received a copy of the GNU General Public License
14 ##    along with this program. If not, see <http://www.gnu.org/licenses/>.
15
16 ## -*- texinfo -*-
17 ## @deftypefn {Function File} {[@var{W}, @var{H}, @var{iter}, @var{HIS}] = } nmf_bpas (@var{A}, @var{k})
18 ## Nonnegative Matrix Factorization by Alternating Nonnegativity Constrained Least Squares
19 ## using Block Principal Pivoting/Active Set method.
20 ##
21 ## This function solves one the following problems: given @var{A} and @var{k}, find @var{W} and @var{H} such that
22 ##     (1) minimize 1/2 * || @var{A}-@var{W}@var{H} ||_F^2
23 ##     (2) minimize 1/2 * ( || @var{A}-@var{W}@var{H} ||_F^2 + alpha * || @var{W} ||_F^2 + beta * || @var{H} ||_F^2 )
24 ##     (3) minimize 1/2 * ( || @var{A}-@var{W}@var{H} ||_F^2 + alpha * || @var{W} ||_F^2 + beta * (sum_(i=1)^n || @var{H}(:,i) ||_1^2 ) )
25 ##     where @var{W}>=0 and @var{H}>=0 elementwise.
26 ## The input arguments are @var{A} : Input data matrix (m x n) and @var{k} : Target low-rank.
27 ##
28 ##
29 ## @strong{Optional Inputs}
30 ## @table @samp
31 ## @item Type : Default is 'regularized', which is recommended for quick application testing unless 'sparse' or 'plain' is explicitly needed. If sparsity is needed for 'W' factor, then apply this function for the transpose of 'A' with formulation (3). Then, exchange 'W' and 'H' and obtain the transpose of them. Imposing sparsity for both factors is not recommended and thus not included in this software.
32 ## @table @asis
33 ## @item 'plain' to use formulation (1)
34 ## @item 'regularized' to use formulation (2)
35 ## @item 'sparse' to use formulation (3)
36 ## @end table
37 ##
38 ## @item NNLSSolver : Default is 'bp', which is in general faster.
39 ## @table @asis
40 ## item 'bp' to use the algorithm in [1]
41 ## item 'as' to use the algorithm in [2]
42 ## @end table
43 ##
44 ## @item Alpha      : Parameter alpha in the formulation (2) or (3). Default is the average of all elements in A. No good justfication for this default value, and you might want to try other values.
45 ## @item  Beta      : Parameter beta in the formulation (2) or (3).
46 ##               Default is the average of all elements in A. No good justfication for this default value, and you might want to try other values.
47 ## @item MaxIter    : Maximum number of iterations. Default is 100.
48 ## @item MinIter    : Minimum number of iterations. Default is 20.
49 ## @item MaxTime    : Maximum amount of time in seconds. Default is 100,000.
50 ## @item Winit      : (m x k) initial value for W.
51 ## @item Hinit      : (k x n) initial value for H.
52 ## @item Tol        : Stopping tolerance. Default is 1e-3. If you want to obtain a more accurate solution, decrease TOL and increase MAX_ITER at the same time.
53 ## @item Verbose  :
54 ## @table @asis
55 ## @item 0 (default) - No debugging information is collected.@*
56 ## @item 1 (debugging purpose) - History of computation is returned by 'HIS' variable.
57 ## @item 2 (debugging purpose) - History of computation is additionally printed on screen.
58 ## @end table
59 ## @end table
60 ##
61 ## @strong{Outputs}
62 ## @table @samp
63 ## @item W    : Obtained basis matrix (m x k)
64 ## @item H    : Obtained coefficients matrix (k x n)
65 ## @item iter : Number of iterations
66 ## @item HIS  : (debugging purpose) History of computation
67 ## @end table
68 ##
69 ## Usage Examples:
70 ## @example
71 ##  nmf(A,10)
72 ##  nmf(A,20,'verbose',2)
73 ##  nmf(A,30,'verbose',2,'nnls_solver','as')
74 ##  nmf(A,5,'verbose',2,'type','sparse')
75 ##  nmf(A,60,'verbose',1,'type','plain','w_init',rand(m,k))
76 ##  nmf(A,70,'verbose',2,'type','sparse','nnls_solver','bp','alpha',1.1,'beta',1.3)
77 ## @end example
78 ##
79 ## References:
80 ##  [1] For using this software, please cite:@*
81 ##      Jingu Kim and Haesun Park, Toward Faster Nonnegative Matrix Factorization: A New Algorithm and Comparisons,@*
82 ##      In Proceedings of the 2008 Eighth IEEE International Conference on Data Mining (ICDM'08), 353-362, 2008@*
83 ##  [2] If you use 'nnls_solver'='as' (see below), please cite:@*
84 ##      Hyunsoo Kim and Haesun Park, Nonnegative Matrix Factorization Based @*
85 ##      on Alternating Nonnegativity Constrained Least Squares and Active Set Method, @*
86 ##      SIAM Journal on Matrix Analysis and Applications, 2008, 30, 713-730
87 ##
88 ## Check original code at @url{http://www.cc.gatech.edu/~jingu}
89 ##
90 ## @seealso{nmf_pg}
91 ## @end deftypefn
92
93 ## 2012 - Modified and adapted to Octave 3.6.1 by
94 ## Juan Pablo Carbajal <carbajal@ifi.uzh.ch>
95
96 # TODO
97 # - Format code.
98 # - Vectorize loops.
99
100 function [W, H, iter, HIS] = nmf_bpas (A, k , varargin)
101   page_screen_output (0, "local");
102
103   [m,n] = size(A);
104   ST_RULE = 1;
105
106   # --- Parse arguments --- #
107   parser = inputParser ();
108   parser.FunctionName = "nmf_bpas";
109   parser = addParamValue (parser,'Winit', rand(m,k), @ismatrix);
110   parser = addParamValue (parser,'Hinit', rand(k,n), @ismatrix);
111   parser = addParamValue (parser,'Tol', 1e-3, @(x)x>0);
112   parser = addParamValue (parser,'Alpha', mean (A(:)), @(x)x>=0);
113   parser = addParamValue (parser,'Beta', mean (A(:)), @(x)x>=0);
114   parser = addParamValue (parser,'MaxIter', 100, @(x)x>0);
115   parser = addParamValue (parser,'MaxTime', 1e3, @(x)x>0);
116   parser = addParamValue (parser,'Verbose', false);
117
118   val_type = @(x,c) ischar (x) && any (strcmpi (x,c));
119   parser = addParamValue (parser,'Type', 'regularized', ...
120                             @(x)val_type (x,{'regularized', 'sparse','plain'}));
121   parser = addParamValue (parser,'NNLSSolver', 'bp', ...
122                                                  @(x)val_type (x,{'bp', 'as'}));
123
124   parser = parse(parser,varargin{:});
125
126   % Default configuration
127   par.m           = m;
128   par.n           = n;
129   par.type        = parser.Results.Type;
130   par.nnls_solver = parser.Results.NNLSSolver;
131   par.alpha       = parser.Results.Alpha;
132   par.beta        = parser.Results.Beta;
133   par.max_iter    = parser.Results.MaxIter;
134   par.min_iter    = 20;
135   par.max_time    = parser.Results.MaxTime;
136   par.tol         = parser.Results.Tol;
137   par.verbose     = parser.Results.Verbose;
138   W               = parser.Results.Winit;
139   H               = parser.Results.Hinit;
140
141   # TODO check if can be removed
142   argAlpha        = par.alpha;
143   argBeta         = par.beta;
144
145   clear parser val_type
146
147 ### PARSING TYPE
148 # TODO add callbacks here to use during main loop. See [1]
149   % for regularized/sparse case
150   salphaI = sqrt (par.alpha) * eye (k);
151   zerokm = zeros (k,m);
152
153   if strcmpi (par.type, 'regularized')
154     sbetaI = sqrt (par.beta) * eye (k);
155     zerokn = zeros (k,n);
156
157   elseif strcmpi (par.type, 'sparse')
158     sbetaE = sqrt (par.beta) * ones (1,k);
159     betaI  = par.beta * ones (k,k);
160     zero1n = zeros (1,n);
161
162   end
163 ###
164
165 # Verbosity
166   display(par);
167 ### Done till here Sun Mar 25 19:00:26 2012
168
169     HIS = 0;
170     if par.verbose          % collect information for analysis/debugging
171         [gradW,gradH] = getGradient(A,W,H,par.type,par.alpha,par.beta);
172         initGrNormW = norm(gradW,'fro');
173         initGrNormH = norm(gradH,'fro');
174         initNorm = norm(A,'fro');
175         numSC = 3;
176         initSCs = zeros(numSC,1);
177         for j=1:numSC
178             initSCs(j) = getInitCriterion(j,A,W,H,par.type,par.alpha,par.beta,gradW,gradH);
179         end
180 %---(1)------(2)--------(3)--------(4)--------(5)---------(6)----------(7)------(8)-----(9)-------(10)--------------(11)-------
181 % iter # | elapsed | totalTime | subIterW | subIterH | rel. obj.(%) | NM_GRAD | GRAD | DELTA | W density (%) | H density (%)
182 %------------------------------------------------------------------------------------------------------------------------------
183         HIS = zeros(1,11);
184         HIS(1,[1:5])=0;
185         ver.initGrNormW = initGrNormW;
186         ver.initGrNormH = initGrNormH;
187         ver.initNorm = initNorm;                            HIS(1,6) = ver.initNorm;
188         ver.SC1 = initSCs(1);                               HIS(1,7) = ver.SC1;
189         ver.SC2 = initSCs(2);                               HIS(1,8) = ver.SC2;
190         ver.SC3 = initSCs(3);                               HIS(1,9) = ver.SC3;
191         ver.W_density = length(find(W>0))/(m*k);            HIS(1,10) = ver.W_density;
192         ver.H_density = length(find(H>0))/(n*k);            HIS(1,11) = ver.H_density;
193         if par.verbose == 2
194           disp (ver);
195         end
196         tPrev = cputime;
197     end
198
199     tStart = cputime;
200     tTotal = 0;
201     initSC = getInitCriterion(ST_RULE,A,W,H,par.type,par.alpha,par.beta);
202     SCconv = 0;
203     SC_COUNT = 3;
204
205 #TODO: [1] Replace with callbacks avoid switching each time
206     for iter=1:par.max_iter
207         switch par.type
208             case 'plain'
209                 [H,gradHX,subIterH] = nnlsm(W,A,H,par.nnls_solver);
210                 [W,gradW,subIterW] = nnlsm(H',A',W',par.nnls_solver);, W=W';, gradW=gradW';
211                 gradH = (W'*W)*H - W'*A;
212             case 'regularized'
213                 [H,gradHX,subIterH] = nnlsm([W;sbetaI],[A;zerokn],H,par.nnls_solver);
214                 [W,gradW,subIterW] = nnlsm([H';salphaI],[A';zerokm],W',par.nnls_solver);, W=W';, gradW=gradW';
215                 gradH = (W'*W)*H - W'*A + par.beta*H;
216             case 'sparse'
217                 [H,gradHX,subIterH] = nnlsm([W;sbetaE],[A;zero1n],H,par.nnls_solver);
218                 [W,gradW,subIterW] = nnlsm([H';salphaI],[A';zerokm],W',par.nnls_solver);, W=W';, gradW=gradW';
219                 gradH = (W'*W)*H - W'*A + betaI*H;
220         end
221
222         if par.verbose          % collect information for analysis/debugging
223             elapsed = cputime-tPrev;
224             tTotal = tTotal + elapsed;
225             ver = [];
226             idx = iter+1;
227 %---(1)------(2)--------(3)--------(4)--------(5)---------(6)----------(7)------(8)-----(9)-------(10)--------------(11)-------
228 % iter # | elapsed | totalTime | subIterW | subIterH | rel. obj.(%) | NM_GRAD | GRAD | DELTA | W density (%) | H density (%)
229 %------------------------------------------------------------------------------------------------------------------------------
230             ver.iter = iter;                                    HIS(idx,1)=iter;
231             ver.elapsed = elapsed;                              HIS(idx,2)=elapsed;
232             ver.tTotal = tTotal;                                HIS(idx,3)=tTotal;
233             ver.subIterW = subIterW;                            HIS(idx,4)=subIterW;
234             ver.subIterH = subIterH;                            HIS(idx,5)=subIterH;
235             ver.relError = norm(A-W*H,'fro')/initNorm;          HIS(idx,6)=ver.relError;
236             ver.SC1 = getStopCriterion(1,A,W,H,par.type,par.alpha,par.beta,gradW,gradH)/initSCs(1);     HIS(idx,7)=ver.SC1;
237             ver.SC2 = getStopCriterion(2,A,W,H,par.type,par.alpha,par.beta,gradW,gradH)/initSCs(2);     HIS(idx,8)=ver.SC2;
238             ver.SC3 = getStopCriterion(3,A,W,H,par.type,par.alpha,par.beta,gradW,gradH)/initSCs(3);     HIS(idx,9)=ver.SC3;
239             ver.W_density = length(find(W>0))/(m*k);            HIS(idx,10)=ver.W_density;
240             ver.H_density = length(find(H>0))/(n*k);            HIS(idx,11)=ver.H_density;
241             if par.verbose == 2, display(ver);, end
242             tPrev = cputime;
243         end
244
245         if (iter > par.min_iter)
246             SC = getStopCriterion(ST_RULE,A,W,H,par.type,par.alpha,par.beta,gradW,gradH);
247             if (par.verbose && (tTotal > par.max_time)) || (~par.verbose && ((cputime-tStart)>par.max_time))
248                 break;
249             elseif (SC/initSC <= par.tol)
250                 SCconv = SCconv + 1;
251                 if (SCconv >= SC_COUNT)
252                   break;
253                 end
254             else
255                 SCconv = 0;
256             end
257         end
258     end
259     [m,n]=size(A);
260     norm2=sqrt(sum(W.^2,1));
261     toNormalize = norm2>0;
262     W(:,toNormalize) = W(:,toNormalize)./repmat(norm2(toNormalize),m,1);
263     H(toNormalize,:) = H(toNormalize,:).*repmat(norm2(toNormalize)',1,n);
264
265     final.iterations = iter;
266     if par.verbose
267         final.elapsed_total = tTotal;
268     else
269         final.elapsed_total = cputime-tStart;
270     end
271     final.relative_error = norm(A-W*H,'fro')/norm(A,'fro');
272     final.W_density = length(find(W>0))/(m*k);
273     final.H_density = length(find(H>0))/(n*k);
274     display(final);
275
276 endfunction
277
278 %------------------------------------------------------------------------------------------------------------------------
279 %                                    Utility Functions
280 %-------------------------------------------------------------------------------
281 function retVal = getInitCriterion(stopRule,A,W,H,type,alpha,beta,gradW,gradH)
282 % STOPPING_RULE : 1 - Normalized proj. gradient
283 %                 2 - Proj. gradient
284 %                 3 - Delta by H. Kim
285 %                 0 - None (want to stop by MAX_ITER or MAX_TIME)
286     if nargin~=9
287         [gradW,gradH] = getGradient(A,W,H,type,alpha,beta);
288     end
289     [m,k]=size(W);, [k,n]=size(H);, numAll=(m*k)+(k*n);
290     switch stopRule
291         case 1
292             retVal = norm([gradW; gradH'],'fro')/numAll;
293         case 2
294             retVal = norm([gradW; gradH'],'fro');
295         case 3
296             retVal = getStopCriterion(3,A,W,H,type,alpha,beta,gradW,gradH);
297         case 0
298             retVal = 1;
299     end
300
301 endfunction
302 %-------------------------------------------------------------------------------
303 function retVal = getStopCriterion(stopRule,A,W,H,type,alpha,beta,gradW,gradH)
304 % STOPPING_RULE : 1 - Normalized proj. gradient
305 %                 2 - Proj. gradient
306 %                 3 - Delta by H. Kim
307 %                 0 - None (want to stop by MAX_ITER or MAX_TIME)
308     if nargin~=9
309         [gradW,gradH] = getGradient(A,W,H,type,alpha,beta);
310     end
311
312     switch stopRule
313         case 1
314             pGradW = gradW(gradW<0|W>0);
315             pGradH = gradH(gradH<0|H>0);
316             pGrad = [gradW(gradW<0|W>0); gradH(gradH<0|H>0)];
317             pGradNorm = norm(pGrad);
318             retVal = pGradNorm/length(pGrad);
319         case 2
320             pGradW = gradW(gradW<0|W>0);
321             pGradH = gradH(gradH<0|H>0);
322             pGrad = [gradW(gradW<0|W>0); gradH(gradH<0|H>0)];
323             retVal = norm(pGrad);
324         case 3
325             resmat=min(H,gradH); resvec=resmat(:);
326             resmat=min(W,gradW); resvec=[resvec; resmat(:)];
327             deltao=norm(resvec,1); %L1-norm
328             num_notconv=length(find(abs(resvec)>0));
329             retVal=deltao/num_notconv;
330         case 0
331             retVal = 1e100;
332     end
333
334 endfunction
335 %-------------------------------------------------------------------------------
336 function [gradW,gradH] = getGradient(A,W,H,type,alpha,beta)
337     switch type
338         case 'plain'
339             gradW = W*(H*H') - A*H';
340             gradH = (W'*W)*H - W'*A;
341         case 'regularized'
342             gradW = W*(H*H') - A*H' + alpha*W;
343             gradH = (W'*W)*H - W'*A + beta*H;
344         case 'sparse'
345             k=size(W,2);
346             betaI = beta*ones(k,k);
347             gradW = W*(H*H') - A*H' + alpha*W;
348             gradH = (W'*W)*H - W'*A + betaI*H;
349     end
350
351 endfunction
352
353 %------------------------------------------------------------------------------------------------------------------------
354 function [X,grad,iter] = nnlsm(A,B,init,solver)
355     switch solver
356         case 'bp'
357             [X,grad,iter] = nnlsm_blockpivot(A,B,0,init);
358         case 'as'
359             [X,grad,iter] = nnlsm_activeset(A,B,1,0,init);
360     end
361
362 endfunction
363 %------------------------------------------------------------------------------------------------------------------------
364 function [ X,Y,iter,success ] = nnlsm_activeset( A, B, overwrite, isInputProd, init)
365 % Nonnegativity Constrained Least Squares with Multiple Righthand Sides
366 %      using Active Set method
367 %
368 % This software solves the following problem: given A and B, find X such that
369 %            minimize || AX-B ||_F^2 where X>=0 elementwise.
370 %
371 % Reference:
372 %      Charles L. Lawson and Richard J. Hanson, Solving Least Squares Problems,
373 %            Society for Industrial and Applied Mathematics, 1995
374 %      M. H. Van Benthem and M. R. Keenan,
375 %            Fast Algorithm for the Solution of Large-scale Non-negativity-constrained Least Squares Problems,
376 %            J. Chemometrics 2004; 18: 441-450
377 %
378 % Written by Jingu Kim (jingu@cc.gatech.edu)
379 %               School of Computational Science and Engineering,
380 %               Georgia Institute of Technology
381 %
382 % Last updated Feb-20-2010
383 %
384 % <Inputs>
385 %        A : input matrix (m x n) (by default), or A'*A (n x n) if isInputProd==1
386 %        B : input matrix (m x k) (by default), or A'*B (n x k) if isInputProd==1
387 %        overwrite : (optional, default:0) if turned on, unconstrained least squares solution is computed in the beginning
388 %        isInputProd : (optional, default:0) if turned on, use (A'*A,A'*B) as input instead of (A,B)
389 %        init : (optional) initial value for X
390 % <Outputs>
391 %        X : the solution (n x k)
392 %        Y : A'*A*X - A'*B where X is the solution (n x k)
393 %        iter : number of iterations
394 %        success : 1 for success, 0 for failure.
395 %                  Failure could only happen on a numericall very ill-conditioned problem.
396
397     if nargin<3, overwrite=0;, end
398     if nargin<4, isInputProd=0;, end
399
400     if isInputProd
401         AtA=A;,AtB=B;
402     else
403         AtA=A'*A;, AtB=A'*B;
404     end
405
406     [n,k]=size(AtB);
407     MAX_ITER = n*5;
408     % set initial feasible solution
409     if overwrite
410         [X,iter] = solveNormalEqComb(AtA,AtB);
411         PassSet = (X > 0);
412         NotOptSet = any(X<0);
413     else
414         if nargin<5
415             X = zeros(n,k);
416             PassSet = false(n,k);
417             NotOptSet = true(1,k);
418         else
419             X = init;
420             PassSet = (X > 0);
421             NotOptSet = any(X<0);
422         end
423         iter = 0;
424     end
425
426     Y = zeros(n,k);
427     Y(:,~NotOptSet)=AtA*X(:,~NotOptSet) - AtB(:,~NotOptSet);
428     NotOptCols = find(NotOptSet);
429
430     bigIter = 0;, success=1;
431     while(~isempty(NotOptCols))
432         bigIter = bigIter+1;
433         if ((MAX_ITER >0) && (bigIter > MAX_ITER))   % set max_iter for ill-conditioned (numerically unstable) case
434             success = 0;, bigIter, break
435         end
436
437         % find unconstrained LS solution for the passive set
438         Z = zeros(n,length(NotOptCols));
439         [ Z,subiter ] = solveNormalEqComb(AtA,AtB(:,NotOptCols),PassSet(:,NotOptCols));
440         iter = iter + subiter;
441         %Z(abs(Z)<1e-12) = 0;                 % One can uncomment this line for numerical stability.
442         InfeaSubSet = Z < 0;
443         InfeaSubCols = find(any(InfeaSubSet));
444         FeaSubCols = find(all(~InfeaSubSet));
445
446         if ~isempty(InfeaSubCols)               % for infeasible cols
447             ZInfea = Z(:,InfeaSubCols);
448             InfeaCols = NotOptCols(InfeaSubCols);
449             Alpha = zeros(n,length(InfeaSubCols));, Alpha(:,:) = Inf;
450             InfeaSubSet(:,InfeaSubCols);
451             [i,j] = find(InfeaSubSet(:,InfeaSubCols));
452             InfeaSubIx = sub2ind(size(Alpha),i,j);
453             if length(InfeaCols) == 1
454                 InfeaIx = sub2ind([n,k],i,InfeaCols * ones(length(j),1));
455             else
456                 InfeaIx = sub2ind([n,k],i,InfeaCols(j)');
457             end
458             Alpha(InfeaSubIx) = X(InfeaIx)./(X(InfeaIx)-ZInfea(InfeaSubIx));
459
460             [minVal,minIx] = min(Alpha);
461             Alpha(:,:) = repmat(minVal,n,1);
462             X(:,InfeaCols) = X(:,InfeaCols)+Alpha.*(ZInfea-X(:,InfeaCols));
463             IxToActive = sub2ind([n,k],minIx,InfeaCols);
464             X(IxToActive) = 0;
465             PassSet(IxToActive) = false;
466         end
467         if ~isempty(FeaSubCols)                 % for feasible cols
468             FeaCols = NotOptCols(FeaSubCols);
469             X(:,FeaCols) = Z(:,FeaSubCols);
470             Y(:,FeaCols) = AtA * X(:,FeaCols) - AtB(:,FeaCols);
471             %Y( abs(Y)<1e-12 ) = 0;               % One can uncomment this line for numerical stability.
472
473             NotOptSubSet = (Y(:,FeaCols) < 0) & ~PassSet(:,FeaCols);
474             NewOptCols = FeaCols(all(~NotOptSubSet));
475             UpdateNotOptCols = FeaCols(any(NotOptSubSet));
476             if ~isempty(UpdateNotOptCols)
477                 [minVal,minIx] = min(Y(:,UpdateNotOptCols).*~PassSet(:,UpdateNotOptCols));
478                 PassSet(sub2ind([n,k],minIx,UpdateNotOptCols)) = true;
479             end
480             NotOptSet(NewOptCols) = false;
481             NotOptCols = find(NotOptSet);
482         end
483     end
484
485 endfunction
486 %------------------------------------------------------------------------------------------------------------------------
487 function [ X,Y,iter,success ] = nnlsm_blockpivot( A, B, isInputProd, init )
488 % Nonnegativity Constrained Least Squares with Multiple Righthand Sides
489 %      using Block Principal Pivoting method
490 %
491 % This software solves the following problem: given A and B, find X such that
492 %              minimize || AX-B ||_F^2 where X>=0 elementwise.
493 %
494 % Reference:
495 %      Jingu Kim and Haesun Park, Toward Faster Nonnegative Matrix Factorization: A New Algorithm and Comparisons,
496 %      In Proceedings of the 2008 Eighth IEEE International Conference on Data Mining (ICDM'08), 353-362, 2008
497 %
498 % Written by Jingu Kim (jingu@cc.gatech.edu)
499 % Copyright 2008-2009 by Jingu Kim and Haesun Park,
500 %                        School of Computational Science and Engineering,
501 %                        Georgia Institute of Technology
502 %
503 % Check updated code at http://www.cc.gatech.edu/~jingu
504 % Please send bug reports, comments, or questions to Jingu Kim.
505 % This code comes with no guarantee or warranty of any kind. Note that this algorithm assumes that the
506 %      input matrix A has full column rank.
507 %
508 % Last modified Feb-20-2009
509 %
510 % <Inputs>
511 %        A : input matrix (m x n) (by default), or A'*A (n x n) if isInputProd==1
512 %        B : input matrix (m x k) (by default), or A'*B (n x k) if isInputProd==1
513 %        isInputProd : (optional, default:0) if turned on, use (A'*A,A'*B) as input instead of (A,B)
514 %        init : (optional) initial value for X
515 % <Outputs>
516 %        X : the solution (n x k)
517 %        Y : A'*A*X - A'*B where X is the solution (n x k)
518 %        iter : number of iterations
519 %        success : 1 for success, 0 for failure.
520 %                  Failure could only happen on a numericall very ill-conditioned problem.
521
522     if nargin<3, isInputProd=0;, end
523     if isInputProd
524         AtA = A;, AtB = B;
525     else
526         AtA = A'*A;, AtB = A'*B;
527     end
528
529     [n,k]=size(AtB);
530     MAX_ITER = n*5;
531     % set initial feasible solution
532     X = zeros(n,k);
533     if nargin<4
534         Y = - AtB;
535         PassiveSet = false(n,k);
536         iter = 0;
537     else
538         PassiveSet = (init > 0);
539         [ X,iter ] = solveNormalEqComb(AtA,AtB,PassiveSet);
540         Y = AtA * X - AtB;
541     end
542     % parameters
543     pbar = 3;
544     P = zeros(1,k);, P(:) = pbar;
545     Ninf = zeros(1,k);, Ninf(:) = n+1;
546     iter = 0;
547
548     NonOptSet = (Y < 0) & ~PassiveSet;
549     InfeaSet = (X < 0) & PassiveSet;
550     NotGood = sum(NonOptSet)+sum(InfeaSet);
551     NotOptCols = NotGood > 0;
552
553     bigIter = 0;, success=1;
554     while(~isempty(find(NotOptCols)))
555         bigIter = bigIter+1;
556         if ((MAX_ITER >0) && (bigIter > MAX_ITER))   % set max_iter for ill-conditioned (numerically unstable) case
557             success = 0;, break
558         end
559
560         Cols1 = NotOptCols & (NotGood < Ninf);
561         Cols2 = NotOptCols & (NotGood >= Ninf) & (P >= 1);
562         Cols3Ix = find(NotOptCols & ~Cols1 & ~Cols2);
563         if ~isempty(find(Cols1))
564             P(Cols1) = pbar;,Ninf(Cols1) = NotGood(Cols1);
565             PassiveSet(NonOptSet & repmat(Cols1,n,1)) = true;
566             PassiveSet(InfeaSet & repmat(Cols1,n,1)) = false;
567         end
568         if ~isempty(find(Cols2))
569             P(Cols2) = P(Cols2)-1;
570             PassiveSet(NonOptSet & repmat(Cols2,n,1)) = true;
571             PassiveSet(InfeaSet & repmat(Cols2,n,1)) = false;
572         end
573         if ~isempty(Cols3Ix)
574             for i=1:length(Cols3Ix)
575                 Ix = Cols3Ix(i);
576                 toChange = max(find( NonOptSet(:,Ix)|InfeaSet(:,Ix) ));
577                 if PassiveSet(toChange,Ix)
578                     PassiveSet(toChange,Ix)=false;
579                 else
580                     PassiveSet(toChange,Ix)=true;
581                 end
582             end
583         end
584         NotOptMask = repmat(NotOptCols,n,1);
585         [ X(:,NotOptCols),subiter ] = solveNormalEqComb(AtA,AtB(:,NotOptCols),PassiveSet(:,NotOptCols));
586         iter = iter + subiter;
587         X(abs(X)<1e-12) = 0;            % for numerical stability
588         Y(:,NotOptCols) = AtA * X(:,NotOptCols) - AtB(:,NotOptCols);
589         Y(abs(Y)<1e-12) = 0;            % for numerical stability
590
591         % check optimality
592         NonOptSet = NotOptMask & (Y < 0) & ~PassiveSet;
593         InfeaSet = NotOptMask & (X < 0) & PassiveSet;
594         NotGood = sum(NonOptSet)+sum(InfeaSet);
595         NotOptCols = NotGood > 0;
596     end
597 endfunction
598 %------------------------------------------------------------------------------------------------------------------------
599 function [ Z,iter ] = solveNormalEqComb( AtA,AtB,PassSet )
600 % Solve normal equations using combinatorial grouping.
601 % Although this function was originally adopted from the code of
602 % "M. H. Van Benthem and M. R. Keenan, J. Chemometrics 2004; 18: 441-450",
603 % important modifications were made to fix bugs.
604 %
605 % Modified by Jingu Kim (jingu@cc.gatech.edu)
606 %             School of Computational Science and Engineering,
607 %             Georgia Institute of Technology
608 %
609 % Last updated Aug-12-2009
610
611     iter = 0;
612     if (nargin ==2) || isempty(PassSet) || all(PassSet(:))
613         Z = AtA\AtB;
614         iter = iter + 1;
615     else
616         Z = zeros(size(AtB));
617         [n,k1] = size(PassSet);
618
619         ## Fixed on Aug-12-2009
620         if k1==1
621             Z(PassSet)=AtA(PassSet,PassSet)\AtB(PassSet);
622         else
623             ## Fixed on Aug-12-2009
624             % The following bug was identified by investigating a bug report by Hanseung Lee.
625             [sortedPassSet,sortIx] = sortrows(PassSet');
626             breaks = any(diff(sortedPassSet)');
627             breakIx = [0 find(breaks) k1];
628             % codedPassSet = 2.^(n-1:-1:0)*PassSet;
629             % [sortedPassSet,sortIx] = sort(codedPassSet);
630             % breaks = diff(sortedPassSet);
631             % breakIx = [0 find(breaks) k1];
632
633             for k=1:length(breakIx)-1
634                 cols = sortIx(breakIx(k)+1:breakIx(k+1));
635                 vars = PassSet(:,sortIx(breakIx(k)+1));
636                 Z(vars,cols) = AtA(vars,vars)\AtB(vars,cols);
637                 iter = iter + 1;
638             end
639         end
640     end
641 endfunction
642
643 %!shared m, n, k, A
644 %! m = 30;
645 %! n = 20;
646 %! k = 10;
647 %! A = rand(m,n);
648
649 %!test
650 %! [W,H,iter,HIS]=nmf_bpas(A,k);
651
652 %!test
653 %! [W,H,iter,HIS]=nmf_bpas(A,k,'verbose',2);
654
655 %!test
656 %! [W,H,iter,HIS]=nmf_bpas(A,k,'verbose',1,'nnlssolver','as');
657
658 %!test
659 %! [W,H,iter,HIS]=nmf_bpas(A,k,'verbose',1,'type','sparse');
660
661 %!test
662 %! [W,H,iter,HIS]=nmf_bpas(A,k,'verbose',1,'type','sparse','nnlssolver','bp','alpha',1.1,'beta',1.3);
663
664 %!test
665 %! [W,H,iter,HIS]=nmf_bpas(A,k,'verbose',2,'type','plain','winit',rand(m,k));
666
667 %!demo
668 %! m = 300;
669 %! n = 200;
670 %! k = 10;
671 %!
672 %! W_org = rand(m,k);, W_org(rand(m,k)>0.5)=0;
673 %! H_org = rand(k,n);, H_org(rand(k,n)>0.5)=0;
674 %!
675 %! % normalize W, since 'nmf' normalizes W before return
676 %! norm2=sqrt(sum(W_org.^2,1));
677 %! toNormalize = norm2>0;
678 %! W_org(:,toNormalize) = W_org(:,toNormalize)./repmat(norm2(toNormalize),m,1);
679 %!
680 %! A = W_org * H_org;
681 %!
682 %! [W,H,iter,HIS]=nmf_bpas (A,k,'type','plain','tol',1e-4);
683 %!
684 %! % -------------------- column reordering before computing difference
685 %! reorder = zeros(k,1);
686 %! selected = zeros(k,1);
687 %! for i=1:k
688 %!    for j=1:k
689 %!        if ~selected(j), break, end
690 %!    end
691 %!    minIx = j;
692 %!
693 %!    for j=minIx+1:k
694 %!        if ~selected(j)
695 %!            d1 = norm(W(:,i)-W_org(:,minIx));
696 %!            d2 = norm(W(:,i)-W_org(:,j));
697 %!            if (d2<d1)
698 %!                minIx = j;
699 %!            end
700 %!        end
701 %!    end
702 %!    reorder(i) = minIx;
703 %!    selected(minIx) = 1;
704 %! end
705 %!
706 %! W_org = W_org(:,reorder);
707 %! H_org = H_org(reorder,:);
708 %! % ---------------------------------------------------------------------
709 %!
710 %! recovery_error_W = norm(W_org-W)/norm(W_org)
711 %! recovery_error_H = norm(H_org-H)/norm(H_org)