]> Creatis software - clitk.git/blob - itk/clitkForwardWarpImageFilter.txx
Remove vcl_math calls
[clitk.git] / itk / clitkForwardWarpImageFilter.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to:
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://www.centreleonberard.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef __clitkForwardWarpImageFilter_txx
19 #define __clitkForwardWarpImageFilter_txx
20 #include "clitkForwardWarpImageFilter.h"
21 #include "clitkImageCommon.h"
22
23 // Put the helper classes in an anonymous namespace so that it is not
24 // exposed to the user
25
26 namespace
27 {
28 //nameless namespace
29
30 //=========================================================================================================================
31 //helper class 1 to allow a threaded execution: add contributions of input to output and update weights
32 //=========================================================================================================================
33 template<class InputImageType, class OutputImageType, class DeformationFieldType> class HelperClass1 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
34 {
35
36 public:
37   /** Standard class typedefs. */
38   typedef HelperClass1  Self;
39   typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
40   typedef itk::SmartPointer<Self>         Pointer;
41   typedef itk::SmartPointer<const Self>   ConstPointer;
42
43   /** Method for creation through the object factory. */
44   itkNewMacro(Self);
45
46   /** Run-time type information (and related methods) */
47   itkTypeMacro( HelperClass1, ImageToImageFilter );
48
49   /** Constants for the image dimensions */
50   itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
51
52
53   //Typedefs
54   typedef typename OutputImageType::PixelType        OutputPixelType;
55   typedef itk::Image<double, ImageDimension > WeightsImageType;
56   typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension > MutexImageType;
57   //===================================================================================
58   //Set methods
59   void SetWeights(const typename WeightsImageType::Pointer input) {
60     m_Weights = input;
61     this->Modified();
62   }
63   void SetDeformationField(const typename DeformationFieldType::Pointer input) {
64     m_DeformationField=input;
65     this->Modified();
66   }
67   void SetMutexImage(const typename MutexImageType::Pointer input) {
68     m_MutexImage=input;
69     this->Modified();
70     m_ThreadSafe=true;
71   }
72
73   //Get methods
74   typename WeightsImageType::Pointer GetWeights() {
75     return m_Weights;
76   }
77
78   /** Typedef to describe the output image region type. */
79   typedef typename OutputImageType::RegionType OutputImageRegionType;
80
81 protected:
82   HelperClass1();
83   ~HelperClass1() {};
84
85   //the actual processing
86   void BeforeThreadedGenerateData();
87   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
88
89   //member data
90   typename  itk::Image< double, ImageDimension>::Pointer m_Weights;
91   typename DeformationFieldType::Pointer m_DeformationField;
92   typename MutexImageType::Pointer m_MutexImage;
93   bool m_ThreadSafe;
94
95 };
96
97
98
99 //=========================================================================================================================
100 //Member functions of the helper class 1
101 //=========================================================================================================================
102
103
104 //=========================================================================================================================
105 //Empty constructor
106 template<class InputImageType, class OutputImageType, class DeformationFieldType >
107 HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::HelperClass1()
108 {
109   m_ThreadSafe=false;
110 }
111
112
113 //=========================================================================================================================
114 //Before threaded data
115 template<class InputImageType, class OutputImageType, class DeformationFieldType >
116 void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::BeforeThreadedGenerateData()
117 {
118   //Since we will add, put to zero!
119   this->GetOutput()->FillBuffer(itk::NumericTraits<double>::Zero);
120   this->GetWeights()->FillBuffer(itk::NumericTraits<double>::Zero);
121 }
122
123
124 //=========================================================================================================================
125 //update the output for the outputRegionForThread
126 template<class InputImageType, class OutputImageType, class DeformationFieldType >
127 void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
128 {
129
130   //Get pointer to the input
131   typename InputImageType::ConstPointer inputPtr = this->GetInput();
132
133   //Get pointer to the output
134   typename OutputImageType::Pointer outputPtr = this->GetOutput();
135   //typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
136
137   //Iterators over input and deformation field
138   typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
139   typedef itk::ImageRegionIterator<DeformationFieldType> DeformationFieldIteratorType;
140
141   //define them over the outputRegionForThread
142   InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
143   DeformationFieldIteratorType fieldIt(m_DeformationField,outputRegionForThread);
144
145   //Initialize
146   typename InputImageType::IndexType index;
147   itk::ContinuousIndex<double,ImageDimension> contIndex;
148   typename InputImageType::PointType point;
149   typedef typename DeformationFieldType::PixelType DisplacementType;
150   DisplacementType displacement;
151   fieldIt.GoToBegin();
152   inputIt.GoToBegin();
153
154   //define some temp variables
155   signed long baseIndex[ImageDimension];
156   double distance[ImageDimension];
157   for(unsigned int i=0; i<ImageDimension; i++) distance[i] = 0.0; // to avoid warning
158   unsigned int dim, counter, upper;
159   double overlap, totalOverlap;
160   typename OutputImageType::IndexType neighIndex;
161
162   //Find the number of neighbors
163   unsigned int neighbors =  1 << ImageDimension;
164
165
166   //==================================================================================================
167   //Loop over the region and add the intensities to the output and the weight to the weights
168   //==================================================================================================
169   while( !inputIt.IsAtEnd() ) {
170
171     // get the input image index
172     index = inputIt.GetIndex();
173     inputPtr->TransformIndexToPhysicalPoint( index, point );
174
175     // get the required displacement
176     displacement = fieldIt.Get();
177
178     // compute the required output image point
179     for(unsigned int j = 0; j < ImageDimension; j++ ) point[j] += displacement[j];
180
181
182     // Update the output and the weights
183     if(outputPtr->TransformPhysicalPointToContinuousIndex(point, contIndex ) ) {
184       for(dim = 0; dim < ImageDimension; dim++) {
185         // The following  block is equivalent to the following line without
186         // having to call floor. For positive inputs!!!
187         // baseIndex[dim] = (long) std::floor(contIndex[dim] );
188         baseIndex[dim] = (long) contIndex[dim];
189         distance[dim] = contIndex[dim] - double( baseIndex[dim] );
190       }
191
192       //Add contribution for each neighbor
193       totalOverlap = itk::NumericTraits<double>::Zero;
194       for( counter = 0; counter < neighbors ; counter++ ) {
195         overlap = 1.0;          // fraction overlap
196         upper = counter;  // each bit indicates upper/lower neighbour
197
198         // get neighbor index and overlap fraction
199         for( dim = 0; dim < ImageDimension; dim++ ) {
200           if ( upper & 1 ) {
201             neighIndex[dim] = baseIndex[dim] + 1;
202             overlap *= distance[dim];
203           } else {
204             neighIndex[dim] = baseIndex[dim];
205             overlap *= 1.0 - distance[dim];
206           }
207           upper >>= 1;
208         }
209
210         //Set neighbor value only if overlap is not zero
211         if( (overlap>0.0)) // &&
212           //                    (static_cast<unsigned int>(neighIndex[0])<size[0]) &&
213           //                    (static_cast<unsigned int>(neighIndex[1])<size[1]) &&
214           //                    (static_cast<unsigned int>(neighIndex[2])<size[2]) &&
215           //                    (neighIndex[0]>=0) &&
216           //                    (neighIndex[1]>=0) &&
217           //                    (neighIndex[2]>=0) )
218         {
219
220           if (! m_ThreadSafe) {
221             //Set the pixel and weight at neighIndex
222             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));
223             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
224
225           } else {
226             //Entering critilal section: shared memory
227             m_MutexImage->GetPixel(neighIndex).Lock();
228
229             //Set the pixel and weight at neighIndex
230             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));
231             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
232
233             //Unlock
234             m_MutexImage->GetPixel(neighIndex).Unlock();
235
236           }
237           //Add to total overlap
238           totalOverlap += overlap;
239         }
240
241         //check for totaloverlap: not very likely
242         if( totalOverlap == 1.0 ) {
243           // finished
244           break;
245         }
246       }
247     }
248
249     ++fieldIt;
250     ++inputIt;
251   }
252
253
254 }
255
256
257
258 //=========================================================================================================================
259 //helper class 2 to allow a threaded execution of normalisation by the weights
260 //=========================================================================================================================
261 template<class InputImageType, class OutputImageType>
262 class HelperClass2 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
263 {
264
265 public:
266   /** Standard class typedefs. */
267   typedef HelperClass2  Self;
268   typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
269   typedef itk::SmartPointer<Self>         Pointer;
270   typedef itk::SmartPointer<const Self>   ConstPointer;
271
272   /** Method for creation through the object factory. */
273   itkNewMacro(Self);
274
275   /** Run-time type information (and related methods) */
276   itkTypeMacro( HelperClass2, ImageToImageFilter );
277
278   /** Constants for the image dimensions */
279   itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
280
281   //Typedefs
282   typedef typename  InputImageType::PixelType        InputPixelType;
283   typedef typename  OutputImageType::PixelType        OutputPixelType;
284   typedef itk::Image<double, ImageDimension > WeightsImageType;
285
286
287   //Set methods
288   void SetWeights(const typename WeightsImageType::Pointer input) {
289     m_Weights = input;
290     this->Modified();
291   }
292   void SetEdgePaddingValue(OutputPixelType value) {
293     m_EdgePaddingValue = value;
294     this->Modified();
295   }
296
297   /** Typedef to describe the output image region type. */
298   typedef typename OutputImageType::RegionType OutputImageRegionType;
299
300 protected:
301   HelperClass2();
302   ~HelperClass2() {};
303
304
305   //the actual processing
306   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
307
308
309   //member data
310   typename     WeightsImageType::Pointer m_Weights;
311   OutputPixelType m_EdgePaddingValue;
312 } ;
313
314
315
316 //=========================================================================================================================
317 //Member functions of the helper class 2
318 //=========================================================================================================================
319
320
321 //=========================================================================================================================
322 //Empty constructor
323 template<class InputImageType, class OutputImageType >
324 HelperClass2<InputImageType, OutputImageType>::HelperClass2()
325 {
326   m_EdgePaddingValue=static_cast<OutputPixelType>(0.0);
327 }
328
329
330 //=========================================================================================================================
331 //update the output for the outputRegionForThread
332 template<class InputImageType, class OutputImageType > void
333 HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
334 {
335
336   //Get pointer to the input
337   typename InputImageType::ConstPointer inputPtr = this->GetInput();
338
339   //Get pointer to the output
340   typename OutputImageType::Pointer outputPtr = this->GetOutput();
341
342   //Iterators over input, weigths  and output
343   typedef itk::ImageRegionConstIterator<InputImageType> InputImageIteratorType;
344   typedef itk::ImageRegionIterator<OutputImageType> OutputImageIteratorType;
345   typedef itk::ImageRegionIterator<WeightsImageType> WeightsImageIteratorType;
346
347   //define them over the outputRegionForThread
348   OutputImageIteratorType outputIt(outputPtr, outputRegionForThread);
349   InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
350   WeightsImageIteratorType weightsIt(m_Weights, outputRegionForThread);
351
352
353   //==================================================================================================
354   //loop over the output and normalize the input, remove holes
355   OutputPixelType neighValue;
356   double zero = itk::NumericTraits<double>::Zero;
357   while (!outputIt.IsAtEnd()) {
358     //the weight is not zero
359     if (weightsIt.Get() != zero) {
360       //divide by the weight
361       outputIt.Set(static_cast<OutputPixelType>(inputIt.Get()/weightsIt.Get()));
362     }
363
364     //copy the value of the  neighbour that was just processed
365     else {
366       if(!outputIt.IsAtBegin()) {
367         //go back
368         --outputIt;
369         neighValue=outputIt.Get();
370         ++outputIt;
371         outputIt.Set(neighValue);
372       } else {
373         //DD("is at begin, setting edgepadding value");
374         outputIt.Set(m_EdgePaddingValue);
375       }
376     }
377     ++weightsIt;
378     ++outputIt;
379     ++inputIt;
380
381   }//end while
382 }//end member
383
384
385 }//end nameless namespace
386
387
388
389 namespace clitk
390 {
391
392 //=========================================================================================================================
393 // The rest is the ForwardWarpImageFilter
394 //=========================================================================================================================
395
396 //=========================================================================================================================
397 //constructor
398 template <class InputImageType, class OutputImageType, class DeformationFieldType>
399 ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::ForwardWarpImageFilter()
400 {
401   // mIsUpdated=false;
402   m_NumberOfThreadsIsGiven=false;
403   m_EdgePaddingValue=static_cast<PixelType>(0.0);
404   m_ThreadSafe=false;
405   m_Verbose=false;
406 }
407
408
409 //=========================================================================================================================
410 //Update
411 template <class InputImageType, class OutputImageType, class DeformationFieldType>
412 void ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::GenerateData()
413 {
414
415   //Get the properties of the input
416   typename InputImageType::ConstPointer inputPtr=this->GetInput();
417   typename WeightsImageType::RegionType region;
418   typename WeightsImageType::RegionType::SizeType size=inputPtr->GetLargestPossibleRegion().GetSize();
419   region.SetSize(size);
420   typename OutputImageType::IndexType start;
421   for (unsigned int i =0; i< ImageDimension ; i ++)start[i]=0;
422   region.SetIndex(start);
423
424   //Allocate the weights
425   typename WeightsImageType::Pointer weights=ForwardWarpImageFilter::WeightsImageType::New();
426   weights->SetRegions(region);
427   weights->Allocate();
428   weights->SetSpacing(inputPtr->GetSpacing());
429
430
431   //===========================================================================
432   //warp is divided in in two loops, for each we call a threaded helper class
433   //1. Add contribution of input to output and update weights
434   //2. Normalize the output by the weight and remove holes
435   //===========================================================================
436
437   //===========================================================================
438   //1. Add contribution of input to output and update weights
439
440   //Define an internal image type in double  precision
441   typedef itk::Image<double, ImageDimension> InternalImageType;
442
443   //Call threaded helper class 1
444   typedef HelperClass1<InputImageType, InternalImageType, DeformationFieldType> HelperClass1Type;
445   typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
446
447   //Set input
448   if(m_NumberOfThreadsIsGiven) {
449 #if ITK_VERSION_MAJOR <= 4
450     helper1->SetNumberOfThreads(m_NumberOfThreads);
451 #else
452     helper1->SetNumberOfWorkUnits(m_NumberOfWorkUnits);
453 #endif
454   }
455   helper1->SetInput(inputPtr);
456   helper1->SetDeformationField(m_DeformationField);
457   helper1->SetWeights(weights);
458
459   //Threadsafe?
460   if(m_ThreadSafe) {
461     //Allocate the mutex image
462     typename MutexImageType::Pointer mutex=ForwardWarpImageFilter::MutexImageType::New();
463     mutex->SetRegions(region);
464     mutex->Allocate();
465     mutex->SetSpacing(inputPtr->GetSpacing());
466     helper1->SetMutexImage(mutex);
467     if (m_Verbose) std::cout <<"Forwarp warping using a thread-safe algorithm" <<std::endl;
468   } else  if(m_Verbose)std::cout <<"Forwarp warping using a thread-unsafe algorithm" <<std::endl;
469
470   //Execute helper class
471   helper1->Update();
472
473   //Get the output
474   typename InternalImageType::Pointer temp= helper1->GetOutput();
475
476   //For clarity
477   weights=helper1->GetWeights();
478
479
480   //===========================================================================
481   //2. Normalize the output by the weights and remove holes
482   //Call threaded helper class
483   typedef HelperClass2<InternalImageType, OutputImageType> HelperClass2Type;
484   typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
485
486   //Set temporary output as input
487   if(m_NumberOfThreadsIsGiven) {
488 #if ITK_VERSION_MAJOR <= 4
489     helper2->SetNumberOfThreads(m_NumberOfThreads);
490 #else
491     helper2->SetNumberOfWorkUnits(m_NumberOfWorkUnits);
492 #endif
493   }
494   helper2->SetInput(temp);
495   helper2->SetWeights(weights);
496   helper2->SetEdgePaddingValue(m_EdgePaddingValue);
497
498   //Execute helper class
499   helper2->Update();
500
501   //Set the output
502   this->SetNthOutput(0, helper2->GetOutput());
503 }
504
505 }
506
507 #endif