1 ## Copyright (C) 2006 Michel D. Schmid <michaelschmid@users.sourceforge.net>
4 ## This program is free software; you can redistribute it and/or modify it
5 ## under the terms of the GNU General Public License as published by
6 ## the Free Software Foundation; either version 2, or (at your option)
9 ## This program is distributed in the hope that it will be useful, but
10 ## WITHOUT ANY WARRANTY; without even the implied warranty of
11 ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
12 ## General Public License for more details.
14 ## You should have received a copy of the GNU General Public License
15 ## along with this program; see the file COPYING. If not, see
16 ## <http://www.gnu.org/licenses/>.
19 ## @deftypefn {Function File} {}[@var{net}] = train (@var{MLPnet},@var{mInputN},@var{mOutput},@var{[]},@var{[]},@var{VV})
20 ## A neural feed-forward network will be trained with @code{train}
23 ## [net,tr,out,E] = train(MLPnet,mInputN,mOutput,[],[],VV);
28 ## left side arguments:
29 ## net: the trained network of the net structure @code{MLPnet}
34 ## right side arguments:
35 ## MLPnet : the untrained network, created with @code{newff}
36 ## mInputN: normalized input matrix
37 ## mOutput: output matrix (normalized or not)
38 ## [] : unused parameter
39 ## [] : unused parameter
40 ## VV : validize structure
44 ## @seealso{newff,prestd,trastd}
46 ## Author: Michel D. Schmid
48 ## Comments: see in "A neural network toolbox for Octave User's Guide" [4]
49 ## for variable naming... there have inputs or targets only one letter,
50 ## e.g. for inputs is P written. To write a program, this is stupid, you can't
51 ## search for 1 letter variable... that's why it is written here like Pp, or Tt
52 ## instead only P or T.
54 function [net] = train(net,Pp,Tt,notUsed1,notUsed2,VV)
56 ## check range of input arguments
57 error(nargchk(3,6,nargin))
63 ## check if VV is in MATLAB(TM) notation
64 [VV, doValidation] = checkVV(VV);
68 checkInputArgs(net,Pp,Tt)
73 [Pp,Tt] = trainArgs(net,Pp,Tt);
76 [Pp,Tt] = trainArgs(net,Pp,Tt);
81 error("VV.Pp must be defined or VV must be [].")
84 error("VV.Tt must be defined or VV must be [].")
86 [VV.Pp,VV.Tt] = trainArgs(net,VV.Pp,VV.Tt);
89 error("train: impossible code execution in switch(nargin)")
93 ## so now, let's start training the network
94 ##===========================================
96 ## first let's check if a train function is defined ...
97 if isempty(net.trainFcn)
98 error("train:net.trainFcn not defined")
101 ## calculate input matrix Im
102 [nRowsInputs, nColumnsInputs] = size(Pp);
103 Im = ones(nRowsInputs,nColumnsInputs).*Pp{1,1};
106 [nRowsVal, nColumnsVal] = size(VV.Pp);
107 VV.Im = ones(nRowsVal,nColumnsVal).*VV.Pp{1,1};
110 ## make it MATLAB(TM) compatible
111 nLayers = net.numLayers;
112 Tt{nLayers,1} = Tt{1,1};
115 VV.Tt{nLayers,1} = VV.Tt{1,1};
119 ## which training algorithm should be used
122 if !strcmp(net.performFcn,"mse")
123 error("Levenberg-Marquardt algorithm is defined with the MSE performance function, so please set MSE in NEWFF!")
125 net = __trainlm(net,Im,Pp,Tt,VV);
127 error("train algorithm argument is not valid!")
131 # =======================================================
133 # additional check functions...
135 # =======================================================
137 function checkInputArgs(net,Pp,Tt)
139 ## check "net", must be a net structure
140 if !__checknetstruct(net)
141 error("Structure doesn't seem to be a neural network!")
145 nInputSize = net.inputs{1}.size; #only one exists
146 [nRowsPp, nColumnsPp] = size(Pp);
147 if ( (nColumnsPp>0) )
148 if ( nInputSize==nRowsPp )
149 # seems to be everything i.o.
151 error("Number of rows must be the same, like in net.inputs.size defined!")
154 error("At least one column must exist")
157 ## check Tt (targets)
158 [nRowsTt, nColumnsTt] = size(Tt);
159 if ( (nRowsTt | nColumnsTt)==0 )
160 error("No targets defined!")
161 elseif ( nColumnsTt!=nColumnsPp )
162 error("Inputs and targets must have the same number of data sets (columns).")
163 elseif ( net.layers{net.numLayers}.size!=nRowsTt )
164 error("Defined number of output neurons are not identically to targets data sets!")
168 # -------------------------------------------------------
169 function [Pp,Tt] = trainArgs(net,Pp,Tt);
171 ## check number of arguments
172 error(nargchk(3,3,nargin));
174 [PpRows, PpColumns] = size(Pp);
175 Pp = mat2cell(Pp,PpRows,PpColumns); # mat2cell is the reason
176 # why octave-2.9.5 doesn't work
177 # octave-2.9.x with x>=6 should be
179 [TtRows, TtColumns] = size(Tt);
180 Tt = mat2cell(Tt,TtRows,TtColumns);
184 # -------------------------------------------------------
186 function [VV, doValidation] = checkVV(VV)
188 ## check number of arguments
189 error(nargchk(1,1,nargin));
195 ## check if MATLAB(TM) naming convention is used
199 elseif !isfield(VV,"Pp")
200 error("VV is defined but inside exist no VV.P or VV.Pp")
206 elseif !isfield(VV,"Tt")
207 error("VV is defined but inside exist no VV.TP or VV.Tt")
215 # ============================================================