]> Creatis software - clitk.git/blob - itk/clitkForwardWarpImageFilter.txx
removed headers
[clitk.git] / itk / clitkForwardWarpImageFilter.txx
1 #ifndef __clitkForwardWarpImageFilter_txx
2 #define __clitkForwardWarpImageFilter_txx
3 #include "clitkForwardWarpImageFilter.h"
4 #include "clitkImageCommon.h"
5
6 // Put the helper classes in an anonymous namespace so that it is not
7 // exposed to the user
8
9 namespace 
10 {//nameless namespace
11
12   //=========================================================================================================================
13   //helper class 1 to allow a threaded execution: add contributions of input to output and update weights
14   //=========================================================================================================================
15   template<class InputImageType, class OutputImageType, class DeformationFieldType> class HelperClass1 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
16   {
17     
18   public: 
19     /** Standard class typedefs. */
20     typedef HelperClass1  Self;
21     typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
22     typedef itk::SmartPointer<Self>         Pointer;
23     typedef itk::SmartPointer<const Self>   ConstPointer;
24     
25     /** Method for creation through the object factory. */
26     itkNewMacro(Self);
27     
28     /** Run-time type information (and related methods) */
29     itkTypeMacro( HelperClass1, ImageToImageFilter );
30         
31     /** Constants for the image dimensions */
32     itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
33  
34
35     //Typedefs
36     typedef typename OutputImageType::PixelType        OutputPixelType;
37     typedef itk::Image<double, ImageDimension > WeightsImageType;
38     typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension > MutexImageType;
39     //===================================================================================   
40     //Set methods
41     void SetWeights(const typename WeightsImageType::Pointer input)
42     {
43       m_Weights = input;
44       this->Modified();
45     }
46     void SetDeformationField(const typename DeformationFieldType::Pointer input)
47     {
48       m_DeformationField=input;
49       this->Modified();
50     }
51     void SetMutexImage(const typename MutexImageType::Pointer input)
52     {
53       m_MutexImage=input;
54       this->Modified();
55       m_ThreadSafe=true;
56     }
57     
58     //Get methods
59     typename WeightsImageType::Pointer GetWeights(){return m_Weights;}
60
61     /** Typedef to describe the output image region type. */
62     typedef typename OutputImageType::RegionType OutputImageRegionType;
63    
64   protected:
65     HelperClass1();
66     ~HelperClass1(){};
67     
68     //the actual processing
69     void BeforeThreadedGenerateData();
70     void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
71
72     //member data
73     typename  itk::Image< double, ImageDimension>::Pointer m_Weights;
74     typename DeformationFieldType::Pointer m_DeformationField;
75     typename MutexImageType::Pointer m_MutexImage;
76     bool m_ThreadSafe;
77
78   };
79
80
81
82   //=========================================================================================================================
83   //Member functions of the helper class 1
84   //=========================================================================================================================
85
86
87   //=========================================================================================================================
88   //Empty constructor
89   template<class InputImageType, class OutputImageType, class DeformationFieldType > 
90   HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::HelperClass1() 
91   {
92     m_ThreadSafe=false; 
93   }
94
95
96   //=========================================================================================================================
97   //Before threaded data
98   template<class InputImageType, class OutputImageType, class DeformationFieldType >
99   void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::BeforeThreadedGenerateData()
100   {
101     //Since we will add, put to zero!
102     this->GetOutput()->FillBuffer(itk::NumericTraits<double>::Zero);
103     this->GetWeights()->FillBuffer(itk::NumericTraits<double>::Zero);
104   }
105   
106
107   //=========================================================================================================================
108   //update the output for the outputRegionForThread
109   template<class InputImageType, class OutputImageType, class DeformationFieldType > 
110   void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
111   {
112
113     //Get pointer to the input
114     typename InputImageType::ConstPointer inputPtr = this->GetInput();
115     
116     //Get pointer to the output
117     typename OutputImageType::Pointer outputPtr = this->GetOutput();
118     typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
119
120     //Iterators over input and deformation field
121     typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
122     typedef itk::ImageRegionIterator<DeformationFieldType> DeformationFieldIteratorType;
123
124     //define them over the outputRegionForThread
125     InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
126     DeformationFieldIteratorType fieldIt(m_DeformationField,outputRegionForThread);
127
128     //Initialize
129     typename InputImageType::IndexType index;
130     itk::ContinuousIndex<double,ImageDimension> contIndex;
131     typename InputImageType::PointType point;
132     typedef typename DeformationFieldType::PixelType DisplacementType;
133     DisplacementType displacement;
134     fieldIt.GoToBegin();
135     inputIt.GoToBegin();
136
137     //define some temp variables
138     signed long baseIndex[ImageDimension];
139     double distance[ImageDimension];
140     unsigned int dim, counter, upper;
141     double overlap, totalOverlap;
142     typename OutputImageType::IndexType neighIndex;
143
144     //Find the number of neighbors
145     unsigned int neighbors =  1 << ImageDimension;
146      
147
148     //==================================================================================================
149     //Loop over the region and add the intensities to the output and the weight to the weights
150     //==================================================================================================
151     while( !inputIt.IsAtEnd() )
152           {
153             
154             // get the input image index
155             index = inputIt.GetIndex();
156             inputPtr->TransformIndexToPhysicalPoint( index, point );
157         
158             // get the required displacement
159             displacement = fieldIt.Get();
160         
161             // compute the required output image point
162             for(unsigned int j = 0; j < ImageDimension; j++ ) point[j] += displacement[j];
163         
164         
165             // Update the output and the weights
166             if(outputPtr->TransformPhysicalPointToContinuousIndex(point, contIndex ) )
167               {     
168                 for(dim = 0; dim < ImageDimension; dim++)
169                   {
170                     // The following  block is equivalent to the following line without
171                     // having to call floor. For positive inputs!!! 
172                     // baseIndex[dim] = (long) vcl_floor(contIndex[dim] );
173                     baseIndex[dim] = (long) contIndex[dim];
174                     distance[dim] = contIndex[dim] - double( baseIndex[dim] );
175                   }
176                 
177                 //Add contribution for each neighbor
178                 totalOverlap = itk::NumericTraits<double>::Zero;
179                 for( counter = 0; counter < neighbors ; counter++ )
180                   {             
181                     overlap = 1.0;          // fraction overlap
182                     upper = counter;  // each bit indicates upper/lower neighbour
183                 
184                     // get neighbor index and overlap fraction
185                     for( dim = 0; dim < 3; dim++ )
186                       {
187                         if ( upper & 1 )
188                           {
189                             neighIndex[dim] = baseIndex[dim] + 1;
190                             overlap *= distance[dim];
191                           }
192                         else
193                           {
194                             neighIndex[dim] = baseIndex[dim];
195                             overlap *= 1.0 - distance[dim];
196                           }
197                         upper >>= 1;
198                       }
199                     
200                     //Set neighbor value only if overlap is not zero
201                     if( (overlap>0.0)) // && 
202                       //                        (static_cast<unsigned int>(neighIndex[0])<size[0]) && 
203                       //                        (static_cast<unsigned int>(neighIndex[1])<size[1]) && 
204                       //                        (static_cast<unsigned int>(neighIndex[2])<size[2]) &&
205                       //                        (neighIndex[0]>=0) &&
206                       //                        (neighIndex[1]>=0) &&
207                       //                        (neighIndex[2]>=0) )
208                       {
209                         
210                         if (! m_ThreadSafe)
211                           {
212                             //Set the pixel and weight at neighIndex
213                             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));                      
214                             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
215                             
216                           }
217                         else 
218                           {
219                             //Entering critilal section: shared memory
220                             m_MutexImage->GetPixel(neighIndex).Lock();
221                             
222                             //Set the pixel and weight at neighIndex
223                             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));   
224                             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
225                             
226                             //Unlock
227                             m_MutexImage->GetPixel(neighIndex).Unlock();
228                             
229                           }
230                         //Add to total overlap
231                         totalOverlap += overlap;
232                       }
233                     
234                     //check for totaloverlap: not very likely
235                     if( totalOverlap == 1.0 )
236                       {
237                         // finished
238                         break;
239                       }
240                   }          
241               }
242           
243             ++fieldIt;
244             ++inputIt;
245           }
246     
247     
248   }
249
250
251
252   //=========================================================================================================================
253   //helper class 2 to allow a threaded execution of normalisation by the weights
254   //=========================================================================================================================
255   template<class InputImageType, class OutputImageType> 
256   class HelperClass2 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
257   {
258     
259   public: 
260     /** Standard class typedefs. */
261     typedef HelperClass2  Self;
262     typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
263     typedef itk::SmartPointer<Self>         Pointer;
264     typedef itk::SmartPointer<const Self>   ConstPointer;
265     
266     /** Method for creation through the object factory. */
267     itkNewMacro(Self);
268     
269     /** Run-time type information (and related methods) */
270     itkTypeMacro( HelperClass2, ImageToImageFilter );
271         
272     /** Constants for the image dimensions */
273     itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
274
275     //Typedefs
276     typedef typename  InputImageType::PixelType        InputPixelType;
277     typedef typename  OutputImageType::PixelType        OutputPixelType;
278     typedef itk::Image<double, ImageDimension > WeightsImageType;
279     
280     
281     //Set methods
282     void SetWeights(const typename WeightsImageType::Pointer input)
283     {
284       m_Weights = input;
285       this->Modified();
286     }
287     void SetEdgePaddingValue(OutputPixelType value)
288     {
289       m_EdgePaddingValue = value;
290       this->Modified();
291     }
292   
293     /** Typedef to describe the output image region type. */
294     typedef typename OutputImageType::RegionType OutputImageRegionType;
295     
296   protected:
297     HelperClass2();
298     ~HelperClass2(){};
299     
300     
301     //the actual processing
302     void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
303
304
305     //member data
306     typename     WeightsImageType::Pointer m_Weights;
307     OutputPixelType m_EdgePaddingValue;
308   } ;
309
310
311
312   //=========================================================================================================================
313   //Member functions of the helper class 2
314   //=========================================================================================================================
315   
316   
317   //=========================================================================================================================
318   //Empty constructor
319   template<class InputImageType, class OutputImageType > 
320   HelperClass2<InputImageType, OutputImageType>::HelperClass2()
321   {
322     m_EdgePaddingValue=static_cast<OutputPixelType>(0.0);
323   }
324   
325   
326   //=========================================================================================================================
327   //update the output for the outputRegionForThread
328   template<class InputImageType, class OutputImageType > void 
329   HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
330   {
331     
332     //Get pointer to the input
333     typename InputImageType::ConstPointer inputPtr = this->GetInput();
334     
335     //Get pointer to the output
336     typename OutputImageType::Pointer outputPtr = this->GetOutput();
337
338     //Iterators over input, weigths  and output
339     typedef itk::ImageRegionConstIterator<InputImageType> InputImageIteratorType;
340     typedef itk::ImageRegionIterator<OutputImageType> OutputImageIteratorType;
341     typedef itk::ImageRegionIterator<WeightsImageType> WeightsImageIteratorType;
342
343     //define them over the outputRegionForThread
344     OutputImageIteratorType outputIt(outputPtr, outputRegionForThread);
345     InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
346     WeightsImageIteratorType weightsIt(m_Weights, outputRegionForThread);
347
348
349     //==================================================================================================
350     //loop over the output and normalize the input, remove holes
351     OutputPixelType neighValue;
352     double zero = itk::NumericTraits<double>::Zero;
353     while (!outputIt.IsAtEnd())
354       {
355         //the weight is not zero
356         if (weightsIt.Get() != zero)
357           {
358             //divide by the weight
359             outputIt.Set(static_cast<OutputPixelType>(inputIt.Get()/weightsIt.Get()));
360           }
361         
362         //copy the value of the  neighbour that was just processed
363         else 
364           {
365             if(!outputIt.IsAtBegin())
366               {
367                 //go back
368                 --outputIt;
369                 neighValue=outputIt.Get();
370                 ++outputIt;
371                 outputIt.Set(neighValue);
372               }
373             else{
374               //DD("is at begin, setting edgepadding value");
375               outputIt.Set(m_EdgePaddingValue);
376             }
377           }
378         ++weightsIt; 
379         ++outputIt;
380         ++inputIt;
381         
382       }//end while
383   }//end member
384   
385
386 }//end nameless namespace
387
388
389
390 namespace clitk
391 {
392
393   //=========================================================================================================================
394   // The rest is the ForwardWarpImageFilter
395   //=========================================================================================================================
396
397   //=========================================================================================================================
398   //constructor
399   template <class InputImageType, class OutputImageType, class DeformationFieldType> 
400   ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::ForwardWarpImageFilter()
401   {
402     // mIsUpdated=false;
403     m_NumberOfThreadsIsGiven=false;
404     m_EdgePaddingValue=static_cast<PixelType>(0.0);
405     m_ThreadSafe=false;
406     m_Verbose=false;
407   }
408
409
410   //=========================================================================================================================
411   //Update
412   template <class InputImageType, class OutputImageType, class DeformationFieldType> 
413   void ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::GenerateData()
414   {
415
416     //Get the properties of the input
417     typename InputImageType::ConstPointer inputPtr=this->GetInput();
418     typename WeightsImageType::RegionType region;
419     typename WeightsImageType::RegionType::SizeType size=inputPtr->GetLargestPossibleRegion().GetSize();
420     region.SetSize(size);
421     typename OutputImageType::IndexType start;
422     for (unsigned int i =0; i< ImageDimension ;i ++)start[i]=0;
423     region.SetIndex(start);
424     
425     //Allocate the weights
426     typename WeightsImageType::Pointer weights=ForwardWarpImageFilter::WeightsImageType::New();
427     weights->SetRegions(region);
428     weights->Allocate();
429     weights->SetSpacing(inputPtr->GetSpacing());
430     
431     
432     //=========================================================================== 
433     //warp is divided in in two loops, for each we call a threaded helper class
434     //1. Add contribution of input to output and update weights
435     //2. Normalize the output by the weight and remove holes
436     //=========================================================================== 
437
438     //=========================================================================== 
439     //1. Add contribution of input to output and update weights
440
441     //Define an internal image type in double  precision
442     typedef itk::Image<double, ImageDimension> InternalImageType;
443     
444     //Call threaded helper class 1
445     typedef HelperClass1<InputImageType, InternalImageType, DeformationFieldType> HelperClass1Type;
446     typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
447     
448     //Set input
449     if(m_NumberOfThreadsIsGiven)helper1->SetNumberOfThreads(m_NumberOfThreads);
450     helper1->SetInput(inputPtr);
451     helper1->SetDeformationField(m_DeformationField);
452     helper1->SetWeights(weights);
453
454     //Threadsafe?
455     if(m_ThreadSafe)
456       {
457         //Allocate the mutex image
458         typename MutexImageType::Pointer mutex=ForwardWarpImageFilter::MutexImageType::New();
459         mutex->SetRegions(region);
460         mutex->Allocate();
461         mutex->SetSpacing(inputPtr->GetSpacing());
462         helper1->SetMutexImage(mutex);
463         if (m_Verbose) std::cout <<"Forwarp warping using a thread-safe algorithm" <<std::endl;
464       }
465     else  if(m_Verbose)std::cout <<"Forwarp warping using a thread-unsafe algorithm" <<std::endl;
466
467     //Execute helper class
468     helper1->Update();
469     
470     //Get the output
471     typename InternalImageType::Pointer temp= helper1->GetOutput();
472    
473     //For clarity
474     weights=helper1->GetWeights();
475    
476    
477     //=========================================================================== 
478     //2. Normalize the output by the weights and remove holes 
479     //Call threaded helper class 
480     typedef HelperClass2<InternalImageType, OutputImageType> HelperClass2Type;
481     typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
482     
483     //Set temporary output as input
484     if(m_NumberOfThreadsIsGiven)helper2->SetNumberOfThreads(m_NumberOfThreads);
485     helper2->SetInput(temp);
486     helper2->SetWeights(weights);
487     helper2->SetEdgePaddingValue(m_EdgePaddingValue);
488     
489     //Execute helper class
490     helper2->Update();
491     
492     //Set the output
493     this->SetNthOutput(0, helper2->GetOutput());
494   }
495       
496 }
497
498 #endif