]> Creatis software - clitk.git/blob - itk/clitkForwardWarpImageFilter.txx
Reformatted using new coding style
[clitk.git] / itk / clitkForwardWarpImageFilter.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to:
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://oncora1.lyon.fnclcc.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ======================================================================-====*/
18 #ifndef __clitkForwardWarpImageFilter_txx
19 #define __clitkForwardWarpImageFilter_txx
20 #include "clitkForwardWarpImageFilter.h"
21 #include "clitkImageCommon.h"
22
23 // Put the helper classes in an anonymous namespace so that it is not
24 // exposed to the user
25
26 namespace
27 {
28 //nameless namespace
29
30 //=========================================================================================================================
31 //helper class 1 to allow a threaded execution: add contributions of input to output and update weights
32 //=========================================================================================================================
33 template<class InputImageType, class OutputImageType, class DeformationFieldType> class HelperClass1 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
34 {
35
36 public:
37   /** Standard class typedefs. */
38   typedef HelperClass1  Self;
39   typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
40   typedef itk::SmartPointer<Self>         Pointer;
41   typedef itk::SmartPointer<const Self>   ConstPointer;
42
43   /** Method for creation through the object factory. */
44   itkNewMacro(Self);
45
46   /** Run-time type information (and related methods) */
47   itkTypeMacro( HelperClass1, ImageToImageFilter );
48
49   /** Constants for the image dimensions */
50   itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
51
52
53   //Typedefs
54   typedef typename OutputImageType::PixelType        OutputPixelType;
55   typedef itk::Image<double, ImageDimension > WeightsImageType;
56   typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension > MutexImageType;
57   //===================================================================================
58   //Set methods
59   void SetWeights(const typename WeightsImageType::Pointer input) {
60     m_Weights = input;
61     this->Modified();
62   }
63   void SetDeformationField(const typename DeformationFieldType::Pointer input) {
64     m_DeformationField=input;
65     this->Modified();
66   }
67   void SetMutexImage(const typename MutexImageType::Pointer input) {
68     m_MutexImage=input;
69     this->Modified();
70     m_ThreadSafe=true;
71   }
72
73   //Get methods
74   typename WeightsImageType::Pointer GetWeights() {
75     return m_Weights;
76   }
77
78   /** Typedef to describe the output image region type. */
79   typedef typename OutputImageType::RegionType OutputImageRegionType;
80
81 protected:
82   HelperClass1();
83   ~HelperClass1() {};
84
85   //the actual processing
86   void BeforeThreadedGenerateData();
87   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
88
89   //member data
90   typename  itk::Image< double, ImageDimension>::Pointer m_Weights;
91   typename DeformationFieldType::Pointer m_DeformationField;
92   typename MutexImageType::Pointer m_MutexImage;
93   bool m_ThreadSafe;
94
95 };
96
97
98
99 //=========================================================================================================================
100 //Member functions of the helper class 1
101 //=========================================================================================================================
102
103
104 //=========================================================================================================================
105 //Empty constructor
106 template<class InputImageType, class OutputImageType, class DeformationFieldType >
107 HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::HelperClass1()
108 {
109   m_ThreadSafe=false;
110 }
111
112
113 //=========================================================================================================================
114 //Before threaded data
115 template<class InputImageType, class OutputImageType, class DeformationFieldType >
116 void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::BeforeThreadedGenerateData()
117 {
118   //Since we will add, put to zero!
119   this->GetOutput()->FillBuffer(itk::NumericTraits<double>::Zero);
120   this->GetWeights()->FillBuffer(itk::NumericTraits<double>::Zero);
121 }
122
123
124 //=========================================================================================================================
125 //update the output for the outputRegionForThread
126 template<class InputImageType, class OutputImageType, class DeformationFieldType >
127 void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
128 {
129
130   //Get pointer to the input
131   typename InputImageType::ConstPointer inputPtr = this->GetInput();
132
133   //Get pointer to the output
134   typename OutputImageType::Pointer outputPtr = this->GetOutput();
135   typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
136
137   //Iterators over input and deformation field
138   typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
139   typedef itk::ImageRegionIterator<DeformationFieldType> DeformationFieldIteratorType;
140
141   //define them over the outputRegionForThread
142   InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
143   DeformationFieldIteratorType fieldIt(m_DeformationField,outputRegionForThread);
144
145   //Initialize
146   typename InputImageType::IndexType index;
147   itk::ContinuousIndex<double,ImageDimension> contIndex;
148   typename InputImageType::PointType point;
149   typedef typename DeformationFieldType::PixelType DisplacementType;
150   DisplacementType displacement;
151   fieldIt.GoToBegin();
152   inputIt.GoToBegin();
153
154   //define some temp variables
155   signed long baseIndex[ImageDimension];
156   double distance[ImageDimension];
157   unsigned int dim, counter, upper;
158   double overlap, totalOverlap;
159   typename OutputImageType::IndexType neighIndex;
160
161   //Find the number of neighbors
162   unsigned int neighbors =  1 << ImageDimension;
163
164
165   //==================================================================================================
166   //Loop over the region and add the intensities to the output and the weight to the weights
167   //==================================================================================================
168   while( !inputIt.IsAtEnd() ) {
169
170     // get the input image index
171     index = inputIt.GetIndex();
172     inputPtr->TransformIndexToPhysicalPoint( index, point );
173
174     // get the required displacement
175     displacement = fieldIt.Get();
176
177     // compute the required output image point
178     for(unsigned int j = 0; j < ImageDimension; j++ ) point[j] += displacement[j];
179
180
181     // Update the output and the weights
182     if(outputPtr->TransformPhysicalPointToContinuousIndex(point, contIndex ) ) {
183       for(dim = 0; dim < ImageDimension; dim++) {
184         // The following  block is equivalent to the following line without
185         // having to call floor. For positive inputs!!!
186         // baseIndex[dim] = (long) vcl_floor(contIndex[dim] );
187         baseIndex[dim] = (long) contIndex[dim];
188         distance[dim] = contIndex[dim] - double( baseIndex[dim] );
189       }
190
191       //Add contribution for each neighbor
192       totalOverlap = itk::NumericTraits<double>::Zero;
193       for( counter = 0; counter < neighbors ; counter++ ) {
194         overlap = 1.0;          // fraction overlap
195         upper = counter;  // each bit indicates upper/lower neighbour
196
197         // get neighbor index and overlap fraction
198         for( dim = 0; dim < 3; dim++ ) {
199           if ( upper & 1 ) {
200             neighIndex[dim] = baseIndex[dim] + 1;
201             overlap *= distance[dim];
202           } else {
203             neighIndex[dim] = baseIndex[dim];
204             overlap *= 1.0 - distance[dim];
205           }
206           upper >>= 1;
207         }
208
209         //Set neighbor value only if overlap is not zero
210         if( (overlap>0.0)) // &&
211           //                    (static_cast<unsigned int>(neighIndex[0])<size[0]) &&
212           //                    (static_cast<unsigned int>(neighIndex[1])<size[1]) &&
213           //                    (static_cast<unsigned int>(neighIndex[2])<size[2]) &&
214           //                    (neighIndex[0]>=0) &&
215           //                    (neighIndex[1]>=0) &&
216           //                    (neighIndex[2]>=0) )
217         {
218
219           if (! m_ThreadSafe) {
220             //Set the pixel and weight at neighIndex
221             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));
222             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
223
224           } else {
225             //Entering critilal section: shared memory
226             m_MutexImage->GetPixel(neighIndex).Lock();
227
228             //Set the pixel and weight at neighIndex
229             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));
230             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
231
232             //Unlock
233             m_MutexImage->GetPixel(neighIndex).Unlock();
234
235           }
236           //Add to total overlap
237           totalOverlap += overlap;
238         }
239
240         //check for totaloverlap: not very likely
241         if( totalOverlap == 1.0 ) {
242           // finished
243           break;
244         }
245       }
246     }
247
248     ++fieldIt;
249     ++inputIt;
250   }
251
252
253 }
254
255
256
257 //=========================================================================================================================
258 //helper class 2 to allow a threaded execution of normalisation by the weights
259 //=========================================================================================================================
260 template<class InputImageType, class OutputImageType>
261 class HelperClass2 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
262 {
263
264 public:
265   /** Standard class typedefs. */
266   typedef HelperClass2  Self;
267   typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
268   typedef itk::SmartPointer<Self>         Pointer;
269   typedef itk::SmartPointer<const Self>   ConstPointer;
270
271   /** Method for creation through the object factory. */
272   itkNewMacro(Self);
273
274   /** Run-time type information (and related methods) */
275   itkTypeMacro( HelperClass2, ImageToImageFilter );
276
277   /** Constants for the image dimensions */
278   itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
279
280   //Typedefs
281   typedef typename  InputImageType::PixelType        InputPixelType;
282   typedef typename  OutputImageType::PixelType        OutputPixelType;
283   typedef itk::Image<double, ImageDimension > WeightsImageType;
284
285
286   //Set methods
287   void SetWeights(const typename WeightsImageType::Pointer input) {
288     m_Weights = input;
289     this->Modified();
290   }
291   void SetEdgePaddingValue(OutputPixelType value) {
292     m_EdgePaddingValue = value;
293     this->Modified();
294   }
295
296   /** Typedef to describe the output image region type. */
297   typedef typename OutputImageType::RegionType OutputImageRegionType;
298
299 protected:
300   HelperClass2();
301   ~HelperClass2() {};
302
303
304   //the actual processing
305   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
306
307
308   //member data
309   typename     WeightsImageType::Pointer m_Weights;
310   OutputPixelType m_EdgePaddingValue;
311 } ;
312
313
314
315 //=========================================================================================================================
316 //Member functions of the helper class 2
317 //=========================================================================================================================
318
319
320 //=========================================================================================================================
321 //Empty constructor
322 template<class InputImageType, class OutputImageType >
323 HelperClass2<InputImageType, OutputImageType>::HelperClass2()
324 {
325   m_EdgePaddingValue=static_cast<OutputPixelType>(0.0);
326 }
327
328
329 //=========================================================================================================================
330 //update the output for the outputRegionForThread
331 template<class InputImageType, class OutputImageType > void
332 HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
333 {
334
335   //Get pointer to the input
336   typename InputImageType::ConstPointer inputPtr = this->GetInput();
337
338   //Get pointer to the output
339   typename OutputImageType::Pointer outputPtr = this->GetOutput();
340
341   //Iterators over input, weigths  and output
342   typedef itk::ImageRegionConstIterator<InputImageType> InputImageIteratorType;
343   typedef itk::ImageRegionIterator<OutputImageType> OutputImageIteratorType;
344   typedef itk::ImageRegionIterator<WeightsImageType> WeightsImageIteratorType;
345
346   //define them over the outputRegionForThread
347   OutputImageIteratorType outputIt(outputPtr, outputRegionForThread);
348   InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
349   WeightsImageIteratorType weightsIt(m_Weights, outputRegionForThread);
350
351
352   //==================================================================================================
353   //loop over the output and normalize the input, remove holes
354   OutputPixelType neighValue;
355   double zero = itk::NumericTraits<double>::Zero;
356   while (!outputIt.IsAtEnd()) {
357     //the weight is not zero
358     if (weightsIt.Get() != zero) {
359       //divide by the weight
360       outputIt.Set(static_cast<OutputPixelType>(inputIt.Get()/weightsIt.Get()));
361     }
362
363     //copy the value of the  neighbour that was just processed
364     else {
365       if(!outputIt.IsAtBegin()) {
366         //go back
367         --outputIt;
368         neighValue=outputIt.Get();
369         ++outputIt;
370         outputIt.Set(neighValue);
371       } else {
372         //DD("is at begin, setting edgepadding value");
373         outputIt.Set(m_EdgePaddingValue);
374       }
375     }
376     ++weightsIt;
377     ++outputIt;
378     ++inputIt;
379
380   }//end while
381 }//end member
382
383
384 }//end nameless namespace
385
386
387
388 namespace clitk
389 {
390
391 //=========================================================================================================================
392 // The rest is the ForwardWarpImageFilter
393 //=========================================================================================================================
394
395 //=========================================================================================================================
396 //constructor
397 template <class InputImageType, class OutputImageType, class DeformationFieldType>
398 ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::ForwardWarpImageFilter()
399 {
400   // mIsUpdated=false;
401   m_NumberOfThreadsIsGiven=false;
402   m_EdgePaddingValue=static_cast<PixelType>(0.0);
403   m_ThreadSafe=false;
404   m_Verbose=false;
405 }
406
407
408 //=========================================================================================================================
409 //Update
410 template <class InputImageType, class OutputImageType, class DeformationFieldType>
411 void ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::GenerateData()
412 {
413
414   //Get the properties of the input
415   typename InputImageType::ConstPointer inputPtr=this->GetInput();
416   typename WeightsImageType::RegionType region;
417   typename WeightsImageType::RegionType::SizeType size=inputPtr->GetLargestPossibleRegion().GetSize();
418   region.SetSize(size);
419   typename OutputImageType::IndexType start;
420   for (unsigned int i =0; i< ImageDimension ; i ++)start[i]=0;
421   region.SetIndex(start);
422
423   //Allocate the weights
424   typename WeightsImageType::Pointer weights=ForwardWarpImageFilter::WeightsImageType::New();
425   weights->SetRegions(region);
426   weights->Allocate();
427   weights->SetSpacing(inputPtr->GetSpacing());
428
429
430   //===========================================================================
431   //warp is divided in in two loops, for each we call a threaded helper class
432   //1. Add contribution of input to output and update weights
433   //2. Normalize the output by the weight and remove holes
434   //===========================================================================
435
436   //===========================================================================
437   //1. Add contribution of input to output and update weights
438
439   //Define an internal image type in double  precision
440   typedef itk::Image<double, ImageDimension> InternalImageType;
441
442   //Call threaded helper class 1
443   typedef HelperClass1<InputImageType, InternalImageType, DeformationFieldType> HelperClass1Type;
444   typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
445
446   //Set input
447   if(m_NumberOfThreadsIsGiven)helper1->SetNumberOfThreads(m_NumberOfThreads);
448   helper1->SetInput(inputPtr);
449   helper1->SetDeformationField(m_DeformationField);
450   helper1->SetWeights(weights);
451
452   //Threadsafe?
453   if(m_ThreadSafe) {
454     //Allocate the mutex image
455     typename MutexImageType::Pointer mutex=ForwardWarpImageFilter::MutexImageType::New();
456     mutex->SetRegions(region);
457     mutex->Allocate();
458     mutex->SetSpacing(inputPtr->GetSpacing());
459     helper1->SetMutexImage(mutex);
460     if (m_Verbose) std::cout <<"Forwarp warping using a thread-safe algorithm" <<std::endl;
461   } else  if(m_Verbose)std::cout <<"Forwarp warping using a thread-unsafe algorithm" <<std::endl;
462
463   //Execute helper class
464   helper1->Update();
465
466   //Get the output
467   typename InternalImageType::Pointer temp= helper1->GetOutput();
468
469   //For clarity
470   weights=helper1->GetWeights();
471
472
473   //===========================================================================
474   //2. Normalize the output by the weights and remove holes
475   //Call threaded helper class
476   typedef HelperClass2<InternalImageType, OutputImageType> HelperClass2Type;
477   typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
478
479   //Set temporary output as input
480   if(m_NumberOfThreadsIsGiven)helper2->SetNumberOfThreads(m_NumberOfThreads);
481   helper2->SetInput(temp);
482   helper2->SetWeights(weights);
483   helper2->SetEdgePaddingValue(m_EdgePaddingValue);
484
485   //Execute helper class
486   helper2->Update();
487
488   //Set the output
489   this->SetNthOutput(0, helper2->GetOutput());
490 }
491
492 }
493
494 #endif