]> Creatis software - clitk.git/blob - tools/clitkCorrelationRatioImageToImageMetric.txx
fftw is not required anymore
[clitk.git] / tools / clitkCorrelationRatioImageToImageMetric.txx
1 /*=========================================================================
2                                                                                
3   Program:   clitk
4   Language:  C++
5                                                                                 
6   Copyright (c) CREATIS (Centre de Recherche et d'Applications en Traitement de
7   l'Image). All rights reserved. See Doc/License.txt or
8   http://www.creatis.insa-lyon.fr/Public/Gdcm/License.html for details.
9                                                                                 
10      This software is distributed WITHOUT ANY WARRANTY; without even
11      the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
12      PURPOSE.  See the above copyright notices for more information.
13                                                                              
14 =========================================================================*/
15
16 #ifndef _clitkCorrelationRatioImageToImageMetric_txx
17 #define _clitkCorrelationRatioImageToImageMetric_txx
18
19 /**
20  * @file   clitkCorrelationRatioImageToImageMetric.txx
21  * @author Jef Vandemeulebroucke <jef@creatis.insa-lyon.fr>
22  * @date   July 30  18:14:53 2007
23  * 
24  * @brief  Compute the correlation ratio between 2 images
25  * 
26  * 
27  */
28
29 #include "clitkCorrelationRatioImageToImageMetric.h"
30 #include "itkImageRegionConstIteratorWithIndex.h"
31 #include "itkImageRegionConstIterator.h"
32 #include "itkImageRegionIterator.h"
33
34 namespace clitk
35 {
36
37 /*
38  * Constructor
39  */
40 template <class TFixedImage, class TMovingImage> 
41 CorrelationRatioImageToImageMetric<TFixedImage,TMovingImage>
42 ::CorrelationRatioImageToImageMetric()
43 {
44   m_NumberOfBins = 50;
45
46 }
47
48 template <class TFixedImage, class TMovingImage> 
49 void
50 CorrelationRatioImageToImageMetric<TFixedImage,TMovingImage>
51 ::Initialize(void) throw ( ExceptionObject )
52 {
53
54   this->Superclass::Initialize();
55   
56   // Compute the minimum and maximum for the FixedImage over the FixedImageRegion.
57   // We can't use StatisticsImageFilter to do this because the filter computes the min/max for the largest possible region
58   double fixedImageMin = NumericTraits<double>::max();
59   double fixedImageMax = NumericTraits<double>::NonpositiveMin();
60
61   typedef ImageRegionConstIterator<FixedImageType> FixedIteratorType;
62   FixedIteratorType fixedImageIterator( 
63     this->m_FixedImage, this->GetFixedImageRegion() );
64
65   for ( fixedImageIterator.GoToBegin(); 
66         !fixedImageIterator.IsAtEnd(); ++fixedImageIterator )
67     {
68
69     double sample = static_cast<double>( fixedImageIterator.Get() );
70
71     if ( sample < fixedImageMin )
72       {
73       fixedImageMin = sample;
74       }
75
76     if ( sample > fixedImageMax )
77       {
78       fixedImageMax = sample;
79       }
80     }
81
82   // Compute binsize for the fixedImage
83   m_FixedImageBinSize = ( fixedImageMax - fixedImageMin ) / m_NumberOfBins;
84   m_FixedImageMin=fixedImageMin;
85   //Allocate mempry and initialise the fixed image bin
86   m_NumberOfPixelsCountedPerBin.resize( m_NumberOfBins, 0 );
87   m_mMSVPB.resize( m_NumberOfBins, 0.0 );
88   m_mSMVPB.resize( m_NumberOfBins, 0.0 );
89 }
90
91
92 /*
93  * Get the match Measure
94  */
95 template <class TFixedImage, class TMovingImage> 
96 typename CorrelationRatioImageToImageMetric<TFixedImage,TMovingImage>::MeasureType
97 CorrelationRatioImageToImageMetric<TFixedImage,TMovingImage>
98 ::GetValue( const TransformParametersType & parameters ) const
99 {
100
101   itkDebugMacro("GetValue( " << parameters << " ) ");
102
103   FixedImageConstPointer fixedImage = this->m_FixedImage;
104
105   if( !fixedImage ) 
106     {
107     itkExceptionMacro( << "Fixed image has not been assigned" );
108     }
109
110   typedef  itk::ImageRegionConstIteratorWithIndex<FixedImageType> FixedIteratorType;
111
112
113   FixedIteratorType ti( fixedImage, this->GetFixedImageRegion() );
114
115   typename FixedImageType::IndexType index;
116
117   MeasureType measure = itk::NumericTraits< MeasureType >::Zero;
118
119   this->m_NumberOfPixelsCounted = 0;
120   this->SetTransformParameters( parameters );
121
122
123   //temporary measures for the calculation 
124   RealType mSMV=0;
125   RealType mMSV=0;
126
127   while(!ti.IsAtEnd())
128     {
129
130     index = ti.GetIndex();
131     
132     typename Superclass::InputPointType inputPoint;
133     fixedImage->TransformIndexToPhysicalPoint( index, inputPoint );
134
135     // Verify that the point is in the fixed Image Mask
136     if( this->m_FixedImageMask && !this->m_FixedImageMask->IsInside( inputPoint ) )
137       {
138       ++ti;
139       continue;
140       }
141
142     typename Superclass::OutputPointType transformedPoint = this->m_Transform->TransformPoint( inputPoint );
143
144     //Verify that the point is in the moving Image Mask
145     if( this->m_MovingImageMask && !this->m_MovingImageMask->IsInside( transformedPoint ) )
146       {
147       ++ti;
148       continue;
149       }
150
151     // Verify is the interpolated value is in the buffer
152     if( this->m_Interpolator->IsInsideBuffer( transformedPoint ) )
153       {
154         //Accumulate calculations for the correlation ratio
155         //For each pixel the is in both masks and the buffer we adapt the following measures:
156         //movingMeanSquaredValue mMSV; movingSquaredMeanValue mSMV; 
157         //movingMeanSquaredValuePerBin[i] mSMVPB; movingSquaredMeanValuePerBin[i] mSMVPB
158         //NumberOfPixelsCounted, NumberOfPixelsCountedPerBin[i]
159  
160         //get the value of the moving image
161         const RealType movingValue  = this->m_Interpolator->Evaluate( transformedPoint );
162         // for the variance of the overlapping moving image we accumulate the following measures
163         const RealType movingSquaredValue=movingValue*movingValue;
164         mMSV+=movingSquaredValue;
165         mSMV+=movingValue;
166
167         //get the fixed value
168         const RealType fixedValue   = ti.Get();
169
170         //check in which bin the fixed value belongs, get the index 
171         const double fixedImageBinTerm =        (fixedValue - m_FixedImageMin) / m_FixedImageBinSize;
172         const unsigned int fixedImageBinIndex = static_cast<unsigned int>( vcl_floor(fixedImageBinTerm ) );
173         //adapt the measures per bin
174         this->m_mMSVPB[fixedImageBinIndex]+=movingSquaredValue;
175         this->m_mSMVPB[fixedImageBinIndex]+=movingValue;
176         //increase the fixed image bin and the total pixel count
177         this->m_NumberOfPixelsCountedPerBin[fixedImageBinIndex]+=1;
178         this->m_NumberOfPixelsCounted++;
179       }
180     
181     ++ti;
182     }
183
184   if( !this->m_NumberOfPixelsCounted )
185     {
186       itkExceptionMacro(<<"All the points mapped to outside of the moving image");
187     }
188   else
189     {
190
191       //apdapt the measures per bin
192       for (unsigned int i=0; i< m_NumberOfBins; i++ ){
193         if (this->m_NumberOfPixelsCountedPerBin[i]>0){
194         measure+=(this->m_mMSVPB[i]-((this->m_mSMVPB[i]*this->m_mSMVPB[i])/this->m_NumberOfPixelsCountedPerBin[i]));
195         }
196       }
197
198       //Normalize with the global measures
199       measure /= (mMSV-((mSMV*mSMV)/ this->m_NumberOfPixelsCounted));
200       return measure;
201
202     }
203 }
204
205
206
207
208
209 /*
210  * Get the Derivative Measure
211  */
212 template < class TFixedImage, class TMovingImage> 
213 void
214 CorrelationRatioImageToImageMetric<TFixedImage,TMovingImage>
215 ::GetDerivative( const TransformParametersType & parameters,
216                  DerivativeType & derivative  ) const
217 {
218
219   itkDebugMacro("GetDerivative( " << parameters << " ) ");
220   
221   if( !this->GetGradientImage() )
222     {
223     itkExceptionMacro(<<"The gradient image is null, maybe you forgot to call Initialize()");
224     }
225
226   FixedImageConstPointer fixedImage = this->m_FixedImage;
227
228   if( !fixedImage ) 
229     {
230     itkExceptionMacro( << "Fixed image has not been assigned" );
231     }
232
233   const unsigned int ImageDimension = FixedImageType::ImageDimension;
234
235
236   typedef  itk::ImageRegionConstIteratorWithIndex<
237     FixedImageType> FixedIteratorType;
238
239   typedef  itk::ImageRegionConstIteratorWithIndex<
240     ITK_TYPENAME Superclass::GradientImageType> GradientIteratorType;
241
242
243   FixedIteratorType ti( fixedImage, this->GetFixedImageRegion() );
244
245   typename FixedImageType::IndexType index;
246
247   this->m_NumberOfPixelsCounted = 0;
248
249   this->SetTransformParameters( parameters );
250
251   const unsigned int ParametersDimension = this->GetNumberOfParameters();
252   derivative = DerivativeType( ParametersDimension );
253   derivative.Fill( itk::NumericTraits<ITK_TYPENAME DerivativeType::ValueType>::Zero );
254
255   ti.GoToBegin();
256
257   while(!ti.IsAtEnd())
258     {
259
260     index = ti.GetIndex();
261     
262     typename Superclass::InputPointType inputPoint;
263     fixedImage->TransformIndexToPhysicalPoint( index, inputPoint );
264
265     if( this->m_FixedImageMask && !this->m_FixedImageMask->IsInside( inputPoint ) )
266       {
267       ++ti;
268       continue;
269       }
270
271     typename Superclass::OutputPointType transformedPoint = this->m_Transform->TransformPoint( inputPoint );
272
273     if( this->m_MovingImageMask && !this->m_MovingImageMask->IsInside( transformedPoint ) )
274       {
275       ++ti;
276       continue;
277       }
278
279     if( this->m_Interpolator->IsInsideBuffer( transformedPoint ) )
280       {
281       const RealType movingValue  = this->m_Interpolator->Evaluate( transformedPoint );
282
283       const TransformJacobianType & jacobian =
284         this->m_Transform->GetJacobian( inputPoint ); 
285
286       
287       const RealType fixedValue     = ti.Value();
288       this->m_NumberOfPixelsCounted++;
289       const RealType diff = movingValue - fixedValue; 
290
291       // Get the gradient by NearestNeighboorInterpolation: 
292       // which is equivalent to round up the point components.
293       typedef typename Superclass::OutputPointType OutputPointType;
294       typedef typename OutputPointType::CoordRepType CoordRepType;
295       typedef ContinuousIndex<CoordRepType,MovingImageType::ImageDimension>
296         MovingImageContinuousIndexType;
297
298       MovingImageContinuousIndexType tempIndex;
299       this->m_MovingImage->TransformPhysicalPointToContinuousIndex( transformedPoint, tempIndex );
300
301       typename MovingImageType::IndexType mappedIndex; 
302       for( unsigned int j = 0; j < MovingImageType::ImageDimension; j++ )
303         {
304         mappedIndex[j] = static_cast<long>( vnl_math_rnd( tempIndex[j] ) );
305         }
306
307       const GradientPixelType gradient = 
308         this->GetGradientImage()->GetPixel( mappedIndex );
309
310       for(unsigned int par=0; par<ParametersDimension; par++)
311         {
312         RealType sum = NumericTraits< RealType >::Zero;
313         for(unsigned int dim=0; dim<ImageDimension; dim++)
314           {
315           sum += 2.0 * diff * jacobian( dim, par ) * gradient[dim];
316           }
317         derivative[par] += sum;
318         }
319       }
320
321     ++ti;
322     }
323
324   if( !this->m_NumberOfPixelsCounted )
325     {
326     itkExceptionMacro(<<"All the points mapped to outside of the moving image");
327     }
328   else
329     {
330     for(unsigned int i=0; i<ParametersDimension; i++)
331       {
332       derivative[i] /= this->m_NumberOfPixelsCounted;
333       }
334     }
335
336 }
337
338
339 /*
340  * Get both the match Measure and the Derivative Measure 
341  */
342 template <class TFixedImage, class TMovingImage> 
343 void
344 CorrelationRatioImageToImageMetric<TFixedImage,TMovingImage>
345 ::GetValueAndDerivative(const TransformParametersType & parameters, 
346                         MeasureType & value, DerivativeType  & derivative) const
347 {
348
349   itkDebugMacro("GetValueAndDerivative( " << parameters << " ) ");
350
351   if( !this->GetGradientImage() )
352     {
353     itkExceptionMacro(<<"The gradient image is null, maybe you forgot to call Initialize()");
354     }
355
356   FixedImageConstPointer fixedImage = this->m_FixedImage;
357
358   if( !fixedImage ) 
359     {
360     itkExceptionMacro( << "Fixed image has not been assigned" );
361     }
362
363   const unsigned int ImageDimension = FixedImageType::ImageDimension;
364
365   typedef  itk::ImageRegionConstIteratorWithIndex<
366     FixedImageType> FixedIteratorType;
367
368   typedef  itk::ImageRegionConstIteratorWithIndex<
369     ITK_TYPENAME Superclass::GradientImageType> GradientIteratorType;
370
371
372   FixedIteratorType ti( fixedImage, this->GetFixedImageRegion() );
373
374   typename FixedImageType::IndexType index;
375
376   MeasureType measure = NumericTraits< MeasureType >::Zero;
377
378   this->m_NumberOfPixelsCounted = 0;
379
380   this->SetTransformParameters( parameters );
381
382   const unsigned int ParametersDimension = this->GetNumberOfParameters();
383   derivative = DerivativeType( ParametersDimension );
384   derivative.Fill( NumericTraits<ITK_TYPENAME DerivativeType::ValueType>::Zero );
385
386   ti.GoToBegin();
387
388   while(!ti.IsAtEnd())
389     {
390
391     index = ti.GetIndex();
392     
393     typename Superclass::InputPointType inputPoint;
394     fixedImage->TransformIndexToPhysicalPoint( index, inputPoint );
395
396     if( this->m_FixedImageMask && !this->m_FixedImageMask->IsInside( inputPoint ) )
397       {
398       ++ti;
399       continue;
400       }
401
402     typename Superclass::OutputPointType transformedPoint = this->m_Transform->TransformPoint( inputPoint );
403
404     if( this->m_MovingImageMask && !this->m_MovingImageMask->IsInside( transformedPoint ) )
405       {
406       ++ti;
407       continue;
408       }
409
410     if( this->m_Interpolator->IsInsideBuffer( transformedPoint ) )
411       {
412       const RealType movingValue  = this->m_Interpolator->Evaluate( transformedPoint );
413
414       const TransformJacobianType & jacobian =
415         this->m_Transform->GetJacobian( inputPoint ); 
416
417       
418       const RealType fixedValue     = ti.Value();
419       this->m_NumberOfPixelsCounted++;
420
421       const RealType diff = movingValue - fixedValue; 
422   
423       measure += diff * diff;
424
425       // Get the gradient by NearestNeighboorInterpolation: 
426       // which is equivalent to round up the point components.
427       typedef typename Superclass::OutputPointType OutputPointType;
428       typedef typename OutputPointType::CoordRepType CoordRepType;
429       typedef ContinuousIndex<CoordRepType,MovingImageType::ImageDimension>
430         MovingImageContinuousIndexType;
431
432       MovingImageContinuousIndexType tempIndex;
433       this->m_MovingImage->TransformPhysicalPointToContinuousIndex( transformedPoint, tempIndex );
434
435       typename MovingImageType::IndexType mappedIndex; 
436       for( unsigned int j = 0; j < MovingImageType::ImageDimension; j++ )
437         {
438         mappedIndex[j] = static_cast<long>( vnl_math_rnd( tempIndex[j] ) );
439         }
440
441       const GradientPixelType gradient = 
442         this->GetGradientImage()->GetPixel( mappedIndex );
443
444       for(unsigned int par=0; par<ParametersDimension; par++)
445         {
446         RealType sum = NumericTraits< RealType >::Zero;
447         for(unsigned int dim=0; dim<ImageDimension; dim++)
448           {
449           sum += 2.0 * diff * jacobian( dim, par ) * gradient[dim];
450           }
451         derivative[par] += sum;
452         }
453       }
454
455     ++ti;
456     }
457
458   if( !this->m_NumberOfPixelsCounted )
459     {
460     itkExceptionMacro(<<"All the points mapped to outside of the moving image");
461     }
462   else
463     {
464     for(unsigned int i=0; i<ParametersDimension; i++)
465       {
466       derivative[i] /= this->m_NumberOfPixelsCounted;
467       }
468     measure /= this->m_NumberOfPixelsCounted;
469     }
470
471   value = measure;
472
473 }
474
475 } // end namespace itk
476
477
478 #endif
479