]> Creatis software - CreaPhase.git/blob - octave_packages/nnet-0.1.13/train.m
Add a useful package (from Source forge) for octave
[CreaPhase.git] / octave_packages / nnet-0.1.13 / train.m
1 ## Copyright (C) 2006 Michel D. Schmid  <michaelschmid@users.sourceforge.net>
2 ##
3 ##
4 ## This program is free software; you can redistribute it and/or modify it
5 ## under the terms of the GNU General Public License as published by
6 ## the Free Software Foundation; either version 2, or (at your option)
7 ## any later version.
8 ##
9 ## This program is distributed in the hope that it will be useful, but
10 ## WITHOUT ANY WARRANTY; without even the implied warranty of
11 ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
12 ## General Public License for more details.
13 ##
14 ## You should have received a copy of the GNU General Public License
15 ## along with this program; see the file COPYING.  If not, see
16 ## <http://www.gnu.org/licenses/>.
17
18 ## -*- texinfo -*-
19 ## @deftypefn {Function File} {}[@var{net}] = train (@var{MLPnet},@var{mInputN},@var{mOutput},@var{[]},@var{[]},@var{VV})
20 ## A neural feed-forward network will be trained with @code{train}
21 ##
22 ## @example
23 ## [net,tr,out,E] = train(MLPnet,mInputN,mOutput,[],[],VV);
24 ## @end example
25 ## @noindent
26 ##
27 ## @example
28 ## left side arguments:
29 ##   net: the trained network of the net structure @code{MLPnet}
30 ## @end example
31 ## @noindent
32 ##
33 ## @example
34 ## right side arguments:
35 ##   MLPnet : the untrained network, created with @code{newff}
36 ##   mInputN: normalized input matrix
37 ##   mOutput: output matrix (normalized or not)
38 ##   []     : unused parameter
39 ##   []     : unused parameter
40 ##   VV     : validize structure
41 ## @end example
42 ## @end deftypefn
43
44 ## @seealso{newff,prestd,trastd}
45
46 ## Author: Michel D. Schmid
47
48 ## Comments: see in "A neural network toolbox for Octave User's Guide" [4]
49 ## for variable naming... there have inputs or targets only one letter,
50 ## e.g. for inputs is P written. To write a program, this is stupid, you can't
51 ## search for 1 letter variable... that's why it is written here like Pp, or Tt
52 ## instead only P or T.
53
54 function [net] = train(net,Pp,Tt,notUsed1,notUsed2,VV)
55
56   ## check range of input arguments
57   error(nargchk(3,6,nargin))
58
59   ## set defaults
60   doValidation = 0;
61   if nargin==6
62     # doValidation=1;
63     ## check if VV is in MATLAB(TM) notation
64     [VV, doValidation] = checkVV(VV);
65   endif
66
67   ## check input args
68   checkInputArgs(net,Pp,Tt)
69
70   ## nargin ...
71   switch(nargin)
72   case 3
73     [Pp,Tt] = trainArgs(net,Pp,Tt);
74     VV = [];
75   case 6
76     [Pp,Tt] = trainArgs(net,Pp,Tt);
77     if isempty(VV)
78       VV = [];
79     else
80       if !isfield(VV,"Pp")
81         error("VV.Pp must be defined or VV must be [].")
82       endif
83       if !isfield(VV,"Tt")
84         error("VV.Tt must be defined or VV must be [].")
85       endif
86       [VV.Pp,VV.Tt] = trainArgs(net,VV.Pp,VV.Tt);
87     endif
88   otherwise
89     error("train: impossible code execution in switch(nargin)")
90   endswitch
91
92
93   ## so now, let's start training the network
94   ##===========================================
95
96   ## first let's check if a train function is defined ...
97   if isempty(net.trainFcn)
98     error("train:net.trainFcn not defined")
99   endif
100
101   ## calculate input matrix Im
102   [nRowsInputs, nColumnsInputs] = size(Pp);
103   Im = ones(nRowsInputs,nColumnsInputs).*Pp{1,1};
104
105   if (doValidation)
106     [nRowsVal, nColumnsVal] = size(VV.Pp);
107     VV.Im = ones(nRowsVal,nColumnsVal).*VV.Pp{1,1};
108   endif
109
110   ## make it MATLAB(TM) compatible
111   nLayers = net.numLayers;
112   Tt{nLayers,1} = Tt{1,1};
113   Tt{1,1} = [];
114   if (!isempty(VV))
115     VV.Tt{nLayers,1} = VV.Tt{1,1};
116     VV.Tt{1,1} = [];
117   endif
118
119   ## which training algorithm should be used
120   switch(net.trainFcn)
121     case "trainlm"
122       if !strcmp(net.performFcn,"mse")
123         error("Levenberg-Marquardt algorithm is defined with the MSE performance function, so please set MSE in NEWFF!")
124       endif
125       net = __trainlm(net,Im,Pp,Tt,VV);
126     otherwise
127       error("train algorithm argument is not valid!")
128   endswitch
129
130
131 # =======================================================
132 #
133 # additional check functions...
134 #
135 # =======================================================
136
137   function checkInputArgs(net,Pp,Tt)
138       
139     ## check "net", must be a net structure
140     if !__checknetstruct(net)
141       error("Structure doesn't seem to be a neural network!")
142     endif
143
144     ## check Pp (inputs)
145     nInputSize = net.inputs{1}.size; #only one exists
146     [nRowsPp, nColumnsPp] = size(Pp);
147     if ( (nColumnsPp>0) )
148       if ( nInputSize==nRowsPp )
149       # seems to be everything i.o.
150       else
151         error("Number of rows must be the same, like in net.inputs.size defined!")
152       endif
153     else
154       error("At least one column must exist")
155     endif
156     
157     ## check Tt (targets)
158     [nRowsTt, nColumnsTt] = size(Tt);
159     if ( (nRowsTt | nColumnsTt)==0 )
160       error("No targets defined!")
161     elseif ( nColumnsTt!=nColumnsPp )
162       error("Inputs and targets must have the same number of data sets (columns).")
163     elseif ( net.layers{net.numLayers}.size!=nRowsTt )
164       error("Defined number of output neurons are not identically to targets data sets!")
165     endif
166
167   endfunction
168 # -------------------------------------------------------
169   function [Pp,Tt] = trainArgs(net,Pp,Tt);
170
171     ## check number of arguments
172     error(nargchk(3,3,nargin));
173
174     [PpRows, PpColumns] = size(Pp);
175     Pp = mat2cell(Pp,PpRows,PpColumns);    # mat2cell is the reason
176                                                                            # why octave-2.9.5 doesn't work
177                                                                                    # octave-2.9.x with x>=6 should be
178                                                                                    # ok
179     [TtRows, TtColumns] = size(Tt);
180     Tt = mat2cell(Tt,TtRows,TtColumns);
181
182   endfunction
183
184 # -------------------------------------------------------
185
186   function [VV, doValidation] = checkVV(VV)
187
188     ## check number of arguments
189     error(nargchk(1,1,nargin));
190
191         if (isempty(VV))        
192           doValidation = 0;     
193         else
194           doValidation = 1;
195       ## check if MATLAB(TM) naming convention is used
196       if isfield(VV,"P")
197         VV.Pp = VV.P;
198         VV.P = [];
199       elseif !isfield(VV,"Pp")
200         error("VV is defined but inside exist no VV.P or VV.Pp")
201       endif
202
203       if isfield(VV,"T")
204         VV.Tt = VV.T;
205         VV.T = [];
206       elseif !isfield(VV,"Tt")
207         error("VV is defined but inside exist no VV.TP or VV.Tt")
208       endif
209         
210         endif
211
212
213   endfunction
214
215 # ============================================================
216
217 endfunction