]> Creatis software - clitk.git/blob - registration/clitkSpatioTemporalMultiResolutionPyramidImageFilter.txx
Remove vcl_math calls
[clitk.git] / registration / clitkSpatioTemporalMultiResolutionPyramidImageFilter.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to: 
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://www.centreleonberard.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef __clitkSpatioTemporalMultiResolutionPyramidImageFilter_txx
19 #define __clitkSpatioTemporalMultiResolutionPyramidImageFilter_txx
20 #include "clitkSpatioTemporalMultiResolutionPyramidImageFilter.h"
21 #include "itkGaussianOperator.h"
22 #include "itkCastImageFilter.h"
23 #include "itkDiscreteGaussianImageFilter.h"
24 #include "itkExceptionObject.h"
25 #include "itkResampleImageFilter.h"
26 #include "itkShrinkImageFilter.h"
27 #include "itkIdentityTransform.h"
28
29 #include "vnl/vnl_math.h"
30
31 namespace clitk
32 {
33
34 /**
35  * Constructor
36  */
37 template <class TInputImage, class TOutputImage>
38 SpatioTemporalMultiResolutionPyramidImageFilter<TInputImage, TOutputImage>
39 ::SpatioTemporalMultiResolutionPyramidImageFilter()
40 {
41   m_NumberOfLevels = 0;
42   this->SetNumberOfLevels( 2 );
43   m_MaximumError = 0.1;
44   m_UseShrinkImageFilter = false;
45 }
46
47
48 /**
49  * Set the number of computation levels
50  */
51 template <class TInputImage, class TOutputImage>
52 void
53 SpatioTemporalMultiResolutionPyramidImageFilter<TInputImage, TOutputImage>
54 ::SetNumberOfLevels(
55   unsigned int num )
56 {
57   if( m_NumberOfLevels == num )
58     {
59     return;
60     }
61
62   this->Modified();
63
64   // clamp value to be at least one
65   m_NumberOfLevels = num;
66   if( m_NumberOfLevels < 1 ) m_NumberOfLevels = 1;
67
68   // resize the schedules
69   ScheduleType temp( m_NumberOfLevels, ImageDimension );
70   temp.Fill( 0 );
71   m_Schedule = temp;
72
73   // determine initial shrink factor
74   unsigned int startfactor = 1;
75   startfactor = startfactor << ( m_NumberOfLevels - 1 );
76
77   // set the starting shrink factors
78   this->SetStartingShrinkFactors( startfactor );
79
80   // set the required number of outputs
81   this->SetNumberOfRequiredOutputs( m_NumberOfLevels );
82
83   unsigned int numOutputs = static_cast<unsigned int>( this->GetNumberOfOutputs() );
84   unsigned int idx;
85   if( numOutputs < m_NumberOfLevels )
86     {
87     // add extra outputs
88     for( idx = numOutputs; idx < m_NumberOfLevels; idx++ )
89       {
90       typename itk::DataObject::Pointer output =
91         this->MakeOutput( idx );
92       this->SetNthOutput( idx, output.GetPointer() );
93       }
94
95     }
96   else if( numOutputs > m_NumberOfLevels )
97     {
98     // remove extra outputs
99     for( idx = m_NumberOfLevels; idx < numOutputs; idx++ )
100       {
101       typename itk::DataObject::Pointer output =
102         this->GetOutputs()[idx];
103       this->RemoveOutput( output );
104       }
105     }
106
107 }
108
109
110 /*
111  * Set the starting shrink factors
112  */
113 template <class TInputImage, class TOutputImage>
114 void
115 SpatioTemporalMultiResolutionPyramidImageFilter<TInputImage, TOutputImage>
116 ::SetStartingShrinkFactors(
117   unsigned int factor )
118 {
119
120   unsigned int array[ImageDimension];
121   //JV temporal dimension always 1
122   for( unsigned int dim = 0; dim < ImageDimension-1; ++dim )
123     {
124     array[dim] = factor;
125     }
126   array[ImageDimension-1]=1;
127
128   this->SetStartingShrinkFactors( array );
129
130 }
131
132
133 /**
134  * Set the starting shrink factors
135  */
136 template <class TInputImage, class TOutputImage>
137 void
138 SpatioTemporalMultiResolutionPyramidImageFilter<TInputImage, TOutputImage>
139 ::SetStartingShrinkFactors(
140   unsigned int * factors )
141 {
142
143   for( unsigned int dim = 0; dim < ImageDimension-1; ++dim )
144     {
145     m_Schedule[0][dim] = factors[dim];
146     if( m_Schedule[0][dim] == 0 )
147       {
148       m_Schedule[0][dim] = 1;
149       }
150     }
151   //JV temporal dimension always 1
152   m_Schedule[0][ImageDimension-1]=1;
153
154   for( unsigned int level = 1; level < m_NumberOfLevels; ++level )
155     {
156       //JV temporal dimension always 1
157       for( unsigned int dim = 0; dim < ImageDimension-1; ++dim )
158            {
159           m_Schedule[level][dim] = m_Schedule[level-1][dim] / 2;
160           if( m_Schedule[level][dim] == 0 )
161             {
162               m_Schedule[level][dim] = 1;
163             }
164         }
165       m_Schedule[level][ImageDimension-1]=1;
166     }
167
168   this->Modified();
169
170 }
171
172
173 /*
174  * Get the starting shrink factors
175  */
176 template <class TInputImage, class TOutputImage>
177 const unsigned int *
178 SpatioTemporalMultiResolutionPyramidImageFilter<TInputImage, TOutputImage>
179 ::GetStartingShrinkFactors() const
180 {
181   return ( m_Schedule.data_block() );
182 }
183
184
185 /*
186  * Set the multi-resolution schedule
187  */
188 template <class TInputImage, class TOutputImage>
189 void
190 SpatioTemporalMultiResolutionPyramidImageFilter<TInputImage, TOutputImage>
191 ::SetSchedule(
192   const ScheduleType& schedule )
193 {
194
195   if( schedule.rows() != m_NumberOfLevels ||
196       schedule.columns() != ImageDimension )
197     {
198     itkDebugMacro(<< "Schedule has wrong dimensions" );
199     return;
200     }
201
202   if( schedule == m_Schedule )
203     {
204     return;
205     }
206
207   this->Modified();
208   unsigned int level, dim;
209   for( level = 0; level < m_NumberOfLevels; level++ )
210     {
211       //JV temporal dimension always 1
212       for( dim = 0; dim < ImageDimension-1; dim++ )
213         {
214           
215           m_Schedule[level][dim] = schedule[level][dim];
216
217           // set schedule to max( 1, min(schedule[level],
218           //  schedule[level-1] );
219           if( level > 0 )
220             {
221               m_Schedule[level][dim] = vnl_math_min( m_Schedule[level][dim], m_Schedule[level-1][dim] );
222             }
223           
224           if( m_Schedule[level][dim] < 1 )
225             {
226               m_Schedule[level][dim] = 1;
227             }
228         }
229       m_Schedule[level][ImageDimension-1]=1;
230     }
231 }
232
233
234 /*
235  * Is the schedule downward divisible ?
236  */
237 template <class TInputImage, class TOutputImage>
238 bool
239 SpatioTemporalMultiResolutionPyramidImageFilter<TInputImage, TOutputImage>
240 ::IsScheduleDownwardDivisible( const ScheduleType& schedule )
241 {
242
243   unsigned int ilevel, idim;
244   for( ilevel = 0; ilevel < schedule.rows() - 1; ilevel++ )
245     {
246     for( idim = 0; idim < schedule.columns(); idim++ )
247       {
248       if( schedule[ilevel][idim] == 0 )
249         {
250         return false;
251         }
252       if( ( schedule[ilevel][idim] % schedule[ilevel+1][idim] ) > 0 )
253         {
254         return false;
255         }
256       }
257     }
258
259   return true;
260 }
261
262 /*
263  * GenerateData for non downward divisible schedules
264  */
265 template <class TInputImage, class TOutputImage>
266 void
267 SpatioTemporalMultiResolutionPyramidImageFilter<TInputImage, TOutputImage>
268 ::GenerateData()
269 {
270   // Get the input and output pointers
271   InputImageConstPointer  inputPtr = this->GetInput();
272
273   // Create caster, smoother and resampleShrinker filters
274   typedef itk::CastImageFilter<TInputImage, TOutputImage>              CasterType;
275   typedef itk::DiscreteGaussianImageFilter<TOutputImage, TOutputImage> SmootherType;
276
277   typedef itk::ImageToImageFilter<TOutputImage,TOutputImage>           ImageToImageType;
278   typedef itk::ResampleImageFilter<TOutputImage,TOutputImage>          ResampleShrinkerType;
279   typedef itk::ShrinkImageFilter<TOutputImage,TOutputImage>            ShrinkerType;
280
281   typename CasterType::Pointer caster = CasterType::New();
282   typename SmootherType::Pointer smoother = SmootherType::New();
283
284   typename ImageToImageType::Pointer shrinkerFilter;
285   //
286   // only one of these pointers is going to be valid, depending on the
287   // value of UseShrinkImageFilter flag
288   typename ResampleShrinkerType::Pointer resampleShrinker;
289   typename ShrinkerType::Pointer shrinker;
290
291   if(this->GetUseShrinkImageFilter())
292     {
293     shrinker = ShrinkerType::New();
294     shrinkerFilter = shrinker.GetPointer();
295     }
296   else
297     {
298     resampleShrinker = ResampleShrinkerType::New();
299     typedef itk::LinearInterpolateImageFunction< OutputImageType, double >
300       LinearInterpolatorType;
301     typename LinearInterpolatorType::Pointer interpolator = 
302       LinearInterpolatorType::New();
303     resampleShrinker->SetInterpolator( interpolator );
304     resampleShrinker->SetDefaultPixelValue( 0 );
305     shrinkerFilter = resampleShrinker.GetPointer();
306     }
307   // Setup the filters
308   caster->SetInput( inputPtr );
309
310   smoother->SetUseImageSpacing( false );
311   smoother->SetInput( caster->GetOutput() );
312   smoother->SetMaximumError( m_MaximumError );
313
314   shrinkerFilter->SetInput( smoother->GetOutput() );
315
316   unsigned int ilevel, idim;
317   unsigned int factors[ImageDimension];
318   double       variance[ImageDimension];
319
320   for( ilevel = 0; ilevel < m_NumberOfLevels; ilevel++ )
321     {
322     this->UpdateProgress( static_cast<float>( ilevel ) /
323                           static_cast<float>( m_NumberOfLevels ) );
324
325     // Allocate memory for each output
326     OutputImagePointer outputPtr = this->GetOutput( ilevel );
327     outputPtr->SetBufferedRegion( outputPtr->GetRequestedRegion() );
328     outputPtr->Allocate();
329
330     // compute shrink factors and variances
331     for( idim = 0; idim < ImageDimension; idim++ )
332       {
333       factors[idim] = m_Schedule[ilevel][idim];
334       variance[idim] = vnl_math_sqr( 0.5 *
335                                      static_cast<float>( factors[idim] ) );
336       }
337
338     if(!this->GetUseShrinkImageFilter())
339       {
340       typedef itk::IdentityTransform<double,OutputImageType::ImageDimension> 
341         IdentityTransformType;
342       typename IdentityTransformType::Pointer identityTransform =
343         IdentityTransformType::New();
344       resampleShrinker->SetOutputParametersFromImage( outputPtr );
345       resampleShrinker->SetTransform(identityTransform);
346       }
347     else
348       {
349       shrinker->SetShrinkFactors(factors);
350       }
351     // use mini-pipeline to compute output
352     smoother->SetVariance( variance );
353
354     shrinkerFilter->GraftOutput( outputPtr );
355
356     // force to always update in case shrink factors are the same
357     shrinkerFilter->Modified();
358     shrinkerFilter->UpdateLargestPossibleRegion();
359     this->GraftNthOutput( ilevel, shrinkerFilter->GetOutput() );
360     }
361 }
362
363 /**
364  * PrintSelf method
365  */
366 template <class TInputImage, class TOutputImage>
367 void
368 SpatioTemporalMultiResolutionPyramidImageFilter<TInputImage, TOutputImage>
369 ::PrintSelf(std::ostream& os, itk::Indent indent) const
370 {
371   Superclass::PrintSelf(os,indent);
372
373   os << indent << "MaximumError: " << m_MaximumError << std::endl;
374   os << indent << "No. levels: " << m_NumberOfLevels << std::endl;
375   os << indent << "Schedule: " << std::endl;
376   os << m_Schedule << std::endl;
377   os << "Use ShrinkImageFilter= " << m_UseShrinkImageFilter << std::endl;
378 }
379
380
381 /*
382  * GenerateOutputInformation
383  */
384 template <class TInputImage, class TOutputImage>
385 void
386 SpatioTemporalMultiResolutionPyramidImageFilter<TInputImage, TOutputImage>
387 ::GenerateOutputInformation()
388 {
389
390   // call the superclass's implementation of this method
391   Superclass::GenerateOutputInformation();
392
393   // get pointers to the input and output
394   InputImageConstPointer inputPtr = this->GetInput();
395
396   if ( !inputPtr  )
397     {
398     itkExceptionMacro( << "Input has not been set" );
399     }
400
401   const typename InputImageType::PointType&
402     inputOrigin = inputPtr->GetOrigin();
403   const typename InputImageType::SpacingType&
404     inputSpacing = inputPtr->GetSpacing();
405   const typename InputImageType::DirectionType&
406     inputDirection = inputPtr->GetDirection();
407   const typename InputImageType::SizeType& inputSize =
408     inputPtr->GetLargestPossibleRegion().GetSize();
409   const typename InputImageType::IndexType& inputStartIndex =
410     inputPtr->GetLargestPossibleRegion().GetIndex();
411
412   typedef typename OutputImageType::SizeType  SizeType;
413   typedef typename SizeType::SizeValueType    SizeValueType;
414   typedef typename OutputImageType::IndexType IndexType;
415   typedef typename IndexType::IndexValueType  IndexValueType;
416
417   OutputImagePointer outputPtr;
418   typename OutputImageType::PointType   outputOrigin;
419   typename OutputImageType::SpacingType outputSpacing;
420   SizeType    outputSize;
421   IndexType   outputStartIndex;
422
423   // we need to compute the output spacing, the output image size,
424   // and the output image start index
425   for(unsigned int ilevel = 0; ilevel < m_NumberOfLevels; ilevel++ )
426     {
427     outputPtr = this->GetOutput( ilevel );
428     if( !outputPtr ) { continue; }
429
430     for(unsigned int idim = 0; idim < OutputImageType::ImageDimension; idim++ )
431       {
432       const double shrinkFactor = static_cast<double>( m_Schedule[ilevel][idim] );
433       outputSpacing[idim] = inputSpacing[idim] * shrinkFactor;
434
435       outputSize[idim] = static_cast<SizeValueType>(
436         std::floor(static_cast<double>(inputSize[idim]) / shrinkFactor ) );
437       if( outputSize[idim] < 1 ) { outputSize[idim] = 1; }
438
439       outputStartIndex[idim] = static_cast<IndexValueType>(
440         std::ceil(static_cast<double>(inputStartIndex[idim]) / shrinkFactor ) );
441       }
442     //Now compute the new shifted origin for the updated levels;
443     const typename OutputImageType::PointType::VectorType outputOriginOffset
444          =(inputDirection*(outputSpacing-inputSpacing))*0.5;
445     for(unsigned int idim = 0; idim < OutputImageType::ImageDimension; idim++ )
446       {
447         outputOrigin[idim]=inputOrigin[idim]+outputOriginOffset[idim];
448       }
449
450     typename OutputImageType::RegionType outputLargestPossibleRegion;
451     outputLargestPossibleRegion.SetSize( outputSize );
452     outputLargestPossibleRegion.SetIndex( outputStartIndex );
453
454     outputPtr->SetLargestPossibleRegion( outputLargestPossibleRegion );
455     outputPtr->SetOrigin ( outputOrigin );
456     outputPtr->SetSpacing( outputSpacing );
457     outputPtr->SetDirection( inputDirection );//Output Direction should be same as input.
458     }
459 }
460
461
462 /*
463  * GenerateOutputRequestedRegion
464  */
465 template <class TInputImage, class TOutputImage>
466 void
467 SpatioTemporalMultiResolutionPyramidImageFilter<TInputImage, TOutputImage>
468 ::GenerateOutputRequestedRegion(itk::DataObject * refOutput )
469 {
470   // call the superclass's implementation of this method
471   Superclass::GenerateOutputRequestedRegion( refOutput );
472
473   // find the index for this output
474   unsigned int refLevel = refOutput->GetSourceOutputIndex();
475
476   // compute baseIndex and baseSize
477   typedef typename OutputImageType::SizeType    SizeType;
478   typedef typename SizeType::SizeValueType      SizeValueType;
479   typedef typename OutputImageType::IndexType   IndexType;
480   typedef typename IndexType::IndexValueType    IndexValueType;
481   typedef typename OutputImageType::RegionType  RegionType;
482
483   TOutputImage * ptr = static_cast<TOutputImage*>( refOutput );
484   if( !ptr )
485     {
486     itkExceptionMacro( << "Could not cast refOutput to TOutputImage*." );
487     }
488
489   unsigned int ilevel, idim;
490
491   if ( ptr->GetRequestedRegion() == ptr->GetLargestPossibleRegion() )
492     {
493     // set the requested regions for the other outputs to their
494     // requested region
495
496     for( ilevel = 0; ilevel < m_NumberOfLevels; ilevel++ )
497       {
498       if( ilevel == refLevel ) { continue; }
499       if( !this->GetOutput(ilevel) ) { continue; }
500       this->GetOutput(ilevel)->SetRequestedRegionToLargestPossibleRegion();
501       }
502     }
503   else
504     {
505     // compute requested regions for the other outputs based on
506     // the requested region of the reference output
507     IndexType outputIndex;
508     SizeType  outputSize;
509     RegionType outputRegion;
510     IndexType  baseIndex = ptr->GetRequestedRegion().GetIndex();
511     SizeType   baseSize  = ptr->GetRequestedRegion().GetSize();
512
513     for( idim = 0; idim < TOutputImage::ImageDimension; idim++ )
514       {
515       unsigned int factor = m_Schedule[refLevel][idim];
516       baseIndex[idim] *= static_cast<IndexValueType>( factor );
517       baseSize[idim] *= static_cast<SizeValueType>( factor );
518       }
519
520     for( ilevel = 0; ilevel < m_NumberOfLevels; ilevel++ )
521       {
522       if( ilevel == refLevel ) { continue; }
523       if( !this->GetOutput(ilevel) ) { continue; }
524
525       for( idim = 0; idim < TOutputImage::ImageDimension; idim++ )
526         {
527
528         double factor = static_cast<double>( m_Schedule[ilevel][idim] );
529
530         outputSize[idim] = static_cast<SizeValueType>(
531           std::floor(static_cast<double>(baseSize[idim]) / factor ) );
532         if( outputSize[idim] < 1 ) { outputSize[idim] = 1; }
533
534         outputIndex[idim] = static_cast<IndexValueType>(
535           std::ceil(static_cast<double>(baseIndex[idim]) / factor ) );
536
537         }
538
539       outputRegion.SetIndex( outputIndex );
540       outputRegion.SetSize( outputSize );
541
542       // make sure the region is within the largest possible region
543       outputRegion.Crop( this->GetOutput( ilevel )->
544                          GetLargestPossibleRegion() );
545       // set the requested region
546       this->GetOutput( ilevel )->SetRequestedRegion( outputRegion );
547       }
548
549     }
550 }
551
552
553 /**
554  * GenerateInputRequestedRegion
555  */
556 template <class TInputImage, class TOutputImage>
557 void
558 SpatioTemporalMultiResolutionPyramidImageFilter<TInputImage, TOutputImage>
559 ::GenerateInputRequestedRegion()
560 {
561   // call the superclass' implementation of this method
562   Superclass::GenerateInputRequestedRegion();
563
564   // get pointers to the input and output
565   InputImagePointer  inputPtr =
566     const_cast< InputImageType * >( this->GetInput() );
567   if ( !inputPtr )
568     {
569     itkExceptionMacro( << "Input has not been set." );
570     }
571
572   // compute baseIndex and baseSize
573   typedef typename OutputImageType::SizeType    SizeType;
574   typedef typename SizeType::SizeValueType      SizeValueType;
575   typedef typename OutputImageType::IndexType   IndexType;
576   typedef typename IndexType::IndexValueType    IndexValueType;
577   typedef typename OutputImageType::RegionType  RegionType;
578
579   unsigned int refLevel = m_NumberOfLevels - 1;
580   SizeType baseSize = this->GetOutput(refLevel)->GetRequestedRegion().GetSize();
581   IndexType baseIndex = this->GetOutput(refLevel)->GetRequestedRegion().GetIndex();
582   RegionType baseRegion;
583
584   unsigned int idim;
585   for( idim = 0; idim < ImageDimension; idim++ )
586     {
587     unsigned int factor = m_Schedule[refLevel][idim];
588     baseIndex[idim] *= static_cast<IndexValueType>( factor );
589     baseSize[idim] *= static_cast<SizeValueType>( factor );
590     }
591   baseRegion.SetIndex( baseIndex );
592   baseRegion.SetSize( baseSize );
593
594   // compute requirements for the smoothing part
595   typedef typename TOutputImage::PixelType                 OutputPixelType;
596   typedef typename itk::GaussianOperator<OutputPixelType,ImageDimension> OperatorType;
597
598   OperatorType *oper = new OperatorType;
599
600   typename TInputImage::SizeType radius;
601
602   RegionType inputRequestedRegion = baseRegion;
603   refLevel = 0;
604
605   for( idim = 0; idim < TInputImage::ImageDimension; idim++ )
606     {
607     oper->SetDirection(idim);
608     oper->SetVariance( vnl_math_sqr( 0.5 * static_cast<float>(
609                                        m_Schedule[refLevel][idim] ) ) );
610     oper->SetMaximumError( m_MaximumError );
611     oper->CreateDirectional();
612     radius[idim] = oper->GetRadius()[idim];
613     }
614   delete oper;
615
616   inputRequestedRegion.PadByRadius( radius );
617
618   // make sure the requested region is within the largest possible
619   inputRequestedRegion.Crop( inputPtr->GetLargestPossibleRegion() );
620
621   // set the input requested region
622   inputPtr->SetRequestedRegion( inputRequestedRegion );
623
624 }
625
626
627 } // namespace clitk
628
629 #endif