]> Creatis software - clitk.git/blob - registration/clitkSpatioTemporalMultiResolutionImageRegistrationMethod.txx
Merge branch 'master' of /home/dsarrut/clitk3.server
[clitk.git] / registration / clitkSpatioTemporalMultiResolutionImageRegistrationMethod.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to: 
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://www.centreleonberard.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef __clitkSpatioTemporalMultiResolutionImageRegistrationMethod_txx
19 #define __clitkSpatioTemporalMultiResolutionImageRegistrationMethod_txx
20 #include "clitkSpatioTemporalMultiResolutionImageRegistrationMethod.h"
21 #include "clitkRecursiveSpatioTemporalMultiResolutionPyramidImageFilter.h"
22
23 namespace clitk
24 {
25
26 /**
27  * Constructor
28  */
29 template < typename TFixedImage, typename TMovingImage >
30 SpatioTemporalMultiResolutionImageRegistrationMethod<TFixedImage,TMovingImage>
31 ::SpatioTemporalMultiResolutionImageRegistrationMethod()
32 {
33   this->SetNumberOfRequiredOutputs( 1 );  // for the Transform
34
35   m_FixedImage   = 0; // has to be provided by the user.
36   m_MovingImage  = 0; // has to be provided by the user.
37   m_Transform    = 0; // has to be provided by the user.
38   m_Interpolator = 0; // has to be provided by the user.
39   m_Metric       = 0; // has to be provided by the user.
40   m_Optimizer    = 0; // has to be provided by the user.
41
42   // Use SpatioTemporalMultiResolutionPyramidImageFilter as the default
43   // image pyramids.
44   m_FixedImagePyramid  = FixedImagePyramidType::New(); 
45   m_MovingImagePyramid = MovingImagePyramidType::New();
46
47   m_NumberOfLevels = 1;
48   m_CurrentLevel = 0;
49
50   m_Stop = false;
51
52   m_ScheduleSpecified = false;
53   m_NumberOfLevelsSpecified = false;
54
55   m_InitialTransformParameters = ParametersType(1);
56   m_InitialTransformParametersOfNextLevel = ParametersType(1);
57   m_LastTransformParameters = ParametersType(1);
58
59   m_InitialTransformParameters.Fill( 0.0f );
60   m_InitialTransformParametersOfNextLevel.Fill( 0.0f );
61   m_LastTransformParameters.Fill( 0.0f );
62
63
64   TransformOutputPointer transformDecorator = 
65                  static_cast< TransformOutputType * >( 
66                                   this->MakeOutput(0).GetPointer() );
67
68   this->ProcessObject::SetNthOutput( 0, transformDecorator.GetPointer() );
69 }
70
71
72 /*
73  * Initialize by setting the interconnects between components. 
74  */
75 template < typename TFixedImage, typename TMovingImage >
76 void
77 SpatioTemporalMultiResolutionImageRegistrationMethod<TFixedImage,TMovingImage>
78 ::Initialize() throw (ExceptionObject)
79 {
80
81   // Sanity checks
82   if ( !m_Metric )
83     {
84     itkExceptionMacro(<<"Metric is not present" );
85     }
86
87   if ( !m_Optimizer )
88     {
89     itkExceptionMacro(<<"Optimizer is not present" );
90     }
91
92   if( !m_Transform )
93     {
94     itkExceptionMacro(<<"Transform is not present");
95     }
96
97   if( !m_Interpolator )
98     {
99     itkExceptionMacro(<<"Interpolator is not present");
100     }
101
102   // Setup the metric
103   m_Metric->SetMovingImage( m_MovingImagePyramid->GetOutput(m_CurrentLevel) );
104   m_Metric->SetFixedImage( m_FixedImagePyramid->GetOutput(m_CurrentLevel) );
105   m_Metric->SetTransform( m_Transform );
106   m_Metric->SetInterpolator( m_Interpolator );
107   m_Metric->SetFixedImageRegion( m_FixedImageRegionPyramid[ m_CurrentLevel ] );
108   m_Metric->Initialize();
109
110   // Setup the optimizer
111   m_Optimizer->SetCostFunction( m_Metric );
112   m_Optimizer->SetInitialPosition( m_InitialTransformParametersOfNextLevel );
113
114   //
115   // Connect the transform to the Decorator.
116   //
117   TransformOutputType * transformOutput =  
118      static_cast< TransformOutputType * >( this->ProcessObject::GetOutput(0) );
119
120   transformOutput->Set( m_Transform.GetPointer() );
121
122 }
123
124
125 /*
126  * Stop the Registration Process
127  */
128 template < typename TFixedImage, typename TMovingImage >
129 void
130 SpatioTemporalMultiResolutionImageRegistrationMethod<TFixedImage,TMovingImage>
131 ::StopRegistration( void )
132 {
133   m_Stop = true;
134 }
135
136 /**
137  * Set the schedules for the fixed and moving image pyramid
138  */
139 template < typename TFixedImage, typename TMovingImage >
140 void
141 SpatioTemporalMultiResolutionImageRegistrationMethod<TFixedImage,TMovingImage>
142 ::SetSchedules( const ScheduleType & fixedImagePyramidSchedule,
143                const ScheduleType & movingImagePyramidSchedule )
144 {
145   if( m_NumberOfLevelsSpecified )
146     {
147     itkExceptionMacro( "SetSchedules should not be used " 
148            << "if numberOfLevelves are specified using SetNumberOfLevels" );
149     }
150   m_FixedImagePyramidSchedule = fixedImagePyramidSchedule;
151   m_MovingImagePyramidSchedule = movingImagePyramidSchedule;
152   m_ScheduleSpecified = true;
153
154   //Set the number of levels based on the pyramid schedule specified
155   if ( m_FixedImagePyramidSchedule.rows() != 
156         m_MovingImagePyramidSchedule.rows())
157     {
158     itkExceptionMacro("The specified schedules contain unequal number of levels");
159     }
160   else
161     {
162     m_NumberOfLevels = m_FixedImagePyramidSchedule.rows();
163     }
164
165   this->Modified();
166 }
167
168 /**
169  * Set the number of levels  
170  */
171 template < typename TFixedImage, typename TMovingImage >
172 void
173 SpatioTemporalMultiResolutionImageRegistrationMethod<TFixedImage,TMovingImage>
174 ::SetNumberOfLevels( unsigned long numberOfLevels )
175 {
176   if( m_ScheduleSpecified )
177     {
178     itkExceptionMacro( "SetNumberOfLevels should not be used " 
179       << "if schedules have been specified using SetSchedules method " );
180     }
181
182   m_NumberOfLevels = numberOfLevels;
183   m_NumberOfLevelsSpecified = true;
184   this->Modified();
185 }
186
187 /**
188  * Stop the Registration Process
189  */
190 template < typename TFixedImage, typename TMovingImage >
191 void
192 SpatioTemporalMultiResolutionImageRegistrationMethod<TFixedImage,TMovingImage>
193 ::PreparePyramids( void )
194 {
195
196   if( !m_Transform )
197     {
198     itkExceptionMacro(<<"Transform is not present");
199     }
200
201   m_InitialTransformParametersOfNextLevel = m_InitialTransformParameters;
202
203   if ( m_InitialTransformParametersOfNextLevel.Size() != 
204        m_Transform->GetNumberOfParameters() )
205     {
206     itkExceptionMacro(<<"Size mismatch between initial parameter and transform"); 
207     }
208
209   // Sanity checks
210   if( !m_FixedImage )
211     {
212     itkExceptionMacro(<<"FixedImage is not present");
213     }
214
215   if( !m_MovingImage )
216     {
217     itkExceptionMacro(<<"MovingImage is not present");
218     }
219
220   if( !m_FixedImagePyramid )
221     {
222     itkExceptionMacro(<<"Fixed image pyramid is not present");
223     }
224
225   if( !m_MovingImagePyramid )
226     {
227     itkExceptionMacro(<<"Moving image pyramid is not present");
228     }
229
230   // Setup the fixed and moving image pyramid
231   if( m_NumberOfLevelsSpecified )
232     {
233     m_FixedImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
234     m_MovingImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
235     }
236
237   if( m_ScheduleSpecified )
238     {
239     m_FixedImagePyramid->SetNumberOfLevels( m_FixedImagePyramidSchedule.rows());
240     m_FixedImagePyramid->SetSchedule( m_FixedImagePyramidSchedule );
241
242     m_MovingImagePyramid->SetNumberOfLevels( m_MovingImagePyramidSchedule.rows());
243     m_MovingImagePyramid->SetSchedule( m_MovingImagePyramidSchedule );
244     }
245
246   m_FixedImagePyramid->SetInput( m_FixedImage );
247   m_FixedImagePyramid->UpdateLargestPossibleRegion();
248
249   // Setup the moving image pyramid
250   m_MovingImagePyramid->SetInput( m_MovingImage );
251   m_MovingImagePyramid->UpdateLargestPossibleRegion();
252
253   typedef typename FixedImageRegionType::SizeType         SizeType;
254   typedef typename FixedImageRegionType::IndexType        IndexType;
255
256   ScheduleType schedule = m_FixedImagePyramid->GetSchedule();
257   std::cout << "FixedImage schedule: " << schedule << std::endl;
258
259   ScheduleType movingschedule = m_MovingImagePyramid->GetSchedule();
260   std::cout << "MovingImage schedule: " << movingschedule << std::endl;
261
262   SizeType  inputSize  = m_FixedImageRegion.GetSize();
263   IndexType inputStart = m_FixedImageRegion.GetIndex();
264
265   const unsigned long numberOfLevels = 
266           m_FixedImagePyramid->GetNumberOfLevels(); 
267
268   m_FixedImageRegionPyramid.reserve( numberOfLevels );
269   m_FixedImageRegionPyramid.resize( numberOfLevels );
270
271   // Compute the FixedImageRegion corresponding to each level of the 
272   // pyramid. This uses the same algorithm of the ShrinkImageFilter 
273   // since the regions should be compatible. 
274   for ( unsigned int level=0; level < numberOfLevels; level++ )
275     {
276     SizeType  size;
277     IndexType start;
278     for ( unsigned int dim = 0; dim < TFixedImage::ImageDimension; dim++)
279       {
280       const float scaleFactor = static_cast<float>( schedule[ level ][ dim ] );
281
282       size[ dim ] = static_cast<typename SizeType::SizeValueType>(
283         vcl_floor(static_cast<float>( inputSize[ dim ] ) / scaleFactor ) );
284       if( size[ dim ] < 1 )
285         {
286         size[ dim ] = 1;
287         }
288       
289       start[ dim ] = static_cast<typename IndexType::IndexValueType>(
290         vcl_ceil(static_cast<float>( inputStart[ dim ] ) / scaleFactor ) ); 
291       }
292     m_FixedImageRegionPyramid[ level ].SetSize( size );
293     m_FixedImageRegionPyramid[ level ].SetIndex( start );
294     }
295
296 }
297
298 /*
299  * Starts the Registration Process
300  */
301 template < typename TFixedImage, typename TMovingImage >
302 void
303 SpatioTemporalMultiResolutionImageRegistrationMethod<TFixedImage,TMovingImage>
304 ::StartRegistration( void )
305
306
307   // StartRegistration is an old API from before
308   // this egistrationMethod was a subclass of ProcessObject.
309   // Historically, one could call StartRegistration() instead of
310   // calling Update().  However, when called directly by the user, the
311   // inputs to the RegistrationMethod may not be up to date.  This
312   // may cause an unexpected behavior.
313   //
314   // Since we cannot eliminate StartRegistration for backward
315   // compability reasons, we check whether StartRegistration was
316   // called directly or whether Update() (which in turn called 
317   // StartRegistration()).
318   if (!m_Updating)
319     {
320     this->Update();
321     }
322   else
323     {
324     m_Stop = false;
325     
326     this->PreparePyramids();
327     
328     for ( m_CurrentLevel = 0; m_CurrentLevel < m_NumberOfLevels;
329           m_CurrentLevel++ )
330       {
331       
332       // Invoke an iteration event.
333       // This allows a UI to reset any of the components between
334       // resolution level.
335       this->InvokeEvent( IterationEvent() );
336       
337       // Check if there has been a stop request
338       if ( m_Stop ) 
339         {
340         break;
341         }
342       
343       try
344         {
345         // initialize the interconnects between components
346         this->Initialize();
347         }
348       catch( ExceptionObject& err )
349         {
350         m_LastTransformParameters = ParametersType(1);
351         m_LastTransformParameters.Fill( 0.0f );
352         
353         // pass exception to caller
354         throw err;
355         }
356       
357       try
358         {
359         // do the optimization
360         m_Optimizer->StartOptimization();
361         }
362       catch( ExceptionObject& err )
363         {
364         // An error has occurred in the optimization.
365         // Update the parameters
366         m_LastTransformParameters = m_Optimizer->GetCurrentPosition();
367         
368         // Pass exception to caller
369         throw err;
370         }
371       
372       // get the results
373       m_LastTransformParameters = m_Optimizer->GetCurrentPosition();
374       m_Transform->SetParameters( m_LastTransformParameters );
375       
376       // setup the initial parameters for next level
377       if ( m_CurrentLevel < m_NumberOfLevels - 1 )
378         {
379         m_InitialTransformParametersOfNextLevel =
380           m_LastTransformParameters;
381         }
382       }
383     }
384
385 }
386
387
388 /*
389  * PrintSelf
390  */
391 template < typename TFixedImage, typename TMovingImage >
392 void
393 SpatioTemporalMultiResolutionImageRegistrationMethod<TFixedImage,TMovingImage>
394 ::PrintSelf(std::ostream& os, itk::Indent indent) const
395 {
396   Superclass::PrintSelf( os, indent );
397   os << indent << "Metric: " << m_Metric.GetPointer() << std::endl;
398   os << indent << "Optimizer: " << m_Optimizer.GetPointer() << std::endl;
399   os << indent << "Transform: " << m_Transform.GetPointer() << std::endl;
400   os << indent << "Interpolator: " << m_Interpolator.GetPointer() << std::endl;
401   os << indent << "FixedImage: " << m_FixedImage.GetPointer() << std::endl;
402   os << indent << "MovingImage: " << m_MovingImage.GetPointer() << std::endl;
403   os << indent << "FixedImagePyramid: ";
404   os << m_FixedImagePyramid.GetPointer() << std::endl;
405   os << indent << "MovingImagePyramid: ";
406   os << m_MovingImagePyramid.GetPointer() << std::endl;
407
408   os << indent << "NumberOfLevels: ";
409   os << m_NumberOfLevels << std::endl;
410
411   os << indent << "CurrentLevel: ";
412   os << m_CurrentLevel << std::endl;  
413
414   os << indent << "InitialTransformParameters: ";
415   os << m_InitialTransformParameters << std::endl;
416   os << indent << "InitialTransformParametersOfNextLevel: ";
417   os << m_InitialTransformParametersOfNextLevel << std::endl;
418   os << indent << "LastTransformParameters: ";
419   os << m_LastTransformParameters << std::endl;
420   os << indent << "FixedImageRegion: ";
421   os << m_FixedImageRegion << std::endl;
422   for(unsigned int level=0; level< m_FixedImageRegionPyramid.size(); level++)
423     {
424     os << indent << "FixedImageRegion at level " << level << ": ";
425     os << m_FixedImageRegionPyramid[level] << std::endl;
426     }
427   os << indent << "FixedImagePyramidSchedule : " << std::endl;
428   os << m_FixedImagePyramidSchedule << std::endl;
429   os << indent << "MovingImagePyramidSchedule : " << std::endl;
430   os << m_MovingImagePyramidSchedule << std::endl;
431
432 }
433
434
435 /*
436  * Generate Data
437  */
438 template < typename TFixedImage, typename TMovingImage >
439 void
440 SpatioTemporalMultiResolutionImageRegistrationMethod<TFixedImage,TMovingImage>
441 ::GenerateData()
442 {
443   this->StartRegistration();
444 }
445
446 template < typename TFixedImage, typename TMovingImage >
447 unsigned long
448 SpatioTemporalMultiResolutionImageRegistrationMethod<TFixedImage,TMovingImage>
449 ::GetMTime() const
450 {
451   unsigned long mtime = Superclass::GetMTime();
452   unsigned long m;
453
454
455   // Some of the following should be removed once ivars are put in the
456   // input and output lists
457   
458   if (m_Transform)
459     {
460     m = m_Transform->GetMTime();
461     mtime = (m > mtime ? m : mtime);
462     }
463
464   if (m_Interpolator)
465     {
466     m = m_Interpolator->GetMTime();
467     mtime = (m > mtime ? m : mtime);
468     }
469
470   if (m_Metric)
471     {
472     m = m_Metric->GetMTime();
473     mtime = (m > mtime ? m : mtime);
474     }
475
476   if (m_Optimizer)
477     {
478     m = m_Optimizer->GetMTime();
479     mtime = (m > mtime ? m : mtime);
480     }
481
482   if (m_FixedImage)
483     {
484     m = m_FixedImage->GetMTime();
485     mtime = (m > mtime ? m : mtime);
486     }
487
488   if (m_MovingImage)
489     {
490     m = m_MovingImage->GetMTime();
491     mtime = (m > mtime ? m : mtime);
492     }
493
494   return mtime;
495   
496 }
497
498 /*
499  *  Get Output
500  */
501 template < typename TFixedImage, typename TMovingImage >
502 const typename SpatioTemporalMultiResolutionImageRegistrationMethod<TFixedImage,TMovingImage>::TransformOutputType *
503 SpatioTemporalMultiResolutionImageRegistrationMethod<TFixedImage,TMovingImage>
504 ::GetOutput() const
505 {
506   return static_cast< const TransformOutputType * >( this->ProcessObject::GetOutput(0) );
507 }
508
509 template < typename TFixedImage, typename TMovingImage >
510 DataObject::Pointer
511 SpatioTemporalMultiResolutionImageRegistrationMethod<TFixedImage,TMovingImage>
512 ::MakeOutput(unsigned int output)
513 {
514   switch (output)
515     {
516     case 0:
517       return static_cast<DataObject*>(TransformOutputType::New().GetPointer());
518       break;
519     default:
520       itkExceptionMacro("MakeOutput request for an output number larger than the expected number of outputs");
521       return 0;
522     }
523 }
524
525 } // end namespace clitk
526
527
528 #endif