]> Creatis software - clitk.git/blob - registration/clitkMultiResolutionPDEDeformableRegistration.txx
Remove vnl_math dependency into registration codes
[clitk.git] / registration / clitkMultiResolutionPDEDeformableRegistration.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to: 
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://www.centreleonberard.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef _clitkMultiResolutionPDEDeformableRegistration_txx
19 #define _clitkMultiResolutionPDEDeformableRegistration_txx
20 #include "clitkMultiResolutionPDEDeformableRegistration.h"
21
22 #include "itkRecursiveMultiResolutionPyramidImageFilter.h"
23 #include "itkImageRegionIterator.h"
24
25 namespace clitk {
26
27 /*
28  * Default constructor
29  */
30 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
31 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
32 ::MultiResolutionPDEDeformableRegistration()
33 {
34  
35   this->SetNumberOfRequiredInputs(2);
36
37   typename DefaultRegistrationType::Pointer registrator =
38     DefaultRegistrationType::New();
39   m_RegistrationFilter = static_cast<RegistrationType*>(
40     registrator.GetPointer() );
41
42   m_MovingImagePyramid  = MovingImagePyramidType::New();
43   m_FixedImagePyramid     = FixedImagePyramidType::New();
44   m_FieldExpander     = FieldExpanderType::New();
45   m_InitialDeformationField = ITK_NULLPTR;
46
47   m_NumberOfLevels = 3;
48   m_NumberOfIterations.resize( m_NumberOfLevels );
49   m_FixedImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
50   m_MovingImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
51
52   unsigned int ilevel;
53   for( ilevel = 0; ilevel < m_NumberOfLevels; ilevel++ )
54     {
55     m_NumberOfIterations[ilevel] = 10;
56     }
57   m_CurrentLevel = 0;
58
59   m_StopRegistrationFlag = false;
60
61 }
62
63
64 /*
65  * Set the moving image image.
66  */
67 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
68 void
69 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
70 ::SetMovingImage(
71 const MovingImageType * ptr )
72 {
73   this->itk::ProcessObject::SetNthInput( 2, const_cast< MovingImageType * >( ptr ) );
74 }
75
76
77 /*
78  * Get the moving image image.
79  */
80 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
81 const typename MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
82 ::MovingImageType *
83 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
84 ::GetMovingImage(void) const
85 {
86   return dynamic_cast< const MovingImageType * >
87     ( this->itk::ProcessObject::GetInput( 2 ) );
88 }
89
90
91 /*
92  * Set the fixed image.
93  */
94 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
95 void
96 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
97 ::SetFixedImage(
98 const FixedImageType * ptr )
99 {
100   this->itk::ProcessObject::SetNthInput( 1, const_cast< FixedImageType * >( ptr ) );
101 }
102
103
104 /*
105  * Get the fixed image.
106  */
107 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
108 const typename MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
109 ::FixedImageType *
110 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
111 ::GetFixedImage(void) const
112 {
113   return dynamic_cast< const FixedImageType * >
114     ( this->itk::ProcessObject::GetInput( 1 ) );
115 }
116
117 /*
118  * 
119  */
120 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
121 std::vector<itk::SmartPointer<itk::DataObject> >::size_type
122 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
123 ::GetNumberOfValidRequiredInputs() const
124 {
125   typename std::vector<itk::SmartPointer<itk::DataObject> >::size_type num = 0;
126
127   if (this->GetFixedImage())
128     {
129     num++;
130     }
131
132   if (this->GetMovingImage())
133     {
134     num++;
135     }
136   
137   return num;
138 }
139
140
141 /*
142  * Set the number of multi-resolution levels
143  */
144 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
145 void
146 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
147 ::SetNumberOfLevels(
148 unsigned int num )
149 {
150   if( m_NumberOfLevels != num )
151     {
152     this->Modified();
153     m_NumberOfLevels = num;
154     m_NumberOfIterations.resize( m_NumberOfLevels );
155     }
156
157   if( m_MovingImagePyramid && m_MovingImagePyramid->GetNumberOfLevels() != num )
158     {
159     m_MovingImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
160     }
161   if( m_FixedImagePyramid && m_FixedImagePyramid->GetNumberOfLevels() != num )
162     {
163     m_FixedImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
164     }  
165 }
166
167
168 /*
169  * Standard PrintSelf method.
170  */
171 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
172 void
173 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
174 ::PrintSelf(std::ostream& os, itk::Indent indent) const
175 {
176   Superclass::PrintSelf(os, indent);
177   os << indent << "NumberOfLevels: " << m_NumberOfLevels << std::endl;
178   os << indent << "CurrentLevel: " << m_CurrentLevel << std::endl;
179
180   os << indent << "NumberOfIterations: [";
181   unsigned int ilevel;
182   for( ilevel = 0; ilevel < m_NumberOfLevels - 1; ilevel++ )
183     {
184     os << m_NumberOfIterations[ilevel] << ", ";
185     }
186   os << m_NumberOfIterations[ilevel] << "]" << std::endl;
187   
188   os << indent << "RegistrationFilter: ";
189   os << m_RegistrationFilter.GetPointer() << std::endl;
190   os << indent << "MovingImagePyramid: ";
191   os << m_MovingImagePyramid.GetPointer() << std::endl;
192   os << indent << "FixedImagePyramid: ";
193   os << m_FixedImagePyramid.GetPointer() << std::endl;
194
195   os << indent << "FieldExpander: ";
196   os << m_FieldExpander.GetPointer() << std::endl;
197
198   os << indent << "StopRegistrationFlag: ";
199   os << m_StopRegistrationFlag << std::endl;
200
201 }
202
203 /*
204  * Perform a the deformable registration using a multiresolution scheme
205  * using an internal mini-pipeline
206  *
207  *  ref_pyramid ->  registrator  ->  field_expander --|| tempField
208  * test_pyramid ->           |                              |
209  *                           |                              |
210  *                           --------------------------------    
211  *
212  * A tempField image is used to break the cycle between the
213  * registrator and field_expander.
214  *
215  */                              
216 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
217 void
218 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
219 ::GenerateData()
220 {
221   // Check for ITK_NULLPTR images and pointers
222   MovingImageConstPointer movingImage = this->GetMovingImage();
223   FixedImageConstPointer  fixedImage = this->GetFixedImage();
224
225   if( !movingImage || !fixedImage )
226     {
227     itkExceptionMacro( << "Fixed and/or moving image not set" );
228     }
229
230   if( !m_MovingImagePyramid || !m_FixedImagePyramid )
231     {
232     itkExceptionMacro( << "Fixed and/or moving pyramid not set" );
233     }
234
235   if( !m_RegistrationFilter )
236     {
237     itkExceptionMacro( << "Registration filter not set" );
238     }
239   
240   if( this->m_InitialDeformationField && this->GetInput(0) )
241     {
242     itkExceptionMacro( << "Only one initial deformation can be given. "
243                        << "SetInitialDeformationField should not be used in "
244                        << "cunjunction with SetArbitraryInitialDeformationField "
245                        << "or SetInput.");
246     }
247
248   //Update the number of levels for the pyramids
249   this->SetNumberOfLevels(m_NumberOfLevels);
250
251   // Create the image pyramids.
252   m_MovingImagePyramid->SetInput( movingImage );
253   m_MovingImagePyramid->UpdateLargestPossibleRegion();
254
255   m_FixedImagePyramid->SetInput( fixedImage );
256   m_FixedImagePyramid->UpdateLargestPossibleRegion();
257  
258   // Initializations
259   m_CurrentLevel = 0;
260   m_StopRegistrationFlag = false;
261
262   unsigned int movingLevel = std::min( (int) m_CurrentLevel,
263                                            (int) m_MovingImagePyramid->GetNumberOfLevels() );
264
265   unsigned int fixedLevel = std::min( (int) m_CurrentLevel,
266                                           (int) m_FixedImagePyramid->GetNumberOfLevels() );
267
268   DeformationFieldPointer tempField = ITK_NULLPTR;
269
270   DeformationFieldPointer inputPtr =
271     const_cast< DeformationFieldType * >( this->GetInput(0) );
272   
273   if ( this->m_InitialDeformationField )
274     {
275     tempField = this->m_InitialDeformationField;
276     }
277   else if( inputPtr )
278     {
279     // Arbitrary initial deformation field is set.
280     // smooth it and resample
281
282     // First smooth it
283     tempField = inputPtr;
284       
285     typedef itk::RecursiveGaussianImageFilter< DeformationFieldType,
286       DeformationFieldType> GaussianFilterType;
287     typename GaussianFilterType::Pointer smoother
288       = GaussianFilterType::New();
289       
290     for (unsigned int dim=0; dim<DeformationFieldType::ImageDimension; ++dim)
291       {
292       // sigma accounts for the subsampling of the pyramid
293       double sigma = 0.5 * static_cast<float>(
294         m_FixedImagePyramid->GetSchedule()[fixedLevel][dim] );
295
296       // but also for a possible discrepancy in the spacing
297       sigma *= fixedImage->GetSpacing()[dim]
298         / inputPtr->GetSpacing()[dim];
299       
300       smoother->SetInput( tempField );
301       smoother->SetSigma( sigma );
302       smoother->SetDirection( dim );
303       
304       smoother->Update();
305       
306       tempField = smoother->GetOutput();
307       tempField->DisconnectPipeline();
308       }
309       
310       
311     // Now resample
312     m_FieldExpander->SetInput( tempField );
313     
314     typename FloatImageType::Pointer fi = 
315       m_FixedImagePyramid->GetOutput( fixedLevel );
316     m_FieldExpander->SetSize( 
317       fi->GetLargestPossibleRegion().GetSize() );
318     m_FieldExpander->SetOutputStartIndex(
319       fi->GetLargestPossibleRegion().GetIndex() );
320     m_FieldExpander->SetOutputOrigin( fi->GetOrigin() );
321     m_FieldExpander->SetOutputSpacing( fi->GetSpacing());
322     m_FieldExpander->SetOutputDirection( fi->GetDirection());
323
324     m_FieldExpander->UpdateLargestPossibleRegion();
325     m_FieldExpander->SetInput( ITK_NULLPTR );
326     tempField = m_FieldExpander->GetOutput();
327     tempField->DisconnectPipeline();
328     }
329
330
331   bool lastShrinkFactorsAllOnes = false;
332   while ( !this->Halt() )
333     {
334       
335       if( tempField.IsNull() )
336         {
337           m_RegistrationFilter->SetInitialDisplacementField( ITK_NULLPTR );
338         }
339       else
340         {
341           // Resample the field to be the same size as the fixed image
342           // at the current level
343           m_FieldExpander->SetInput( tempField );
344           
345       typename FloatImageType::Pointer fi = 
346         m_FixedImagePyramid->GetOutput( fixedLevel );
347       m_FieldExpander->SetSize( 
348         fi->GetLargestPossibleRegion().GetSize() );
349       m_FieldExpander->SetOutputStartIndex(
350         fi->GetLargestPossibleRegion().GetIndex() );
351       m_FieldExpander->SetOutputOrigin( fi->GetOrigin() );
352       m_FieldExpander->SetOutputSpacing( fi->GetSpacing());
353
354       m_FieldExpander->UpdateLargestPossibleRegion();
355       m_FieldExpander->SetInput( ITK_NULLPTR );
356       tempField = m_FieldExpander->GetOutput();
357       tempField->DisconnectPipeline();
358
359       m_RegistrationFilter->SetInitialDisplacementField( tempField );
360       }
361
362     // setup registration filter and pyramids 
363     m_RegistrationFilter->SetMovingImage( m_MovingImagePyramid->GetOutput(movingLevel) );
364     m_RegistrationFilter->SetFixedImage( m_FixedImagePyramid->GetOutput(fixedLevel) );
365     m_RegistrationFilter->SetNumberOfIterations(
366       m_NumberOfIterations[m_CurrentLevel] );
367
368     // cache shrink factors for computing the next expand factors.
369     lastShrinkFactorsAllOnes = true;
370     for( unsigned int idim = 0; idim < ImageDimension; idim++ )
371       {
372       if ( m_FixedImagePyramid->GetSchedule()[fixedLevel][idim] > 1 )
373         {
374         lastShrinkFactorsAllOnes = false;
375         break;
376         }
377       }
378
379     // Invoke an iteration event.
380     this->InvokeEvent( itk::IterationEvent() );
381
382     // compute new deformation field
383     m_RegistrationFilter->UpdateLargestPossibleRegion();
384     tempField = m_RegistrationFilter->GetOutput();
385     tempField->DisconnectPipeline();
386
387     // Increment level counter.  
388     m_CurrentLevel++;
389     movingLevel = std::min( (int) m_CurrentLevel,
390                                 (int) m_MovingImagePyramid->GetNumberOfLevels() );
391     fixedLevel = std::min( (int) m_CurrentLevel,
392                                (int) m_FixedImagePyramid->GetNumberOfLevels() );
393
394     // We can release data from pyramid which are no longer required.
395     if ( movingLevel > 0 )
396       {
397       m_MovingImagePyramid->GetOutput( movingLevel - 1 )->ReleaseData();
398       }
399     if( fixedLevel > 0 )
400       {
401       m_FixedImagePyramid->GetOutput( fixedLevel - 1 )->ReleaseData();
402       }
403
404     } // while not Halt()
405
406     if( !lastShrinkFactorsAllOnes )
407       {
408       // Some of the last shrink factors are not one
409       // graft the output of the expander filter to
410       // to output of this filter
411
412       // resample the field to the same size as the fixed image
413       m_FieldExpander->SetInput( tempField );
414       m_FieldExpander->SetSize( 
415         fixedImage->GetLargestPossibleRegion().GetSize() );
416       m_FieldExpander->SetOutputStartIndex(
417         fixedImage->GetLargestPossibleRegion().GetIndex() );
418       m_FieldExpander->SetOutputOrigin( fixedImage->GetOrigin() );
419       m_FieldExpander->SetOutputSpacing( fixedImage->GetSpacing());
420       m_FieldExpander->UpdateLargestPossibleRegion();
421       this->GraftOutput( m_FieldExpander->GetOutput() );
422       }
423     else
424       {
425       // all the last shrink factors are all ones
426       // graft the output of registration filter to
427       // to output of this filter
428       this->GraftOutput( tempField );
429       }
430
431     // Release memory
432     m_FieldExpander->SetInput( ITK_NULLPTR );
433     m_FieldExpander->GetOutput()->ReleaseData();
434     m_RegistrationFilter->SetInput( ITK_NULLPTR );
435     m_RegistrationFilter->GetOutput()->ReleaseData();
436
437 }
438
439
440 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
441 void
442 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
443 ::StopRegistration()
444 {
445   m_RegistrationFilter->StopRegistration();
446   m_StopRegistrationFlag = true;
447 }
448
449 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
450 bool
451 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
452 ::Halt()
453 {
454   // Halt the registration after the user-specified number of levels
455   if (m_NumberOfLevels != 0)
456   {
457   this->UpdateProgress( static_cast<float>( m_CurrentLevel ) /
458                         static_cast<float>( m_NumberOfLevels ) );
459   }
460
461   if ( m_CurrentLevel >= m_NumberOfLevels )
462     {
463     return true;
464     }
465   if ( m_StopRegistrationFlag )
466     {
467     return true;
468     }
469   else
470     { 
471     return false; 
472     }
473
474 }
475
476
477 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
478 void
479 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
480 ::GenerateOutputInformation()
481 {
482
483   typename itk::DataObject::Pointer output;
484
485  if( this->GetInput(0) )
486   {
487   // Initial deformation field is set.
488   // Copy information from initial field.
489   this->Superclass::GenerateOutputInformation();
490
491   }
492  else if( this->GetFixedImage() )
493   {
494   // Initial deforamtion field is not set. 
495   // Copy information from the fixed image.
496   for (unsigned int idx = 0; idx < 
497     this->GetNumberOfOutputs(); ++idx )
498     {
499     output = this->GetOutput(idx);
500     if (output)
501       {
502       output->CopyInformation(this->GetFixedImage());
503       }  
504     }
505
506   }
507
508 }
509
510
511 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
512 void
513 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
514 ::GenerateInputRequestedRegion()
515 {
516
517   // call the superclass's implementation
518   Superclass::GenerateInputRequestedRegion();
519
520   // request the largest possible region for the moving image
521   MovingImagePointer movingPtr = 
522     const_cast< MovingImageType * >( this->GetMovingImage() );
523   if( movingPtr )
524     {
525     movingPtr->SetRequestedRegionToLargestPossibleRegion();
526     }
527   
528   // just propagate up the output requested region for
529   // the fixed image and initial deformation field.
530   DeformationFieldPointer inputPtr = 
531       const_cast< DeformationFieldType * >( this->GetInput() );
532   DeformationFieldPointer outputPtr = this->GetOutput();
533   FixedImagePointer fixedPtr = 
534         const_cast< FixedImageType *>( this->GetFixedImage() );
535
536   if( inputPtr )
537     {
538     inputPtr->SetRequestedRegion( outputPtr->GetRequestedRegion() );
539     }
540
541   if( fixedPtr )
542     {
543     fixedPtr->SetRequestedRegion( outputPtr->GetRequestedRegion() );
544     }
545
546 }
547
548
549 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
550 void
551 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
552 ::EnlargeOutputRequestedRegion(
553                                itk::DataObject * ptr )
554 {
555   // call the superclass's implementation
556   Superclass::EnlargeOutputRequestedRegion( ptr );
557
558   // set the output requested region to largest possible.
559   DeformationFieldType * outputPtr;
560   outputPtr = dynamic_cast<DeformationFieldType*>( ptr );
561
562   if( outputPtr )
563     {
564     outputPtr->SetRequestedRegionToLargestPossibleRegion();
565     }
566
567 }
568
569
570 } // end namespace itk
571
572 #endif