]> Creatis software - clitk.git/blob - clitkMultiResolutionPDEDeformableRegistration.txx
5c6b0aeaf55c38e2e3361f2d683a15a4c7d1b927
[clitk.git] / clitkMultiResolutionPDEDeformableRegistration.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to: 
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://www.centreleonberard.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef _clitkMultiResolutionPDEDeformableRegistration_txx
19 #define _clitkMultiResolutionPDEDeformableRegistration_txx
20 #include "clitkMultiResolutionPDEDeformableRegistration.h"
21
22 #include "itkRecursiveMultiResolutionPyramidImageFilter.h"
23 #include "itkImageRegionIterator.h"
24 #include "vnl/vnl_math.h"
25
26 namespace clitk {
27
28 /*
29  * Default constructor
30  */
31 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
32 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
33 ::MultiResolutionPDEDeformableRegistration()
34 {
35  
36   this->SetNumberOfRequiredInputs(2);
37
38   typename DefaultRegistrationType::Pointer registrator =
39     DefaultRegistrationType::New();
40   m_RegistrationFilter = static_cast<RegistrationType*>(
41     registrator.GetPointer() );
42
43   m_MovingImagePyramid  = MovingImagePyramidType::New();
44   m_FixedImagePyramid     = FixedImagePyramidType::New();
45   m_FieldExpander     = FieldExpanderType::New();
46   m_InitialDeformationField = NULL;
47
48   m_NumberOfLevels = 3;
49   m_NumberOfIterations.resize( m_NumberOfLevels );
50   m_FixedImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
51   m_MovingImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
52
53   unsigned int ilevel;
54   for( ilevel = 0; ilevel < m_NumberOfLevels; ilevel++ )
55     {
56     m_NumberOfIterations[ilevel] = 10;
57     }
58   m_CurrentLevel = 0;
59
60   m_StopRegistrationFlag = false;
61
62 }
63
64
65 /*
66  * Set the moving image image.
67  */
68 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
69 void
70 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
71 ::SetMovingImage(
72 const MovingImageType * ptr )
73 {
74   this->itk::ProcessObject::SetNthInput( 2, const_cast< MovingImageType * >( ptr ) );
75 }
76
77
78 /*
79  * Get the moving image image.
80  */
81 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
82 const typename MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
83 ::MovingImageType *
84 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
85 ::GetMovingImage(void) const
86 {
87   return dynamic_cast< const MovingImageType * >
88     ( this->itk::ProcessObject::GetInput( 2 ) );
89 }
90
91
92 /*
93  * Set the fixed image.
94  */
95 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
96 void
97 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
98 ::SetFixedImage(
99 const FixedImageType * ptr )
100 {
101   this->itk::ProcessObject::SetNthInput( 1, const_cast< FixedImageType * >( ptr ) );
102 }
103
104
105 /*
106  * Get the fixed image.
107  */
108 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
109 const typename MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
110 ::FixedImageType *
111 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
112 ::GetFixedImage(void) const
113 {
114   return dynamic_cast< const FixedImageType * >
115     ( this->itk::ProcessObject::GetInput( 1 ) );
116 }
117
118 /*
119  * 
120  */
121 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
122 std::vector<itk::SmartPointer<itk::DataObject> >::size_type
123 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
124 ::GetNumberOfValidRequiredInputs() const
125 {
126   typename std::vector<itk::SmartPointer<itk::DataObject> >::size_type num = 0;
127
128   if (this->GetFixedImage())
129     {
130     num++;
131     }
132
133   if (this->GetMovingImage())
134     {
135     num++;
136     }
137   
138   return num;
139 }
140
141
142 /*
143  * Set the number of multi-resolution levels
144  */
145 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
146 void
147 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
148 ::SetNumberOfLevels(
149 unsigned int num )
150 {
151   if( m_NumberOfLevels != num )
152     {
153     this->Modified();
154     m_NumberOfLevels = num;
155     m_NumberOfIterations.resize( m_NumberOfLevels );
156     }
157
158   if( m_MovingImagePyramid && m_MovingImagePyramid->GetNumberOfLevels() != num )
159     {
160     m_MovingImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
161     }
162   if( m_FixedImagePyramid && m_FixedImagePyramid->GetNumberOfLevels() != num )
163     {
164     m_FixedImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
165     }  
166 }
167
168
169 /*
170  * Standard PrintSelf method.
171  */
172 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
173 void
174 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
175 ::PrintSelf(std::ostream& os, itk::Indent indent) const
176 {
177   Superclass::PrintSelf(os, indent);
178   os << indent << "NumberOfLevels: " << m_NumberOfLevels << std::endl;
179   os << indent << "CurrentLevel: " << m_CurrentLevel << std::endl;
180
181   os << indent << "NumberOfIterations: [";
182   unsigned int ilevel;
183   for( ilevel = 0; ilevel < m_NumberOfLevels - 1; ilevel++ )
184     {
185     os << m_NumberOfIterations[ilevel] << ", ";
186     }
187   os << m_NumberOfIterations[ilevel] << "]" << std::endl;
188   
189   os << indent << "RegistrationFilter: ";
190   os << m_RegistrationFilter.GetPointer() << std::endl;
191   os << indent << "MovingImagePyramid: ";
192   os << m_MovingImagePyramid.GetPointer() << std::endl;
193   os << indent << "FixedImagePyramid: ";
194   os << m_FixedImagePyramid.GetPointer() << std::endl;
195
196   os << indent << "FieldExpander: ";
197   os << m_FieldExpander.GetPointer() << std::endl;
198
199   os << indent << "StopRegistrationFlag: ";
200   os << m_StopRegistrationFlag << std::endl;
201
202 }
203
204 /*
205  * Perform a the deformable registration using a multiresolution scheme
206  * using an internal mini-pipeline
207  *
208  *  ref_pyramid ->  registrator  ->  field_expander --|| tempField
209  * test_pyramid ->           |                              |
210  *                           |                              |
211  *                           --------------------------------    
212  *
213  * A tempField image is used to break the cycle between the
214  * registrator and field_expander.
215  *
216  */                              
217 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
218 void
219 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
220 ::GenerateData()
221 {
222   // Check for NULL images and pointers
223   MovingImageConstPointer movingImage = this->GetMovingImage();
224   FixedImageConstPointer  fixedImage = this->GetFixedImage();
225
226   if( !movingImage || !fixedImage )
227     {
228     itkExceptionMacro( << "Fixed and/or moving image not set" );
229     }
230
231   if( !m_MovingImagePyramid || !m_FixedImagePyramid )
232     {
233     itkExceptionMacro( << "Fixed and/or moving pyramid not set" );
234     }
235
236   if( !m_RegistrationFilter )
237     {
238     itkExceptionMacro( << "Registration filter not set" );
239     }
240   
241   if( this->m_InitialDeformationField && this->GetInput(0) )
242     {
243     itkExceptionMacro( << "Only one initial deformation can be given. "
244                        << "SetInitialDeformationField should not be used in "
245                        << "cunjunction with SetArbitraryInitialDeformationField "
246                        << "or SetInput.");
247     }
248
249   //Update the number of levels for the pyramids
250   this->SetNumberOfLevels(m_NumberOfLevels);
251
252   // Create the image pyramids.
253   m_MovingImagePyramid->SetInput( movingImage );
254   m_MovingImagePyramid->UpdateLargestPossibleRegion();
255
256   m_FixedImagePyramid->SetInput( fixedImage );
257   m_FixedImagePyramid->UpdateLargestPossibleRegion();
258  
259   // Initializations
260   m_CurrentLevel = 0;
261   m_StopRegistrationFlag = false;
262
263   unsigned int movingLevel = vnl_math_min( (int) m_CurrentLevel, 
264                                            (int) m_MovingImagePyramid->GetNumberOfLevels() );
265
266   unsigned int fixedLevel = vnl_math_min( (int) m_CurrentLevel, 
267                                           (int) m_FixedImagePyramid->GetNumberOfLevels() );
268
269   DeformationFieldPointer tempField = NULL;
270
271   DeformationFieldPointer inputPtr =
272     const_cast< DeformationFieldType * >( this->GetInput(0) );
273   
274   if ( this->m_InitialDeformationField )
275     {
276     tempField = this->m_InitialDeformationField;
277     }
278   else if( inputPtr )
279     {
280     // Arbitrary initial deformation field is set.
281     // smooth it and resample
282
283     // First smooth it
284     tempField = inputPtr;
285       
286     typedef itk::RecursiveGaussianImageFilter< DeformationFieldType,
287       DeformationFieldType> GaussianFilterType;
288     typename GaussianFilterType::Pointer smoother
289       = GaussianFilterType::New();
290       
291     for (unsigned int dim=0; dim<DeformationFieldType::ImageDimension; ++dim)
292       {
293       // sigma accounts for the subsampling of the pyramid
294       double sigma = 0.5 * static_cast<float>(
295         m_FixedImagePyramid->GetSchedule()[fixedLevel][dim] );
296
297       // but also for a possible discrepancy in the spacing
298       sigma *= fixedImage->GetSpacing()[dim]
299         / inputPtr->GetSpacing()[dim];
300       
301       smoother->SetInput( tempField );
302       smoother->SetSigma( sigma );
303       smoother->SetDirection( dim );
304       
305       smoother->Update();
306       
307       tempField = smoother->GetOutput();
308       tempField->DisconnectPipeline();
309       }
310       
311       
312     // Now resample
313     m_FieldExpander->SetInput( tempField );
314     
315     typename FloatImageType::Pointer fi = 
316       m_FixedImagePyramid->GetOutput( fixedLevel );
317     m_FieldExpander->SetSize( 
318       fi->GetLargestPossibleRegion().GetSize() );
319     m_FieldExpander->SetOutputStartIndex(
320       fi->GetLargestPossibleRegion().GetIndex() );
321     m_FieldExpander->SetOutputOrigin( fi->GetOrigin() );
322     m_FieldExpander->SetOutputSpacing( fi->GetSpacing());
323     m_FieldExpander->SetOutputDirection( fi->GetDirection());
324
325     m_FieldExpander->UpdateLargestPossibleRegion();
326     m_FieldExpander->SetInput( NULL );
327     tempField = m_FieldExpander->GetOutput();
328     tempField->DisconnectPipeline();
329     }
330
331
332   bool lastShrinkFactorsAllOnes = false;
333   while ( !this->Halt() )
334     {
335       
336       if( tempField.IsNull() )
337         {
338 #if ITK_VERSION_MAJOR >= 4
339           m_RegistrationFilter->SetInitialDisplacementField( NULL );
340 #else
341           m_RegistrationFilter->SetInitialDeformationField( NULL );
342 #endif
343         }
344       else
345         {
346           // Resample the field to be the same size as the fixed image
347           // at the current level
348           m_FieldExpander->SetInput( tempField );
349           
350       typename FloatImageType::Pointer fi = 
351         m_FixedImagePyramid->GetOutput( fixedLevel );
352       m_FieldExpander->SetSize( 
353         fi->GetLargestPossibleRegion().GetSize() );
354       m_FieldExpander->SetOutputStartIndex(
355         fi->GetLargestPossibleRegion().GetIndex() );
356       m_FieldExpander->SetOutputOrigin( fi->GetOrigin() );
357       m_FieldExpander->SetOutputSpacing( fi->GetSpacing());
358
359       m_FieldExpander->UpdateLargestPossibleRegion();
360       m_FieldExpander->SetInput( NULL );
361       tempField = m_FieldExpander->GetOutput();
362       tempField->DisconnectPipeline();
363
364 #if ITK_VERSION_MAJOR >= 4
365       m_RegistrationFilter->SetInitialDisplacementField( tempField );
366 #else
367       m_RegistrationFilter->SetInitialDeformationField( tempField );
368 #endif
369
370       }
371
372     // setup registration filter and pyramids 
373     m_RegistrationFilter->SetMovingImage( m_MovingImagePyramid->GetOutput(movingLevel) );
374     m_RegistrationFilter->SetFixedImage( m_FixedImagePyramid->GetOutput(fixedLevel) );
375     m_RegistrationFilter->SetNumberOfIterations(
376       m_NumberOfIterations[m_CurrentLevel] );
377
378     // cache shrink factors for computing the next expand factors.
379     lastShrinkFactorsAllOnes = true;
380     for( unsigned int idim = 0; idim < ImageDimension; idim++ )
381       {
382       if ( m_FixedImagePyramid->GetSchedule()[fixedLevel][idim] > 1 )
383         {
384         lastShrinkFactorsAllOnes = false;
385         break;
386         }
387       }
388
389     // Invoke an iteration event.
390     this->InvokeEvent( itk::IterationEvent() );
391
392     // compute new deformation field
393     m_RegistrationFilter->UpdateLargestPossibleRegion();
394     tempField = m_RegistrationFilter->GetOutput();
395     tempField->DisconnectPipeline();
396
397     // Increment level counter.  
398     m_CurrentLevel++;
399     movingLevel = vnl_math_min( (int) m_CurrentLevel, 
400                                 (int) m_MovingImagePyramid->GetNumberOfLevels() );
401     fixedLevel = vnl_math_min( (int) m_CurrentLevel, 
402                                (int) m_FixedImagePyramid->GetNumberOfLevels() );
403
404     // We can release data from pyramid which are no longer required.
405     if ( movingLevel > 0 )
406       {
407       m_MovingImagePyramid->GetOutput( movingLevel - 1 )->ReleaseData();
408       }
409     if( fixedLevel > 0 )
410       {
411       m_FixedImagePyramid->GetOutput( fixedLevel - 1 )->ReleaseData();
412       }
413
414     } // while not Halt()
415
416     if( !lastShrinkFactorsAllOnes )
417       {
418       // Some of the last shrink factors are not one
419       // graft the output of the expander filter to
420       // to output of this filter
421
422       // resample the field to the same size as the fixed image
423       m_FieldExpander->SetInput( tempField );
424       m_FieldExpander->SetSize( 
425         fixedImage->GetLargestPossibleRegion().GetSize() );
426       m_FieldExpander->SetOutputStartIndex(
427         fixedImage->GetLargestPossibleRegion().GetIndex() );
428       m_FieldExpander->SetOutputOrigin( fixedImage->GetOrigin() );
429       m_FieldExpander->SetOutputSpacing( fixedImage->GetSpacing());
430       m_FieldExpander->UpdateLargestPossibleRegion();
431       this->GraftOutput( m_FieldExpander->GetOutput() );
432       }
433     else
434       {
435       // all the last shrink factors are all ones
436       // graft the output of registration filter to
437       // to output of this filter
438       this->GraftOutput( tempField );
439       }
440
441     // Release memory
442     m_FieldExpander->SetInput( NULL );
443     m_FieldExpander->GetOutput()->ReleaseData();
444     m_RegistrationFilter->SetInput( NULL );
445     m_RegistrationFilter->GetOutput()->ReleaseData();
446
447 }
448
449
450 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
451 void
452 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
453 ::StopRegistration()
454 {
455   m_RegistrationFilter->StopRegistration();
456   m_StopRegistrationFlag = true;
457 }
458
459 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
460 bool
461 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
462 ::Halt()
463 {
464   // Halt the registration after the user-specified number of levels
465   if (m_NumberOfLevels != 0)
466   {
467   this->UpdateProgress( static_cast<float>( m_CurrentLevel ) /
468                         static_cast<float>( m_NumberOfLevels ) );
469   }
470
471   if ( m_CurrentLevel >= m_NumberOfLevels )
472     {
473     return true;
474     }
475   if ( m_StopRegistrationFlag )
476     {
477     return true;
478     }
479   else
480     { 
481     return false; 
482     }
483
484 }
485
486
487 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
488 void
489 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
490 ::GenerateOutputInformation()
491 {
492
493   typename itk::DataObject::Pointer output;
494
495  if( this->GetInput(0) )
496   {
497   // Initial deformation field is set.
498   // Copy information from initial field.
499   this->Superclass::GenerateOutputInformation();
500
501   }
502  else if( this->GetFixedImage() )
503   {
504   // Initial deforamtion field is not set. 
505   // Copy information from the fixed image.
506   for (unsigned int idx = 0; idx < 
507     this->GetNumberOfOutputs(); ++idx )
508     {
509     output = this->GetOutput(idx);
510     if (output)
511       {
512       output->CopyInformation(this->GetFixedImage());
513       }  
514     }
515
516   }
517
518 }
519
520
521 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
522 void
523 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
524 ::GenerateInputRequestedRegion()
525 {
526
527   // call the superclass's implementation
528   Superclass::GenerateInputRequestedRegion();
529
530   // request the largest possible region for the moving image
531   MovingImagePointer movingPtr = 
532     const_cast< MovingImageType * >( this->GetMovingImage() );
533   if( movingPtr )
534     {
535     movingPtr->SetRequestedRegionToLargestPossibleRegion();
536     }
537   
538   // just propagate up the output requested region for
539   // the fixed image and initial deformation field.
540   DeformationFieldPointer inputPtr = 
541       const_cast< DeformationFieldType * >( this->GetInput() );
542   DeformationFieldPointer outputPtr = this->GetOutput();
543   FixedImagePointer fixedPtr = 
544         const_cast< FixedImageType *>( this->GetFixedImage() );
545
546   if( inputPtr )
547     {
548     inputPtr->SetRequestedRegion( outputPtr->GetRequestedRegion() );
549     }
550
551   if( fixedPtr )
552     {
553     fixedPtr->SetRequestedRegion( outputPtr->GetRequestedRegion() );
554     }
555
556 }
557
558
559 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
560 void
561 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
562 ::EnlargeOutputRequestedRegion(
563                                itk::DataObject * ptr )
564 {
565   // call the superclass's implementation
566   Superclass::EnlargeOutputRequestedRegion( ptr );
567
568   // set the output requested region to largest possible.
569   DeformationFieldType * outputPtr;
570   outputPtr = dynamic_cast<DeformationFieldType*>( ptr );
571
572   if( outputPtr )
573     {
574     outputPtr->SetRequestedRegionToLargestPossibleRegion();
575     }
576
577 }
578
579
580 } // end namespace itk
581
582 #endif