1 ## Copyright (C) 2006 Michel D. Schmid <michaelschmid@users.sourceforge.net>
4 ## This program is free software; you can redistribute it and/or modify it
5 ## under the terms of the GNU General Public License as published by
6 ## the Free Software Foundation; either version 2, or (at your option)
9 ## This program is distributed in the hope that it will be useful, but
10 ## WITHOUT ANY WARRANTY; without even the implied warranty of
11 ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
12 ## General Public License for more details.
14 ## You should have received a copy of the GNU General Public License
15 ## along with this program; see the file COPYING. If not, see
16 ## <http://www.gnu.org/licenses/>.
19 ## @deftypefn {Function File} {}[@var{netOut}] = __trainlm (@var{net},@var{mInputN},@var{mOutput},@var{[]},@var{[]},@var{VV})
20 ## A neural feed-forward network will be trained with @code{__trainlm}
23 ## [netOut,tr,out,E] = __trainlm(net,mInputN,mOutput,[],[],VV);
27 ## left side arguments:
29 ## netOut: the trained network of the net structure @code{MLPnet}
36 ## right side arguments:
38 ## net : the untrained network, created with @code{newff}
39 ## mInputN: normalized input matrix
40 ## mOutput: output matrix
41 ## [] : unused parameter
42 ## [] : unused parameter
43 ## VV : validize structure
54 ## @seealso{newff,prestd,trastd}
56 ## Author: Michel D. Schmid
58 ## Comments: see in "A neural network toolbox for Octave User's Guide" [4]
59 ## for variable naming... there have inputs or targets only one letter,
60 ## e.g. for inputs is P written. To write a program, this is stupid, you can't
61 ## search for 1 letter variable... that's why it is written here like Pp, or Tt
62 ## instead only P or T.
64 function [net] = __trainlm(net,Im,Pp,Tt,VV)
66 ## check range of input arguments
67 error(nargchk(5,5,nargin))
72 ## get parameters for training
73 epochs = net.trainParam.epochs;
74 goal = net.trainParam.goal;
75 maxFail = net.trainParam.max_fail;
76 minGrad = net.trainParam.min_grad;
77 mu = net.trainParam.mu;
78 muInc = net.trainParam.mu_inc;
79 muDec = net.trainParam.mu_dec;
80 muMax = net.trainParam.mu_max;
81 show = net.trainParam.show;
82 time = net.trainParam.time;
85 checkParameter(epochs,goal,maxFail,minGrad,mu,\
86 muInc,muDec,muMax,show,time);
89 shortStr = "TRAINLM"; # TODO: shortStr is longer as TRAINLM !!!!!!!!!!!
90 doValidation = !isempty(VV);
94 #startTime = clock(); # TODO: maybe this row can be placed
97 ## the weights are used in column vector format
98 xx = __getx(net); # x is the variable with respect to, but no
99 # variables with only one letter!!
100 ## define identity matrix
101 muI = eye(length(xx));
103 startTime = clock(); # if the next some tests are OK, I can delete
104 # startTime = clock(); 9 rows above..
106 ## calc performance of the actual net
107 [perf,vE,Aa,Nn] = __calcperf(net,xx,Im,Tt);
109 ## calc performance if validation is used
110 VV.net = net; # save the actual net in the validate
111 # structure... if no train loop will show better validate
112 # performance, this will be the returned net
113 vperf = __calcperf(net,xx,VV.Im,VV.Tt);
115 VV.numFail = 0; # one of the stop criterias
118 nLayers = net.numLayers;
119 for iEpochs = 0:epochs # longest loop & one of the stop criterias
122 ## Jj is jacobian matrix
123 [Jj] = __calcjacobian(net,Im,Nn,Aa,vE);
125 ## rerange error vector for jacobi matrix
128 Jjve = (Jj' * ve); # will be used to calculate the gradient
130 normGradX = sqrt(Jjve'*Jjve);
132 ## record training progress for later plotting
134 trainRec.perf(iEpochs+1) = perf;
135 trainRec.mu(iEpochs+1) = mu;
137 trainRec.vperf(iEpochs+1) = VV.perf;
141 [stop,currentTime] = stopifnecessary(stop,startTime,perf,goal,\
142 iEpochs,epochs,time,normGradX,minGrad,mu,muMax,\
143 doValidation,VV,maxFail);
145 ## show train progress
146 showtrainprogress(show,stop,iEpochs,epochs,time,currentTime, \
147 goal,perf,minGrad,normGradX,shortStr,net);
149 ## show performance plot, if needed
150 if !isnan(show) # if no performance plot is needed
151 ## now make it possible to define after how much loops the
152 ## performance plot should be updated
153 if (mod(iEpochs,show)==0)
154 plot(1:length(trainRec.perf),trainRec.perf);
157 plot(1:length(trainRec.vperf),trainRec.vperf,"--g");
160 endif # if !(strcmp(show,"NaN"))
161 # legend("Training","Validation");
163 ## stop if one of the criterias is reached.
173 ## calculate change in x
174 ## see [4], page 12-21
175 dx = -((Jj' * Jj) + (muI*mu)) \ Jjve;
177 ## add changes in x to actual x values (xx)
179 ## now add x1 to a new network to see if performance will be better
180 net1 = __setx(net,x1);
181 ## calc now new performance with the new net
182 [perf1,vE1,Aa1,N1] = __calcperf(net1,x1,Im,Tt);
185 ## this means, net performance with new weight values is better...
186 ## so save the new values
195 if (mu < 1e-20) # 1e-20 is properly the hard coded parameter in MATLAB(TM)
203 ## validate with DeltaX
205 vperf = __calcperf(net,xx,VV.Im,VV.Tt);
209 ## if actual validation performance is better,
210 ## set numFail to zero again
212 elseif (vperf > VV.perf)
213 VV.numFail = VV.numFail + 1;
217 endfor #for iEpochs = 0:epochs
219 #=======================================================
221 # additional functions
223 #=======================================================
224 function checkParameter(epochs,goal,maxFail,minGrad,mu,\
225 muInc, muDec, muMax, show, time)
226 ## Parameter Checking
228 ## epochs must be a positive integer
229 if ( !isposint(epochs) )
230 error("Epochs is not a positive integer.")
233 ## goal can be zero or a positive double
234 if ( (goal<0) || !(isa(goal,"double")) )
235 error("Goal is not zero or a positive real value.")
238 ## maxFail must be also a positive integer
239 if ( !isposint(maxFail) ) # this will be used, to see if validation can
241 error("maxFail is not a positive integer.")
244 if (!isa(minGrad,"double")) || (!isreal(minGrad)) || (!isscalar(minGrad)) || \
246 error("minGrad is not zero or a positive real value.")
249 ## mu must be a positive real value. this parameter is responsible
250 ## for moving from stepest descent to quasi newton
251 if ((!isa(mu,"double")) || (!isreal(mu)) || (any(size(mu)) != 1) || (mu <= 0))
252 error("mu is not a positive real value.")
255 ## muDec defines the decrement factor
256 if ((!isa(muDec,"double")) || (!isreal(muDec)) || (any(size(muDec)) != 1) || \
257 (muDec < 0) || (muDec > 1))
258 error("muDec is not a real value between 0 and 1.")
261 ## muInc defines the increment factor
262 if (~isa(muInc,"double")) || (!isreal(muInc)) || (any(size(muInc)) != 1) || \
264 error("muInc is not a real value greater than 1.")
267 ## muMax is the upper boundary for the mu value
268 if (!isa(muMax,"double")) || (!isreal(muMax)) || (any(size(muMax)) != 1) || \
270 error("muMax is not a positive real value.")
273 ## check for actual mu value
275 error("mu is greater than muMax.")
278 ## check if show is activated
281 error(["Show is not " "NaN" " or a positive integer."])
285 ## check at last the time argument, must be zero or a positive real value
286 if (!isa(time,"double")) || (!isreal(time)) || (any(size(time)) != 1) || \
288 error("Time is not zero or a positive real value.")
291 endfunction # parameter checking
294 # -----------------------------------------------------------------------------
297 function showtrainprogress(show,stop,iEpochs,epochs,time,currentTime, \
298 goal,perf,minGrad,normGradX,shortStr,net)
300 ## check number of inputs
301 error(nargchk(12,12,nargin));
304 if isfinite(show) && (!rem(iEpochs,show) || length(stop))
305 fprintf(shortStr); # outputs the training algorithm
307 fprintf(", Epoch %g/%g",iEpochs, epochs);
310 fprintf(", Time %4.1f%%",currentTime/time*100); # \todo: Time wird nicht ausgegeben
313 fprintf(", %s %g/%g",upper(net.performFcn),perf,goal); # outputs the performance function
316 fprintf(", Gradient %g/%g",normGradX,minGrad);
320 fprintf("%s, %s\n\n",shortStr,stop);
322 fflush(stdout); # writes output to stdout as soon as output messages are available
327 # -----------------------------------------------------------------------------
330 function [stop,currentTime] = stopifnecessary(stop,startTime,perf,goal,\
331 iEpochs,epochs,time,normGradX,minGrad,mu,muMax,\
332 doValidation,VV,maxFail)
334 ## check number of inputs
335 error(nargchk(14,14,nargin));
337 currentTime = etime(clock(),startTime);
339 stop = "Performance goal met.";
340 elseif (iEpochs == epochs)
341 stop = "Maximum epoch reached, performance goal was not met.";
342 elseif (currentTime > time)
343 stop = "Maximum time elapsed, performance goal was not met.";
344 elseif (normGradX < minGrad)
345 stop = "Minimum gradient reached, performance goal was not met.";
347 stop = "Maximum MU reached, performance goal was not met.";
348 elseif (doValidation)
349 if (VV.numFail > maxFail)
350 stop = "Validation stop.";
355 # =====================================================================
357 # END additional functions
359 # =====================================================================