]> Creatis software - clitk.git/blob - itk/clitkForwardWarpImageFilter.txx
Change itkSimpleFastMutexLock to std::mutex
[clitk.git] / itk / clitkForwardWarpImageFilter.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to:
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://www.centreleonberard.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef __clitkForwardWarpImageFilter_txx
19 #define __clitkForwardWarpImageFilter_txx
20 #include "clitkForwardWarpImageFilter.h"
21 #include "clitkImageCommon.h"
22
23 // Put the helper classes in an anonymous namespace so that it is not
24 // exposed to the user
25
26 namespace
27 {
28 //nameless namespace
29
30 //=========================================================================================================================
31 //helper class 1 to allow a threaded execution: add contributions of input to output and update weights
32 //=========================================================================================================================
33 template<class InputImageType, class OutputImageType, class DeformationFieldType> class HelperClass1 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
34 {
35
36 public:
37   /** Standard class typedefs. */
38   typedef HelperClass1  Self;
39   typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
40   typedef itk::SmartPointer<Self>         Pointer;
41   typedef itk::SmartPointer<const Self>   ConstPointer;
42
43   /** Method for creation through the object factory. */
44   itkNewMacro(Self);
45
46   /** Run-time type information (and related methods) */
47   itkTypeMacro( HelperClass1, ImageToImageFilter );
48
49   /** Constants for the image dimensions */
50   itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
51
52
53   //Typedefs
54   typedef typename OutputImageType::PixelType        OutputPixelType;
55   typedef itk::Image<double, ImageDimension > WeightsImageType;
56 #if ITK_VERSION_MAJOR <= 4
57     typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension> MutexImageType;
58 #endif
59   //===================================================================================
60   //Set methods
61   void SetWeights(const typename WeightsImageType::Pointer input) {
62     m_Weights = input;
63     this->Modified();
64   }
65   void SetDeformationField(const typename DeformationFieldType::Pointer input) {
66     m_DeformationField=input;
67     this->Modified();
68   }
69 #if ITK_VERSION_MAJOR <= 4
70   void SetMutexImage(const typename MutexImageType::Pointer input) {
71     m_MutexImage=input;
72     this->Modified();
73     m_ThreadSafe=true;
74   }
75 #else
76   void SetMutexImage() {
77     this->Modified();
78     m_ThreadSafe=true;
79   }
80 #endif
81
82   //Get methods
83   typename WeightsImageType::Pointer GetWeights() {
84     return m_Weights;
85   }
86
87   /** Typedef to describe the output image region type. */
88   typedef typename OutputImageType::RegionType OutputImageRegionType;
89
90 protected:
91   HelperClass1();
92   ~HelperClass1() {};
93
94   //the actual processing
95   void BeforeThreadedGenerateData();
96   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
97
98   //member data
99   typename  itk::Image< double, ImageDimension>::Pointer m_Weights;
100   typename DeformationFieldType::Pointer m_DeformationField;
101 #if ITK_VERSION_MAJOR <= 4
102   typename MutexImageType::Pointer m_MutexImage;
103 #else
104   std::mutex m_Mutex;
105 #endif
106   bool m_ThreadSafe;
107
108 };
109
110
111
112 //=========================================================================================================================
113 //Member functions of the helper class 1
114 //=========================================================================================================================
115
116
117 //=========================================================================================================================
118 //Empty constructor
119 template<class InputImageType, class OutputImageType, class DeformationFieldType >
120 HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::HelperClass1()
121 {
122   m_ThreadSafe=false;
123 }
124
125
126 //=========================================================================================================================
127 //Before threaded data
128 template<class InputImageType, class OutputImageType, class DeformationFieldType >
129 void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::BeforeThreadedGenerateData()
130 {
131   //Since we will add, put to zero!
132   this->GetOutput()->FillBuffer(itk::NumericTraits<double>::Zero);
133   this->GetWeights()->FillBuffer(itk::NumericTraits<double>::Zero);
134 }
135
136
137 //=========================================================================================================================
138 //update the output for the outputRegionForThread
139 template<class InputImageType, class OutputImageType, class DeformationFieldType >
140 void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
141 {
142
143   //Get pointer to the input
144   typename InputImageType::ConstPointer inputPtr = this->GetInput();
145
146   //Get pointer to the output
147   typename OutputImageType::Pointer outputPtr = this->GetOutput();
148   //typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
149
150   //Iterators over input and deformation field
151   typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
152   typedef itk::ImageRegionIterator<DeformationFieldType> DeformationFieldIteratorType;
153
154   //define them over the outputRegionForThread
155   InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
156   DeformationFieldIteratorType fieldIt(m_DeformationField,outputRegionForThread);
157
158   //Initialize
159   typename InputImageType::IndexType index;
160   itk::ContinuousIndex<double,ImageDimension> contIndex;
161   typename InputImageType::PointType point;
162   typedef typename DeformationFieldType::PixelType DisplacementType;
163   DisplacementType displacement;
164   fieldIt.GoToBegin();
165   inputIt.GoToBegin();
166
167   //define some temp variables
168   signed long baseIndex[ImageDimension];
169   double distance[ImageDimension];
170   for(unsigned int i=0; i<ImageDimension; i++) distance[i] = 0.0; // to avoid warning
171   unsigned int dim, counter, upper;
172   double overlap, totalOverlap;
173   typename OutputImageType::IndexType neighIndex;
174
175   //Find the number of neighbors
176   unsigned int neighbors =  1 << ImageDimension;
177
178
179   //==================================================================================================
180   //Loop over the region and add the intensities to the output and the weight to the weights
181   //==================================================================================================
182   while( !inputIt.IsAtEnd() ) {
183
184     // get the input image index
185     index = inputIt.GetIndex();
186     inputPtr->TransformIndexToPhysicalPoint( index, point );
187
188     // get the required displacement
189     displacement = fieldIt.Get();
190
191     // compute the required output image point
192     for(unsigned int j = 0; j < ImageDimension; j++ ) point[j] += displacement[j];
193
194
195     // Update the output and the weights
196     if(outputPtr->TransformPhysicalPointToContinuousIndex(point, contIndex ) ) {
197       for(dim = 0; dim < ImageDimension; dim++) {
198         // The following  block is equivalent to the following line without
199         // having to call floor. For positive inputs!!!
200         // baseIndex[dim] = (long) std::floor(contIndex[dim] );
201         baseIndex[dim] = (long) contIndex[dim];
202         distance[dim] = contIndex[dim] - double( baseIndex[dim] );
203       }
204
205       //Add contribution for each neighbor
206       totalOverlap = itk::NumericTraits<double>::Zero;
207       for( counter = 0; counter < neighbors ; counter++ ) {
208         overlap = 1.0;          // fraction overlap
209         upper = counter;  // each bit indicates upper/lower neighbour
210
211         // get neighbor index and overlap fraction
212         for( dim = 0; dim < ImageDimension; dim++ ) {
213           if ( upper & 1 ) {
214             neighIndex[dim] = baseIndex[dim] + 1;
215             overlap *= distance[dim];
216           } else {
217             neighIndex[dim] = baseIndex[dim];
218             overlap *= 1.0 - distance[dim];
219           }
220           upper >>= 1;
221         }
222
223         //Set neighbor value only if overlap is not zero
224         if( (overlap>0.0)) // &&
225           //                    (static_cast<unsigned int>(neighIndex[0])<size[0]) &&
226           //                    (static_cast<unsigned int>(neighIndex[1])<size[1]) &&
227           //                    (static_cast<unsigned int>(neighIndex[2])<size[2]) &&
228           //                    (neighIndex[0]>=0) &&
229           //                    (neighIndex[1]>=0) &&
230           //                    (neighIndex[2]>=0) )
231         {
232
233           if (! m_ThreadSafe) {
234             //Set the pixel and weight at neighIndex
235             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));
236             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
237
238           } else {
239             //Entering critilal section: shared memory
240 #if ITK_VERSION_MAJOR <= 4
241             m_MutexImage->GetPixel(neighIndex).Lock();
242 #else
243             m_Mutex.lock();
244 #endif
245
246             //Set the pixel and weight at neighIndex
247             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));
248             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
249
250             //Unlock
251 #if ITK_VERSION_MAJOR <= 4
252             m_MutexImage->GetPixel(neighIndex).Unlock();
253 #else
254             m_Mutex.unlock();
255 #endif
256
257           }
258           //Add to total overlap
259           totalOverlap += overlap;
260         }
261
262         //check for totaloverlap: not very likely
263         if( totalOverlap == 1.0 ) {
264           // finished
265           break;
266         }
267       }
268     }
269
270     ++fieldIt;
271     ++inputIt;
272   }
273
274
275 }
276
277
278
279 //=========================================================================================================================
280 //helper class 2 to allow a threaded execution of normalisation by the weights
281 //=========================================================================================================================
282 template<class InputImageType, class OutputImageType>
283 class HelperClass2 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
284 {
285
286 public:
287   /** Standard class typedefs. */
288   typedef HelperClass2  Self;
289   typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
290   typedef itk::SmartPointer<Self>         Pointer;
291   typedef itk::SmartPointer<const Self>   ConstPointer;
292
293   /** Method for creation through the object factory. */
294   itkNewMacro(Self);
295
296   /** Run-time type information (and related methods) */
297   itkTypeMacro( HelperClass2, ImageToImageFilter );
298
299   /** Constants for the image dimensions */
300   itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
301
302   //Typedefs
303   typedef typename  InputImageType::PixelType        InputPixelType;
304   typedef typename  OutputImageType::PixelType        OutputPixelType;
305   typedef itk::Image<double, ImageDimension > WeightsImageType;
306
307
308   //Set methods
309   void SetWeights(const typename WeightsImageType::Pointer input) {
310     m_Weights = input;
311     this->Modified();
312   }
313   void SetEdgePaddingValue(OutputPixelType value) {
314     m_EdgePaddingValue = value;
315     this->Modified();
316   }
317
318   /** Typedef to describe the output image region type. */
319   typedef typename OutputImageType::RegionType OutputImageRegionType;
320
321 protected:
322   HelperClass2();
323   ~HelperClass2() {};
324
325
326   //the actual processing
327   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
328
329
330   //member data
331   typename     WeightsImageType::Pointer m_Weights;
332   OutputPixelType m_EdgePaddingValue;
333 } ;
334
335
336
337 //=========================================================================================================================
338 //Member functions of the helper class 2
339 //=========================================================================================================================
340
341
342 //=========================================================================================================================
343 //Empty constructor
344 template<class InputImageType, class OutputImageType >
345 HelperClass2<InputImageType, OutputImageType>::HelperClass2()
346 {
347   m_EdgePaddingValue=static_cast<OutputPixelType>(0.0);
348 }
349
350
351 //=========================================================================================================================
352 //update the output for the outputRegionForThread
353 template<class InputImageType, class OutputImageType > void
354 HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
355 {
356
357   //Get pointer to the input
358   typename InputImageType::ConstPointer inputPtr = this->GetInput();
359
360   //Get pointer to the output
361   typename OutputImageType::Pointer outputPtr = this->GetOutput();
362
363   //Iterators over input, weigths  and output
364   typedef itk::ImageRegionConstIterator<InputImageType> InputImageIteratorType;
365   typedef itk::ImageRegionIterator<OutputImageType> OutputImageIteratorType;
366   typedef itk::ImageRegionIterator<WeightsImageType> WeightsImageIteratorType;
367
368   //define them over the outputRegionForThread
369   OutputImageIteratorType outputIt(outputPtr, outputRegionForThread);
370   InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
371   WeightsImageIteratorType weightsIt(m_Weights, outputRegionForThread);
372
373
374   //==================================================================================================
375   //loop over the output and normalize the input, remove holes
376   OutputPixelType neighValue;
377   double zero = itk::NumericTraits<double>::Zero;
378   while (!outputIt.IsAtEnd()) {
379     //the weight is not zero
380     if (weightsIt.Get() != zero) {
381       //divide by the weight
382       outputIt.Set(static_cast<OutputPixelType>(inputIt.Get()/weightsIt.Get()));
383     }
384
385     //copy the value of the  neighbour that was just processed
386     else {
387       if(!outputIt.IsAtBegin()) {
388         //go back
389         --outputIt;
390         neighValue=outputIt.Get();
391         ++outputIt;
392         outputIt.Set(neighValue);
393       } else {
394         //DD("is at begin, setting edgepadding value");
395         outputIt.Set(m_EdgePaddingValue);
396       }
397     }
398     ++weightsIt;
399     ++outputIt;
400     ++inputIt;
401
402   }//end while
403 }//end member
404
405
406 }//end nameless namespace
407
408
409
410 namespace clitk
411 {
412
413 //=========================================================================================================================
414 // The rest is the ForwardWarpImageFilter
415 //=========================================================================================================================
416
417 //=========================================================================================================================
418 //constructor
419 template <class InputImageType, class OutputImageType, class DeformationFieldType>
420 ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::ForwardWarpImageFilter()
421 {
422   // mIsUpdated=false;
423   m_NumberOfThreadsIsGiven=false;
424   m_EdgePaddingValue=static_cast<PixelType>(0.0);
425   m_ThreadSafe=false;
426   m_Verbose=false;
427 }
428
429
430 //=========================================================================================================================
431 //Update
432 template <class InputImageType, class OutputImageType, class DeformationFieldType>
433 void ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::GenerateData()
434 {
435
436   //Get the properties of the input
437   typename InputImageType::ConstPointer inputPtr=this->GetInput();
438   typename WeightsImageType::RegionType region;
439   typename WeightsImageType::RegionType::SizeType size=inputPtr->GetLargestPossibleRegion().GetSize();
440   region.SetSize(size);
441   typename OutputImageType::IndexType start;
442   for (unsigned int i =0; i< ImageDimension ; i ++)start[i]=0;
443   region.SetIndex(start);
444
445   //Allocate the weights
446   typename WeightsImageType::Pointer weights=ForwardWarpImageFilter::WeightsImageType::New();
447   weights->SetRegions(region);
448   weights->Allocate();
449   weights->SetSpacing(inputPtr->GetSpacing());
450
451
452   //===========================================================================
453   //warp is divided in in two loops, for each we call a threaded helper class
454   //1. Add contribution of input to output and update weights
455   //2. Normalize the output by the weight and remove holes
456   //===========================================================================
457
458   //===========================================================================
459   //1. Add contribution of input to output and update weights
460
461   //Define an internal image type in double  precision
462   typedef itk::Image<double, ImageDimension> InternalImageType;
463
464   //Call threaded helper class 1
465   typedef HelperClass1<InputImageType, InternalImageType, DeformationFieldType> HelperClass1Type;
466   typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
467
468   //Set input
469   if(m_NumberOfThreadsIsGiven) {
470 #if ITK_VERSION_MAJOR <= 4
471     helper1->SetNumberOfThreads(m_NumberOfThreads);
472 #else
473     helper1->SetNumberOfWorkUnits(m_NumberOfWorkUnits);
474 #endif
475   }
476   helper1->SetInput(inputPtr);
477   helper1->SetDeformationField(m_DeformationField);
478   helper1->SetWeights(weights);
479
480   //Threadsafe?
481   if(m_ThreadSafe) {
482     //Allocate the mutex image
483 #if ITK_VERSION_MAJOR <= 4
484     typename MutexImageType::Pointer mutex=ForwardWarpImageFilter::MutexImageType::New();
485     mutex->SetRegions(region);
486     mutex->Allocate();
487     mutex->SetSpacing(inputPtr->GetSpacing());
488     helper1->SetMutexImage(mutex);
489 #else
490     helper1->SetMutexImage();
491 #endif
492     if (m_Verbose) std::cout <<"Forwarp warping using a thread-safe algorithm" <<std::endl;
493   } else  if(m_Verbose)std::cout <<"Forwarp warping using a thread-unsafe algorithm" <<std::endl;
494
495   //Execute helper class
496   helper1->Update();
497
498   //Get the output
499   typename InternalImageType::Pointer temp= helper1->GetOutput();
500
501   //For clarity
502   weights=helper1->GetWeights();
503
504
505   //===========================================================================
506   //2. Normalize the output by the weights and remove holes
507   //Call threaded helper class
508   typedef HelperClass2<InternalImageType, OutputImageType> HelperClass2Type;
509   typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
510
511   //Set temporary output as input
512   if(m_NumberOfThreadsIsGiven) {
513 #if ITK_VERSION_MAJOR <= 4
514     helper2->SetNumberOfThreads(m_NumberOfThreads);
515 #else
516     helper2->SetNumberOfWorkUnits(m_NumberOfWorkUnits);
517 #endif
518   }
519   helper2->SetInput(temp);
520   helper2->SetWeights(weights);
521   helper2->SetEdgePaddingValue(m_EdgePaddingValue);
522
523   //Execute helper class
524   helper2->Update();
525
526   //Set the output
527   this->SetNthOutput(0, helper2->GetOutput());
528 }
529
530 }
531
532 #endif