]> Creatis software - clitk.git/blob - registration/clitkMultiResolutionPDEDeformableRegistration.txx
c8312e6e173085225a1a53181e01897c83123a6c
[clitk.git] / registration / clitkMultiResolutionPDEDeformableRegistration.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to: 
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://www.centreleonberard.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef _clitkMultiResolutionPDEDeformableRegistration_txx
19 #define _clitkMultiResolutionPDEDeformableRegistration_txx
20 #include "clitkMultiResolutionPDEDeformableRegistration.h"
21
22 #include "itkRecursiveMultiResolutionPyramidImageFilter.h"
23 #include "itkImageRegionIterator.h"
24 #include "vnl/vnl_math.h"
25
26 namespace clitk {
27
28 /*
29  * Default constructor
30  */
31 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
32 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
33 ::MultiResolutionPDEDeformableRegistration()
34 {
35  
36   this->SetNumberOfRequiredInputs(2);
37
38   typename DefaultRegistrationType::Pointer registrator =
39     DefaultRegistrationType::New();
40   m_RegistrationFilter = static_cast<RegistrationType*>(
41     registrator.GetPointer() );
42
43   m_MovingImagePyramid  = MovingImagePyramidType::New();
44   m_FixedImagePyramid     = FixedImagePyramidType::New();
45   m_FieldExpander     = FieldExpanderType::New();
46   m_InitialDeformationField = ITK_NULLPTR;
47
48   m_NumberOfLevels = 3;
49   m_NumberOfIterations.resize( m_NumberOfLevels );
50   m_FixedImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
51   m_MovingImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
52
53   unsigned int ilevel;
54   for( ilevel = 0; ilevel < m_NumberOfLevels; ilevel++ )
55     {
56     m_NumberOfIterations[ilevel] = 10;
57     }
58   m_CurrentLevel = 0;
59
60   m_StopRegistrationFlag = false;
61
62 }
63
64
65 /*
66  * Set the moving image image.
67  */
68 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
69 void
70 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
71 ::SetMovingImage(
72 const MovingImageType * ptr )
73 {
74   this->itk::ProcessObject::SetNthInput( 2, const_cast< MovingImageType * >( ptr ) );
75 }
76
77
78 /*
79  * Get the moving image image.
80  */
81 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
82 const typename MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
83 ::MovingImageType *
84 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
85 ::GetMovingImage(void) const
86 {
87   return dynamic_cast< const MovingImageType * >
88     ( this->itk::ProcessObject::GetInput( 2 ) );
89 }
90
91
92 /*
93  * Set the fixed image.
94  */
95 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
96 void
97 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
98 ::SetFixedImage(
99 const FixedImageType * ptr )
100 {
101   this->itk::ProcessObject::SetNthInput( 1, const_cast< FixedImageType * >( ptr ) );
102 }
103
104
105 /*
106  * Get the fixed image.
107  */
108 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
109 const typename MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
110 ::FixedImageType *
111 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
112 ::GetFixedImage(void) const
113 {
114   return dynamic_cast< const FixedImageType * >
115     ( this->itk::ProcessObject::GetInput( 1 ) );
116 }
117
118 /*
119  * 
120  */
121 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
122 std::vector<itk::SmartPointer<itk::DataObject> >::size_type
123 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
124 ::GetNumberOfValidRequiredInputs() const
125 {
126   typename std::vector<itk::SmartPointer<itk::DataObject> >::size_type num = 0;
127
128   if (this->GetFixedImage())
129     {
130     num++;
131     }
132
133   if (this->GetMovingImage())
134     {
135     num++;
136     }
137   
138   return num;
139 }
140
141
142 /*
143  * Set the number of multi-resolution levels
144  */
145 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
146 void
147 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
148 ::SetNumberOfLevels(
149 unsigned int num )
150 {
151   if( m_NumberOfLevels != num )
152     {
153     this->Modified();
154     m_NumberOfLevels = num;
155     m_NumberOfIterations.resize( m_NumberOfLevels );
156     }
157
158   if( m_MovingImagePyramid && m_MovingImagePyramid->GetNumberOfLevels() != num )
159     {
160     m_MovingImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
161     }
162   if( m_FixedImagePyramid && m_FixedImagePyramid->GetNumberOfLevels() != num )
163     {
164     m_FixedImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
165     }  
166 }
167
168
169 /*
170  * Standard PrintSelf method.
171  */
172 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
173 void
174 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
175 ::PrintSelf(std::ostream& os, itk::Indent indent) const
176 {
177   Superclass::PrintSelf(os, indent);
178   os << indent << "NumberOfLevels: " << m_NumberOfLevels << std::endl;
179   os << indent << "CurrentLevel: " << m_CurrentLevel << std::endl;
180
181   os << indent << "NumberOfIterations: [";
182   unsigned int ilevel;
183   for( ilevel = 0; ilevel < m_NumberOfLevels - 1; ilevel++ )
184     {
185     os << m_NumberOfIterations[ilevel] << ", ";
186     }
187   os << m_NumberOfIterations[ilevel] << "]" << std::endl;
188   
189   os << indent << "RegistrationFilter: ";
190   os << m_RegistrationFilter.GetPointer() << std::endl;
191   os << indent << "MovingImagePyramid: ";
192   os << m_MovingImagePyramid.GetPointer() << std::endl;
193   os << indent << "FixedImagePyramid: ";
194   os << m_FixedImagePyramid.GetPointer() << std::endl;
195
196   os << indent << "FieldExpander: ";
197   os << m_FieldExpander.GetPointer() << std::endl;
198
199   os << indent << "StopRegistrationFlag: ";
200   os << m_StopRegistrationFlag << std::endl;
201
202 }
203
204 /*
205  * Perform a the deformable registration using a multiresolution scheme
206  * using an internal mini-pipeline
207  *
208  *  ref_pyramid ->  registrator  ->  field_expander --|| tempField
209  * test_pyramid ->           |                              |
210  *                           |                              |
211  *                           --------------------------------    
212  *
213  * A tempField image is used to break the cycle between the
214  * registrator and field_expander.
215  *
216  */                              
217 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
218 void
219 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
220 ::GenerateData()
221 {
222   // Check for ITK_NULLPTR images and pointers
223   MovingImageConstPointer movingImage = this->GetMovingImage();
224   FixedImageConstPointer  fixedImage = this->GetFixedImage();
225
226   if( !movingImage || !fixedImage )
227     {
228     itkExceptionMacro( << "Fixed and/or moving image not set" );
229     }
230
231   if( !m_MovingImagePyramid || !m_FixedImagePyramid )
232     {
233     itkExceptionMacro( << "Fixed and/or moving pyramid not set" );
234     }
235
236   if( !m_RegistrationFilter )
237     {
238     itkExceptionMacro( << "Registration filter not set" );
239     }
240   
241   if( this->m_InitialDeformationField && this->GetInput(0) )
242     {
243     itkExceptionMacro( << "Only one initial deformation can be given. "
244                        << "SetInitialDeformationField should not be used in "
245                        << "cunjunction with SetArbitraryInitialDeformationField "
246                        << "or SetInput.");
247     }
248
249   //Update the number of levels for the pyramids
250   this->SetNumberOfLevels(m_NumberOfLevels);
251
252   // Create the image pyramids.
253   m_MovingImagePyramid->SetInput( movingImage );
254   m_MovingImagePyramid->UpdateLargestPossibleRegion();
255
256   m_FixedImagePyramid->SetInput( fixedImage );
257   m_FixedImagePyramid->UpdateLargestPossibleRegion();
258  
259   // Initializations
260   m_CurrentLevel = 0;
261   m_StopRegistrationFlag = false;
262
263   unsigned int movingLevel = vnl_math_min( (int) m_CurrentLevel, 
264                                            (int) m_MovingImagePyramid->GetNumberOfLevels() );
265
266   unsigned int fixedLevel = vnl_math_min( (int) m_CurrentLevel, 
267                                           (int) m_FixedImagePyramid->GetNumberOfLevels() );
268
269   DeformationFieldPointer tempField = ITK_NULLPTR;
270
271   DeformationFieldPointer inputPtr =
272     const_cast< DeformationFieldType * >( this->GetInput(0) );
273   
274   if ( this->m_InitialDeformationField )
275     {
276     tempField = this->m_InitialDeformationField;
277     }
278   else if( inputPtr )
279     {
280     // Arbitrary initial deformation field is set.
281     // smooth it and resample
282
283     // First smooth it
284     tempField = inputPtr;
285       
286     typedef itk::RecursiveGaussianImageFilter< DeformationFieldType,
287       DeformationFieldType> GaussianFilterType;
288     typename GaussianFilterType::Pointer smoother
289       = GaussianFilterType::New();
290       
291     for (unsigned int dim=0; dim<DeformationFieldType::ImageDimension; ++dim)
292       {
293       // sigma accounts for the subsampling of the pyramid
294       double sigma = 0.5 * static_cast<float>(
295         m_FixedImagePyramid->GetSchedule()[fixedLevel][dim] );
296
297       // but also for a possible discrepancy in the spacing
298       sigma *= fixedImage->GetSpacing()[dim]
299         / inputPtr->GetSpacing()[dim];
300       
301       smoother->SetInput( tempField );
302       smoother->SetSigma( sigma );
303       smoother->SetDirection( dim );
304       
305       smoother->Update();
306       
307       tempField = smoother->GetOutput();
308       tempField->DisconnectPipeline();
309       }
310       
311       
312     // Now resample
313     m_FieldExpander->SetInput( tempField );
314     
315     typename FloatImageType::Pointer fi = 
316       m_FixedImagePyramid->GetOutput( fixedLevel );
317     m_FieldExpander->SetSize( 
318       fi->GetLargestPossibleRegion().GetSize() );
319     m_FieldExpander->SetOutputStartIndex(
320       fi->GetLargestPossibleRegion().GetIndex() );
321     m_FieldExpander->SetOutputOrigin( fi->GetOrigin() );
322     m_FieldExpander->SetOutputSpacing( fi->GetSpacing());
323     m_FieldExpander->SetOutputDirection( fi->GetDirection());
324
325     m_FieldExpander->UpdateLargestPossibleRegion();
326     m_FieldExpander->SetInput( ITK_NULLPTR );
327     tempField = m_FieldExpander->GetOutput();
328     tempField->DisconnectPipeline();
329     }
330
331
332   bool lastShrinkFactorsAllOnes = false;
333   while ( !this->Halt() )
334     {
335       
336       if( tempField.IsNull() )
337         {
338           m_RegistrationFilter->SetInitialDisplacementField( ITK_NULLPTR );
339         }
340       else
341         {
342           // Resample the field to be the same size as the fixed image
343           // at the current level
344           m_FieldExpander->SetInput( tempField );
345           
346       typename FloatImageType::Pointer fi = 
347         m_FixedImagePyramid->GetOutput( fixedLevel );
348       m_FieldExpander->SetSize( 
349         fi->GetLargestPossibleRegion().GetSize() );
350       m_FieldExpander->SetOutputStartIndex(
351         fi->GetLargestPossibleRegion().GetIndex() );
352       m_FieldExpander->SetOutputOrigin( fi->GetOrigin() );
353       m_FieldExpander->SetOutputSpacing( fi->GetSpacing());
354
355       m_FieldExpander->UpdateLargestPossibleRegion();
356       m_FieldExpander->SetInput( ITK_NULLPTR );
357       tempField = m_FieldExpander->GetOutput();
358       tempField->DisconnectPipeline();
359
360       m_RegistrationFilter->SetInitialDisplacementField( tempField );
361       }
362
363     // setup registration filter and pyramids 
364     m_RegistrationFilter->SetMovingImage( m_MovingImagePyramid->GetOutput(movingLevel) );
365     m_RegistrationFilter->SetFixedImage( m_FixedImagePyramid->GetOutput(fixedLevel) );
366     m_RegistrationFilter->SetNumberOfIterations(
367       m_NumberOfIterations[m_CurrentLevel] );
368
369     // cache shrink factors for computing the next expand factors.
370     lastShrinkFactorsAllOnes = true;
371     for( unsigned int idim = 0; idim < ImageDimension; idim++ )
372       {
373       if ( m_FixedImagePyramid->GetSchedule()[fixedLevel][idim] > 1 )
374         {
375         lastShrinkFactorsAllOnes = false;
376         break;
377         }
378       }
379
380     // Invoke an iteration event.
381     this->InvokeEvent( itk::IterationEvent() );
382
383     // compute new deformation field
384     m_RegistrationFilter->UpdateLargestPossibleRegion();
385     tempField = m_RegistrationFilter->GetOutput();
386     tempField->DisconnectPipeline();
387
388     // Increment level counter.  
389     m_CurrentLevel++;
390     movingLevel = vnl_math_min( (int) m_CurrentLevel, 
391                                 (int) m_MovingImagePyramid->GetNumberOfLevels() );
392     fixedLevel = vnl_math_min( (int) m_CurrentLevel, 
393                                (int) m_FixedImagePyramid->GetNumberOfLevels() );
394
395     // We can release data from pyramid which are no longer required.
396     if ( movingLevel > 0 )
397       {
398       m_MovingImagePyramid->GetOutput( movingLevel - 1 )->ReleaseData();
399       }
400     if( fixedLevel > 0 )
401       {
402       m_FixedImagePyramid->GetOutput( fixedLevel - 1 )->ReleaseData();
403       }
404
405     } // while not Halt()
406
407     if( !lastShrinkFactorsAllOnes )
408       {
409       // Some of the last shrink factors are not one
410       // graft the output of the expander filter to
411       // to output of this filter
412
413       // resample the field to the same size as the fixed image
414       m_FieldExpander->SetInput( tempField );
415       m_FieldExpander->SetSize( 
416         fixedImage->GetLargestPossibleRegion().GetSize() );
417       m_FieldExpander->SetOutputStartIndex(
418         fixedImage->GetLargestPossibleRegion().GetIndex() );
419       m_FieldExpander->SetOutputOrigin( fixedImage->GetOrigin() );
420       m_FieldExpander->SetOutputSpacing( fixedImage->GetSpacing());
421       m_FieldExpander->UpdateLargestPossibleRegion();
422       this->GraftOutput( m_FieldExpander->GetOutput() );
423       }
424     else
425       {
426       // all the last shrink factors are all ones
427       // graft the output of registration filter to
428       // to output of this filter
429       this->GraftOutput( tempField );
430       }
431
432     // Release memory
433     m_FieldExpander->SetInput( ITK_NULLPTR );
434     m_FieldExpander->GetOutput()->ReleaseData();
435     m_RegistrationFilter->SetInput( ITK_NULLPTR );
436     m_RegistrationFilter->GetOutput()->ReleaseData();
437
438 }
439
440
441 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
442 void
443 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
444 ::StopRegistration()
445 {
446   m_RegistrationFilter->StopRegistration();
447   m_StopRegistrationFlag = true;
448 }
449
450 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
451 bool
452 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
453 ::Halt()
454 {
455   // Halt the registration after the user-specified number of levels
456   if (m_NumberOfLevels != 0)
457   {
458   this->UpdateProgress( static_cast<float>( m_CurrentLevel ) /
459                         static_cast<float>( m_NumberOfLevels ) );
460   }
461
462   if ( m_CurrentLevel >= m_NumberOfLevels )
463     {
464     return true;
465     }
466   if ( m_StopRegistrationFlag )
467     {
468     return true;
469     }
470   else
471     { 
472     return false; 
473     }
474
475 }
476
477
478 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
479 void
480 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
481 ::GenerateOutputInformation()
482 {
483
484   typename itk::DataObject::Pointer output;
485
486  if( this->GetInput(0) )
487   {
488   // Initial deformation field is set.
489   // Copy information from initial field.
490   this->Superclass::GenerateOutputInformation();
491
492   }
493  else if( this->GetFixedImage() )
494   {
495   // Initial deforamtion field is not set. 
496   // Copy information from the fixed image.
497   for (unsigned int idx = 0; idx < 
498     this->GetNumberOfOutputs(); ++idx )
499     {
500     output = this->GetOutput(idx);
501     if (output)
502       {
503       output->CopyInformation(this->GetFixedImage());
504       }  
505     }
506
507   }
508
509 }
510
511
512 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
513 void
514 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
515 ::GenerateInputRequestedRegion()
516 {
517
518   // call the superclass's implementation
519   Superclass::GenerateInputRequestedRegion();
520
521   // request the largest possible region for the moving image
522   MovingImagePointer movingPtr = 
523     const_cast< MovingImageType * >( this->GetMovingImage() );
524   if( movingPtr )
525     {
526     movingPtr->SetRequestedRegionToLargestPossibleRegion();
527     }
528   
529   // just propagate up the output requested region for
530   // the fixed image and initial deformation field.
531   DeformationFieldPointer inputPtr = 
532       const_cast< DeformationFieldType * >( this->GetInput() );
533   DeformationFieldPointer outputPtr = this->GetOutput();
534   FixedImagePointer fixedPtr = 
535         const_cast< FixedImageType *>( this->GetFixedImage() );
536
537   if( inputPtr )
538     {
539     inputPtr->SetRequestedRegion( outputPtr->GetRequestedRegion() );
540     }
541
542   if( fixedPtr )
543     {
544     fixedPtr->SetRequestedRegion( outputPtr->GetRequestedRegion() );
545     }
546
547 }
548
549
550 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
551 void
552 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
553 ::EnlargeOutputRequestedRegion(
554                                itk::DataObject * ptr )
555 {
556   // call the superclass's implementation
557   Superclass::EnlargeOutputRequestedRegion( ptr );
558
559   // set the output requested region to largest possible.
560   DeformationFieldType * outputPtr;
561   outputPtr = dynamic_cast<DeformationFieldType*>( ptr );
562
563   if( outputPtr )
564     {
565     outputPtr->SetRequestedRegionToLargestPossibleRegion();
566     }
567
568 }
569
570
571 } // end namespace itk
572
573 #endif