]> Creatis software - clitk.git/blob - registration/clitkMultiResolutionPDEDeformableRegistration.txx
remove debug DD
[clitk.git] / registration / clitkMultiResolutionPDEDeformableRegistration.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to: 
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://oncora1.lyon.fnclcc.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ======================================================================-====*/
18 #ifndef _clitkMultiResolutionPDEDeformableRegistration_txx
19 #define _clitkMultiResolutionPDEDeformableRegistration_txx
20 #include "clitkMultiResolutionPDEDeformableRegistration.h"
21
22 #include "itkRecursiveMultiResolutionPyramidImageFilter.h"
23 #include "itkImageRegionIterator.h"
24 #include "vnl/vnl_math.h"
25
26 namespace clitk {
27
28 /*
29  * Default constructor
30  */
31 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
32 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
33 ::MultiResolutionPDEDeformableRegistration()
34 {
35  
36   this->SetNumberOfRequiredInputs(2);
37
38   typename DefaultRegistrationType::Pointer registrator =
39     DefaultRegistrationType::New();
40   m_RegistrationFilter = static_cast<RegistrationType*>(
41     registrator.GetPointer() );
42
43   m_MovingImagePyramid  = MovingImagePyramidType::New();
44   m_FixedImagePyramid     = FixedImagePyramidType::New();
45   m_FieldExpander     = FieldExpanderType::New();
46   m_InitialDeformationField = NULL;
47
48   m_NumberOfLevels = 3;
49   m_NumberOfIterations.resize( m_NumberOfLevels );
50   m_FixedImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
51   m_MovingImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
52
53   unsigned int ilevel;
54   for( ilevel = 0; ilevel < m_NumberOfLevels; ilevel++ )
55     {
56     m_NumberOfIterations[ilevel] = 10;
57     }
58   m_CurrentLevel = 0;
59
60   m_StopRegistrationFlag = false;
61
62 }
63
64
65 /*
66  * Set the moving image image.
67  */
68 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
69 void
70 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
71 ::SetMovingImage(
72 const MovingImageType * ptr )
73 {
74   this->itk::ProcessObject::SetNthInput( 2, const_cast< MovingImageType * >( ptr ) );
75 }
76
77
78 /*
79  * Get the moving image image.
80  */
81 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
82 const typename MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
83 ::MovingImageType *
84 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
85 ::GetMovingImage(void) const
86 {
87   return dynamic_cast< const MovingImageType * >
88     ( this->itk::ProcessObject::GetInput( 2 ) );
89 }
90
91
92 /*
93  * Set the fixed image.
94  */
95 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
96 void
97 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
98 ::SetFixedImage(
99 const FixedImageType * ptr )
100 {
101   this->itk::ProcessObject::SetNthInput( 1, const_cast< FixedImageType * >( ptr ) );
102 }
103
104
105 /*
106  * Get the fixed image.
107  */
108 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
109 const typename MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
110 ::FixedImageType *
111 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
112 ::GetFixedImage(void) const
113 {
114   return dynamic_cast< const FixedImageType * >
115     ( this->itk::ProcessObject::GetInput( 1 ) );
116 }
117
118 /*
119  * 
120  */
121 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
122 std::vector<itk::SmartPointer<itk::DataObject> >::size_type
123 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
124 ::GetNumberOfValidRequiredInputs() const
125 {
126   typename std::vector<itk::SmartPointer<itk::DataObject> >::size_type num = 0;
127
128   if (this->GetFixedImage())
129     {
130     num++;
131     }
132
133   if (this->GetMovingImage())
134     {
135     num++;
136     }
137   
138   return num;
139 }
140
141
142 /*
143  * Set the number of multi-resolution levels
144  */
145 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
146 void
147 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
148 ::SetNumberOfLevels(
149 unsigned int num )
150 {
151   if( m_NumberOfLevels != num )
152     {
153     this->Modified();
154     m_NumberOfLevels = num;
155     m_NumberOfIterations.resize( m_NumberOfLevels );
156     }
157
158   if( m_MovingImagePyramid && m_MovingImagePyramid->GetNumberOfLevels() != num )
159     {
160     m_MovingImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
161     }
162   if( m_FixedImagePyramid && m_FixedImagePyramid->GetNumberOfLevels() != num )
163     {
164     m_FixedImagePyramid->SetNumberOfLevels( m_NumberOfLevels );
165     }  
166 }
167
168
169 /*
170  * Standard PrintSelf method.
171  */
172 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
173 void
174 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
175 ::PrintSelf(std::ostream& os, itk::Indent indent) const
176 {
177   Superclass::PrintSelf(os, indent);
178   os << indent << "NumberOfLevels: " << m_NumberOfLevels << std::endl;
179   os << indent << "CurrentLevel: " << m_CurrentLevel << std::endl;
180
181   os << indent << "NumberOfIterations: [";
182   unsigned int ilevel;
183   for( ilevel = 0; ilevel < m_NumberOfLevels - 1; ilevel++ )
184     {
185     os << m_NumberOfIterations[ilevel] << ", ";
186     }
187   os << m_NumberOfIterations[ilevel] << "]" << std::endl;
188   
189   os << indent << "RegistrationFilter: ";
190   os << m_RegistrationFilter.GetPointer() << std::endl;
191   os << indent << "MovingImagePyramid: ";
192   os << m_MovingImagePyramid.GetPointer() << std::endl;
193   os << indent << "FixedImagePyramid: ";
194   os << m_FixedImagePyramid.GetPointer() << std::endl;
195
196   os << indent << "FieldExpander: ";
197   os << m_FieldExpander.GetPointer() << std::endl;
198
199   os << indent << "StopRegistrationFlag: ";
200   os << m_StopRegistrationFlag << std::endl;
201
202 }
203
204 /*
205  * Perform a the deformable registration using a multiresolution scheme
206  * using an internal mini-pipeline
207  *
208  *  ref_pyramid ->  registrator  ->  field_expander --|| tempField
209  * test_pyramid ->           |                              |
210  *                           |                              |
211  *                           --------------------------------    
212  *
213  * A tempField image is used to break the cycle between the
214  * registrator and field_expander.
215  *
216  */                              
217 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
218 void
219 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
220 ::GenerateData()
221 {
222   // Check for NULL images and pointers
223   MovingImageConstPointer movingImage = this->GetMovingImage();
224   FixedImageConstPointer  fixedImage = this->GetFixedImage();
225
226   if( !movingImage || !fixedImage )
227     {
228     itkExceptionMacro( << "Fixed and/or moving image not set" );
229     }
230
231   if( !m_MovingImagePyramid || !m_FixedImagePyramid )
232     {
233     itkExceptionMacro( << "Fixed and/or moving pyramid not set" );
234     }
235
236   if( !m_RegistrationFilter )
237     {
238     itkExceptionMacro( << "Registration filter not set" );
239     }
240   
241   if( this->m_InitialDeformationField && this->GetInput(0) )
242     {
243     itkExceptionMacro( << "Only one initial deformation can be given. "
244                        << "SetInitialDeformationField should not be used in "
245                        << "cunjunction with SetArbitraryInitialDeformationField "
246                        << "or SetInput.");
247     }
248
249   //Update the number of levels for the pyramids
250   this->SetNumberOfLevels(m_NumberOfLevels);
251
252   // Create the image pyramids.
253   m_MovingImagePyramid->SetInput( movingImage );
254   m_MovingImagePyramid->UpdateLargestPossibleRegion();
255
256   m_FixedImagePyramid->SetInput( fixedImage );
257   m_FixedImagePyramid->UpdateLargestPossibleRegion();
258  
259   // Initializations
260   m_CurrentLevel = 0;
261   m_StopRegistrationFlag = false;
262
263   unsigned int movingLevel = vnl_math_min( (int) m_CurrentLevel, 
264                                            (int) m_MovingImagePyramid->GetNumberOfLevels() );
265
266   unsigned int fixedLevel = vnl_math_min( (int) m_CurrentLevel, 
267                                           (int) m_FixedImagePyramid->GetNumberOfLevels() );
268
269   DeformationFieldPointer tempField = NULL;
270
271   DeformationFieldPointer inputPtr =
272     const_cast< DeformationFieldType * >( this->GetInput(0) );
273   
274   if ( this->m_InitialDeformationField )
275     {
276     tempField = this->m_InitialDeformationField;
277     }
278   else if( inputPtr )
279     {
280     // Arbitrary initial deformation field is set.
281     // smooth it and resample
282
283     // First smooth it
284     tempField = inputPtr;
285       
286     typedef itk::RecursiveGaussianImageFilter< DeformationFieldType,
287       DeformationFieldType> GaussianFilterType;
288     typename GaussianFilterType::Pointer smoother
289       = GaussianFilterType::New();
290       
291     for (unsigned int dim=0; dim<DeformationFieldType::ImageDimension; ++dim)
292       {
293       // sigma accounts for the subsampling of the pyramid
294       double sigma = 0.5 * static_cast<float>(
295         m_FixedImagePyramid->GetSchedule()[fixedLevel][dim] );
296
297       // but also for a possible discrepancy in the spacing
298       sigma *= fixedImage->GetSpacing()[dim]
299         / inputPtr->GetSpacing()[dim];
300       
301       smoother->SetInput( tempField );
302       smoother->SetSigma( sigma );
303       smoother->SetDirection( dim );
304       
305       smoother->Update();
306       
307       tempField = smoother->GetOutput();
308       tempField->DisconnectPipeline();
309       }
310       
311       
312     // Now resample
313     m_FieldExpander->SetInput( tempField );
314     
315     typename FloatImageType::Pointer fi = 
316       m_FixedImagePyramid->GetOutput( fixedLevel );
317     m_FieldExpander->SetSize( 
318       fi->GetLargestPossibleRegion().GetSize() );
319     m_FieldExpander->SetOutputStartIndex(
320       fi->GetLargestPossibleRegion().GetIndex() );
321     m_FieldExpander->SetOutputOrigin( fi->GetOrigin() );
322     m_FieldExpander->SetOutputSpacing( fi->GetSpacing());
323     m_FieldExpander->SetOutputDirection( fi->GetDirection());
324
325     m_FieldExpander->UpdateLargestPossibleRegion();
326     m_FieldExpander->SetInput( NULL );
327     tempField = m_FieldExpander->GetOutput();
328     tempField->DisconnectPipeline();
329     }
330
331
332   bool lastShrinkFactorsAllOnes = false;
333   while ( !this->Halt() )
334     {
335       
336       if( tempField.IsNull() )
337         {
338           m_RegistrationFilter->SetInitialDeformationField( NULL );
339         }
340       else
341         {
342           // Resample the field to be the same size as the fixed image
343           // at the current level
344           m_FieldExpander->SetInput( tempField );
345           
346       typename FloatImageType::Pointer fi = 
347         m_FixedImagePyramid->GetOutput( fixedLevel );
348       m_FieldExpander->SetSize( 
349         fi->GetLargestPossibleRegion().GetSize() );
350       m_FieldExpander->SetOutputStartIndex(
351         fi->GetLargestPossibleRegion().GetIndex() );
352       m_FieldExpander->SetOutputOrigin( fi->GetOrigin() );
353       m_FieldExpander->SetOutputSpacing( fi->GetSpacing());
354
355       m_FieldExpander->UpdateLargestPossibleRegion();
356       m_FieldExpander->SetInput( NULL );
357       tempField = m_FieldExpander->GetOutput();
358       tempField->DisconnectPipeline();
359
360       m_RegistrationFilter->SetInitialDeformationField( tempField );
361
362       }
363
364     // setup registration filter and pyramids 
365     m_RegistrationFilter->SetMovingImage( m_MovingImagePyramid->GetOutput(movingLevel) );
366     m_RegistrationFilter->SetFixedImage( m_FixedImagePyramid->GetOutput(fixedLevel) );
367     m_RegistrationFilter->SetNumberOfIterations(
368       m_NumberOfIterations[m_CurrentLevel] );
369
370     // cache shrink factors for computing the next expand factors.
371     lastShrinkFactorsAllOnes = true;
372     for( unsigned int idim = 0; idim < ImageDimension; idim++ )
373       {
374       if ( m_FixedImagePyramid->GetSchedule()[fixedLevel][idim] > 1 )
375         {
376         lastShrinkFactorsAllOnes = false;
377         break;
378         }
379       }
380
381     // Invoke an iteration event.
382     this->InvokeEvent( itk::IterationEvent() );
383
384     // compute new deformation field
385     m_RegistrationFilter->UpdateLargestPossibleRegion();
386     tempField = m_RegistrationFilter->GetOutput();
387     tempField->DisconnectPipeline();
388
389     // Increment level counter.  
390     m_CurrentLevel++;
391     movingLevel = vnl_math_min( (int) m_CurrentLevel, 
392                                 (int) m_MovingImagePyramid->GetNumberOfLevels() );
393     fixedLevel = vnl_math_min( (int) m_CurrentLevel, 
394                                (int) m_FixedImagePyramid->GetNumberOfLevels() );
395
396     // We can release data from pyramid which are no longer required.
397     if ( movingLevel > 0 )
398       {
399       m_MovingImagePyramid->GetOutput( movingLevel - 1 )->ReleaseData();
400       }
401     if( fixedLevel > 0 )
402       {
403       m_FixedImagePyramid->GetOutput( fixedLevel - 1 )->ReleaseData();
404       }
405
406     } // while not Halt()
407
408     if( !lastShrinkFactorsAllOnes )
409       {
410       // Some of the last shrink factors are not one
411       // graft the output of the expander filter to
412       // to output of this filter
413
414       // resample the field to the same size as the fixed image
415       m_FieldExpander->SetInput( tempField );
416       m_FieldExpander->SetSize( 
417         fixedImage->GetLargestPossibleRegion().GetSize() );
418       m_FieldExpander->SetOutputStartIndex(
419         fixedImage->GetLargestPossibleRegion().GetIndex() );
420       m_FieldExpander->SetOutputOrigin( fixedImage->GetOrigin() );
421       m_FieldExpander->SetOutputSpacing( fixedImage->GetSpacing());
422       m_FieldExpander->UpdateLargestPossibleRegion();
423       this->GraftOutput( m_FieldExpander->GetOutput() );
424       }
425     else
426       {
427       // all the last shrink factors are all ones
428       // graft the output of registration filter to
429       // to output of this filter
430       this->GraftOutput( tempField );
431       }
432
433     // Release memory
434     m_FieldExpander->SetInput( NULL );
435     m_FieldExpander->GetOutput()->ReleaseData();
436     m_RegistrationFilter->SetInput( NULL );
437     m_RegistrationFilter->GetOutput()->ReleaseData();
438
439 }
440
441
442 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
443 void
444 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
445 ::StopRegistration()
446 {
447   m_RegistrationFilter->StopRegistration();
448   m_StopRegistrationFlag = true;
449 }
450
451 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
452 bool
453 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
454 ::Halt()
455 {
456   // Halt the registration after the user-specified number of levels
457   if (m_NumberOfLevels != 0)
458   {
459   this->UpdateProgress( static_cast<float>( m_CurrentLevel ) /
460                         static_cast<float>( m_NumberOfLevels ) );
461   }
462
463   if ( m_CurrentLevel >= m_NumberOfLevels )
464     {
465     return true;
466     }
467   if ( m_StopRegistrationFlag )
468     {
469     return true;
470     }
471   else
472     { 
473     return false; 
474     }
475
476 }
477
478
479 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
480 void
481 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
482 ::GenerateOutputInformation()
483 {
484
485   typename itk::DataObject::Pointer output;
486
487  if( this->GetInput(0) )
488   {
489   // Initial deformation field is set.
490   // Copy information from initial field.
491   this->Superclass::GenerateOutputInformation();
492
493   }
494  else if( this->GetFixedImage() )
495   {
496   // Initial deforamtion field is not set. 
497   // Copy information from the fixed image.
498   for (unsigned int idx = 0; idx < 
499     this->GetNumberOfOutputs(); ++idx )
500     {
501     output = this->GetOutput(idx);
502     if (output)
503       {
504       output->CopyInformation(this->GetFixedImage());
505       }  
506     }
507
508   }
509
510 }
511
512
513 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
514 void
515 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
516 ::GenerateInputRequestedRegion()
517 {
518
519   // call the superclass's implementation
520   Superclass::GenerateInputRequestedRegion();
521
522   // request the largest possible region for the moving image
523   MovingImagePointer movingPtr = 
524     const_cast< MovingImageType * >( this->GetMovingImage() );
525   if( movingPtr )
526     {
527     movingPtr->SetRequestedRegionToLargestPossibleRegion();
528     }
529   
530   // just propagate up the output requested region for
531   // the fixed image and initial deformation field.
532   DeformationFieldPointer inputPtr = 
533       const_cast< DeformationFieldType * >( this->GetInput() );
534   DeformationFieldPointer outputPtr = this->GetOutput();
535   FixedImagePointer fixedPtr = 
536         const_cast< FixedImageType *>( this->GetFixedImage() );
537
538   if( inputPtr )
539     {
540     inputPtr->SetRequestedRegion( outputPtr->GetRequestedRegion() );
541     }
542
543   if( fixedPtr )
544     {
545     fixedPtr->SetRequestedRegion( outputPtr->GetRequestedRegion() );
546     }
547
548 }
549
550
551 template <class TFixedImage, class TMovingImage, class TDeformationField, class TRealType>
552 void
553 MultiResolutionPDEDeformableRegistration<TFixedImage,TMovingImage,TDeformationField,TRealType>
554 ::EnlargeOutputRequestedRegion(
555                                itk::DataObject * ptr )
556 {
557   // call the superclass's implementation
558   Superclass::EnlargeOutputRequestedRegion( ptr );
559
560   // set the output requested region to largest possible.
561   DeformationFieldType * outputPtr;
562   outputPtr = dynamic_cast<DeformationFieldType*>( ptr );
563
564   if( outputPtr )
565     {
566     outputPtr->SetRequestedRegionToLargestPossibleRegion();
567     }
568
569 }
570
571
572 } // end namespace itk
573
574 #endif