]> Creatis software - CreaPhase.git/blob - octave_packages/nnet-0.1.13/__trainlm.m
Add a useful package (from Source forge) for octave
[CreaPhase.git] / octave_packages / nnet-0.1.13 / __trainlm.m
1 ## Copyright (C) 2006 Michel D. Schmid <michaelschmid@users.sourceforge.net>
2 ##
3 ##
4 ## This program is free software; you can redistribute it and/or modify it
5 ## under the terms of the GNU General Public License as published by
6 ## the Free Software Foundation; either version 2, or (at your option)
7 ## any later version.
8 ##
9 ## This program is distributed in the hope that it will be useful, but
10 ## WITHOUT ANY WARRANTY; without even the implied warranty of
11 ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 ## General Public License for more details.
13 ##
14 ## You should have received a copy of the GNU General Public License
15 ## along with this program; see the file COPYING.  If not, see
16 ## <http://www.gnu.org/licenses/>.
17
18 ## -*- texinfo -*-
19 ## @deftypefn {Function File} {}[@var{netOut}] = __trainlm (@var{net},@var{mInputN},@var{mOutput},@var{[]},@var{[]},@var{VV})
20 ## A neural feed-forward network will be trained with @code{__trainlm}
21 ##
22 ## @example
23 ## [netOut,tr,out,E] = __trainlm(net,mInputN,mOutput,[],[],VV);
24 ## @end example
25 ## @noindent
26 ##
27 ## left side arguments:
28 ## @example
29 ## netOut: the trained network of the net structure @code{MLPnet}
30 ## tr :
31 ## out:
32 ## E  : Error
33 ## @end example
34 ## @noindent
35 ##
36 ## right side arguments:
37 ## @example
38 ## net    : the untrained network, created with @code{newff}
39 ## mInputN: normalized input matrix
40 ## mOutput: output matrix
41 ## []     : unused parameter
42 ## []     : unused parameter
43 ## VV     : validize structure
44 ## out:
45 ## E  : Error
46 ## @end example
47 ## @noindent
48 ##
49 ##
50 ## @noindent
51 ## are equivalent.
52 ## @end deftypefn
53
54 ## @seealso{newff,prestd,trastd}
55
56 ## Author: Michel D. Schmid
57
58 ## Comments: see in "A neural network toolbox for Octave User's Guide" [4]
59 ##  for variable naming... there have inputs or targets only one letter,
60 ## e.g. for inputs is P written. To write a program, this is stupid, you can't
61 ## search for 1 letter variable... that's why it is written here like Pp, or Tt
62 ## instead only P or T.
63
64 function [net] = __trainlm(net,Im,Pp,Tt,VV)
65
66   ## check range of input arguments
67   error(nargchk(5,5,nargin))
68
69   ## Initialize
70   ##------------
71
72   ## get parameters for training
73   epochs   = net.trainParam.epochs;
74   goal     = net.trainParam.goal;
75   maxFail  = net.trainParam.max_fail;
76   minGrad  = net.trainParam.min_grad;
77   mu       = net.trainParam.mu;
78   muInc    = net.trainParam.mu_inc;
79   muDec    = net.trainParam.mu_dec;
80   muMax    = net.trainParam.mu_max;
81   show     = net.trainParam.show;
82   time     = net.trainParam.time;
83
84   ## parameter checking
85   checkParameter(epochs,goal,maxFail,minGrad,mu,\
86                        muInc,muDec,muMax,show,time);
87
88   ## Constants
89   shortStr = "TRAINLM";    # TODO: shortStr is longer as TRAINLM !!!!!!!!!!!
90   doValidation = !isempty(VV);
91   stop = "";
92
93
94   #startTime = clock(); # TODO: maybe this row can be placed
95                        # some rows later
96
97   ## the weights are used in column vector format
98   xx = __getx(net); # x is the variable with respect to, but no
99                     # variables with only one letter!!
100   ## define identity matrix
101   muI = eye(length(xx));                  
102
103   startTime = clock();  # if the next some tests are OK, I can delete
104                         # startTime = clock(); 9 rows above..
105
106   ## calc performance of the actual net
107   [perf,vE,Aa,Nn] = __calcperf(net,xx,Im,Tt);
108   if (doValidation)
109     ## calc performance if validation is used
110     VV.net = net; # save the actual net in the validate
111     # structure... if no train loop will show better validate
112     # performance, this will be the returned net
113     vperf = __calcperf(net,xx,VV.Im,VV.Tt);
114     VV.perf = vperf;
115     VV.numFail = 0; # one of the stop criterias
116   endif
117
118   nLayers = net.numLayers;
119   for iEpochs = 0:epochs # longest loop & one of the stop criterias
120     ve = vE{nLayers,1};
121     ## calc jacobian
122     ## Jj is jacobian matrix
123     [Jj] = __calcjacobian(net,Im,Nn,Aa,vE);
124
125     ## rerange error vector for jacobi matrix
126     ve = ve(:);
127
128     Jjve = (Jj' * ve); # will be used to calculate the gradient
129
130     normGradX = sqrt(Jjve'*Jjve);
131
132     ## record training progress for later plotting
133     ## if requested
134     trainRec.perf(iEpochs+1) = perf;
135     trainRec.mu(iEpochs+1) = mu;
136     if (doValidation)
137       trainRec.vperf(iEpochs+1) = VV.perf;
138     endif
139
140     ## stoping criteria
141     [stop,currentTime] = stopifnecessary(stop,startTime,perf,goal,\
142                            iEpochs,epochs,time,normGradX,minGrad,mu,muMax,\
143                            doValidation,VV,maxFail);
144
145     ## show train progress
146     showtrainprogress(show,stop,iEpochs,epochs,time,currentTime, \
147                   goal,perf,minGrad,normGradX,shortStr,net);
148
149     ## show performance plot, if needed
150     if !isnan(show) # if no performance plot is needed
151       ## now make it possible to define after how much loops the
152       ## performance plot should be updated
153       if (mod(iEpochs,show)==0)
154         plot(1:length(trainRec.perf),trainRec.perf);
155         if (doValidation)
156           hold on;
157           plot(1:length(trainRec.vperf),trainRec.vperf,"--g");
158         endif
159       endif
160     endif # if !(strcmp(show,"NaN"))
161 #    legend("Training","Validation");
162
163     ## stop if one of the criterias is reached.
164     if length(stop)
165       if (doValidation)
166         net = VV.net;
167       endif
168       break
169     endif
170
171     ## calculate DeltaX
172     while (mu <= muMax)
173       ## calculate change in x
174       ## see [4], page 12-21
175       dx = -((Jj' * Jj) + (muI*mu)) \ Jjve;
176
177       ## add changes in x to actual x values (xx)
178       x1 = xx + dx;
179       ## now add x1 to a new network to see if performance will be better
180       net1 = __setx(net,x1);
181       ## calc now new performance with the new net
182       [perf1,vE1,Aa1,N1] = __calcperf(net1,x1,Im,Tt);
183
184       if (perf1 < perf)
185         ## this means, net performance with new weight values is better...
186         ## so save the new values
187         xx = x1;
188         net = net1;
189         Nn = N1;
190         Aa = Aa1;
191         vE = vE1;
192         perf = perf1;
193
194         mu = mu * muDec;
195         if (mu < 1e-20)   # 1e-20 is properly the hard coded parameter in MATLAB(TM)
196           mu = 1e-20;
197         endif
198         break
199       endif
200       mu = mu * muInc;
201     endwhile
202
203     ## validate with DeltaX
204     if (doValidation)
205       vperf = __calcperf(net,xx,VV.Im,VV.Tt);
206       if (vperf < VV.perf)
207         VV.perf = vperf;
208         VV.net = net;
209         ## if actual validation performance is better,
210         ## set numFail to zero again
211         VV.numFail = 0;
212       elseif (vperf > VV.perf)
213         VV.numFail = VV.numFail + 1;
214       endif
215     endif
216
217   endfor #for iEpochs = 0:epochs
218
219 #=======================================================
220 #
221 # additional functions
222 #
223 #=======================================================
224   function checkParameter(epochs,goal,maxFail,minGrad,mu,\
225                        muInc, muDec, muMax, show, time)
226     ## Parameter Checking
227
228     ## epochs must be a positive integer
229     if ( !isposint(epochs) )
230       error("Epochs is not a positive integer.")
231     endif
232
233     ## goal can be zero or a positive double
234     if ( (goal<0) || !(isa(goal,"double")) )
235       error("Goal is not zero or a positive real value.")
236     endif
237
238     ## maxFail must be also a positive integer
239     if ( !isposint(maxFail) ) # this will be used, to see if validation can
240       # break the training
241       error("maxFail is not a positive integer.")
242     endif
243
244     if (!isa(minGrad,"double")) || (!isreal(minGrad)) || (!isscalar(minGrad)) || \
245       (minGrad < 0)
246       error("minGrad is not zero or a positive real value.")
247     end
248
249     ## mu must be a positive real value. this parameter is responsible
250     ## for moving from stepest descent to quasi newton
251     if ((!isa(mu,"double")) || (!isreal(mu)) || (any(size(mu)) != 1) || (mu <= 0))
252       error("mu is not a positive real value.")
253     endif
254
255     ## muDec defines the decrement factor
256     if ((!isa(muDec,"double")) || (!isreal(muDec)) || (any(size(muDec)) != 1) || \
257                  (muDec < 0) || (muDec > 1))
258       error("muDec is not a real value between 0 and 1.")
259     endif
260
261     ## muInc defines the increment factor
262     if (~isa(muInc,"double")) || (!isreal(muInc)) || (any(size(muInc)) != 1) || \
263       (muInc < 1)
264       error("muInc is not a real value greater than 1.")
265     endif
266
267     ## muMax is the upper boundary for the mu value
268     if (!isa(muMax,"double")) || (!isreal(muMax)) || (any(size(muMax)) != 1) || \
269       (muMax <= 0)
270       error("muMax is not a positive real value.")
271     endif
272
273     ## check for actual mu value
274     if (mu > muMax)
275       error("mu is greater than muMax.")
276     end
277
278     ## check if show is activated
279     if (!isnan(show))
280           if (!isposint(show))
281         error(["Show is not " "NaN" " or a positive integer."])
282       endif
283     endif
284
285     ## check at last the time argument, must be zero or a positive real value
286     if (!isa(time,"double")) || (!isreal(time)) || (any(size(time)) != 1) || \
287       (time < 0)
288       error("Time is not zero or a positive real value.")
289     end
290
291   endfunction # parameter checking
292
293 #
294 # -----------------------------------------------------------------------------
295 #
296
297   function showtrainprogress(show,stop,iEpochs,epochs,time,currentTime, \
298           goal,perf,minGrad,normGradX,shortStr,net)
299
300     ## check number of inputs
301     error(nargchk(12,12,nargin));
302
303     ## show progress
304     if isfinite(show) && (!rem(iEpochs,show) || length(stop))
305       fprintf(shortStr);   # outputs the training algorithm
306       if isfinite(epochs)
307         fprintf(", Epoch %g/%g",iEpochs, epochs);
308       endif
309       if isfinite(time)
310         fprintf(", Time %4.1f%%",currentTime/time*100);   # \todo: Time wird nicht ausgegeben
311       endif
312       if isfinite(goal)
313         fprintf(", %s %g/%g",upper(net.performFcn),perf,goal); # outputs the performance function
314       endif
315       if isfinite(minGrad)
316         fprintf(", Gradient %g/%g",normGradX,minGrad);
317       endif
318       fprintf("\n")
319       if length(stop)
320         fprintf("%s, %s\n\n",shortStr,stop);
321       endif
322       fflush(stdout); # writes output to stdout as soon as output messages are available
323     endif
324   endfunction
325   
326 #
327 # -----------------------------------------------------------------------------
328 #
329
330   function [stop,currentTime] = stopifnecessary(stop,startTime,perf,goal,\
331                         iEpochs,epochs,time,normGradX,minGrad,mu,muMax,\
332                                                 doValidation,VV,maxFail)
333
334     ## check number of inputs
335     error(nargchk(14,14,nargin));
336
337     currentTime = etime(clock(),startTime);
338     if (perf <= goal)
339       stop = "Performance goal met.";
340     elseif (iEpochs == epochs)
341       stop = "Maximum epoch reached, performance goal was not met.";
342     elseif (currentTime > time)
343       stop = "Maximum time elapsed, performance goal was not met.";
344     elseif (normGradX < minGrad)
345       stop = "Minimum gradient reached, performance goal was not met.";
346     elseif (mu > muMax)
347       stop = "Maximum MU reached, performance goal was not met.";
348     elseif (doValidation) 
349           if (VV.numFail > maxFail)
350         stop = "Validation stop.";
351       endif
352     endif
353   endfunction
354
355 # =====================================================================
356 #
357 # END additional functions
358 #
359 # =====================================================================
360
361 endfunction