]> Creatis software - clitk.git/blob - itk/clitkForwardWarpImageFilter.txx
38836be401d2bba9afef8d84c850d299eb20c657
[clitk.git] / itk / clitkForwardWarpImageFilter.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to: 
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://oncora1.lyon.fnclcc.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ======================================================================-====*/
18 #ifndef __clitkForwardWarpImageFilter_txx
19 #define __clitkForwardWarpImageFilter_txx
20 #include "clitkForwardWarpImageFilter.h"
21 #include "clitkImageCommon.h"
22
23 // Put the helper classes in an anonymous namespace so that it is not
24 // exposed to the user
25
26 namespace 
27 {//nameless namespace
28
29   //=========================================================================================================================
30   //helper class 1 to allow a threaded execution: add contributions of input to output and update weights
31   //=========================================================================================================================
32   template<class InputImageType, class OutputImageType, class DeformationFieldType> class HelperClass1 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
33   {
34     
35   public: 
36     /** Standard class typedefs. */
37     typedef HelperClass1  Self;
38     typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
39     typedef itk::SmartPointer<Self>         Pointer;
40     typedef itk::SmartPointer<const Self>   ConstPointer;
41     
42     /** Method for creation through the object factory. */
43     itkNewMacro(Self);
44     
45     /** Run-time type information (and related methods) */
46     itkTypeMacro( HelperClass1, ImageToImageFilter );
47         
48     /** Constants for the image dimensions */
49     itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
50  
51
52     //Typedefs
53     typedef typename OutputImageType::PixelType        OutputPixelType;
54     typedef itk::Image<double, ImageDimension > WeightsImageType;
55     typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension > MutexImageType;
56     //===================================================================================   
57     //Set methods
58     void SetWeights(const typename WeightsImageType::Pointer input)
59     {
60       m_Weights = input;
61       this->Modified();
62     }
63     void SetDeformationField(const typename DeformationFieldType::Pointer input)
64     {
65       m_DeformationField=input;
66       this->Modified();
67     }
68     void SetMutexImage(const typename MutexImageType::Pointer input)
69     {
70       m_MutexImage=input;
71       this->Modified();
72       m_ThreadSafe=true;
73     }
74     
75     //Get methods
76     typename WeightsImageType::Pointer GetWeights(){return m_Weights;}
77
78     /** Typedef to describe the output image region type. */
79     typedef typename OutputImageType::RegionType OutputImageRegionType;
80    
81   protected:
82     HelperClass1();
83     ~HelperClass1(){};
84     
85     //the actual processing
86     void BeforeThreadedGenerateData();
87     void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
88
89     //member data
90     typename  itk::Image< double, ImageDimension>::Pointer m_Weights;
91     typename DeformationFieldType::Pointer m_DeformationField;
92     typename MutexImageType::Pointer m_MutexImage;
93     bool m_ThreadSafe;
94
95   };
96
97
98
99   //=========================================================================================================================
100   //Member functions of the helper class 1
101   //=========================================================================================================================
102
103
104   //=========================================================================================================================
105   //Empty constructor
106   template<class InputImageType, class OutputImageType, class DeformationFieldType > 
107   HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::HelperClass1() 
108   {
109     m_ThreadSafe=false; 
110   }
111
112
113   //=========================================================================================================================
114   //Before threaded data
115   template<class InputImageType, class OutputImageType, class DeformationFieldType >
116   void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::BeforeThreadedGenerateData()
117   {
118     //Since we will add, put to zero!
119     this->GetOutput()->FillBuffer(itk::NumericTraits<double>::Zero);
120     this->GetWeights()->FillBuffer(itk::NumericTraits<double>::Zero);
121   }
122   
123
124   //=========================================================================================================================
125   //update the output for the outputRegionForThread
126   template<class InputImageType, class OutputImageType, class DeformationFieldType > 
127   void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
128   {
129
130     //Get pointer to the input
131     typename InputImageType::ConstPointer inputPtr = this->GetInput();
132     
133     //Get pointer to the output
134     typename OutputImageType::Pointer outputPtr = this->GetOutput();
135     typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
136
137     //Iterators over input and deformation field
138     typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
139     typedef itk::ImageRegionIterator<DeformationFieldType> DeformationFieldIteratorType;
140
141     //define them over the outputRegionForThread
142     InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
143     DeformationFieldIteratorType fieldIt(m_DeformationField,outputRegionForThread);
144
145     //Initialize
146     typename InputImageType::IndexType index;
147     itk::ContinuousIndex<double,ImageDimension> contIndex;
148     typename InputImageType::PointType point;
149     typedef typename DeformationFieldType::PixelType DisplacementType;
150     DisplacementType displacement;
151     fieldIt.GoToBegin();
152     inputIt.GoToBegin();
153
154     //define some temp variables
155     signed long baseIndex[ImageDimension];
156     double distance[ImageDimension];
157     unsigned int dim, counter, upper;
158     double overlap, totalOverlap;
159     typename OutputImageType::IndexType neighIndex;
160
161     //Find the number of neighbors
162     unsigned int neighbors =  1 << ImageDimension;
163      
164
165     //==================================================================================================
166     //Loop over the region and add the intensities to the output and the weight to the weights
167     //==================================================================================================
168     while( !inputIt.IsAtEnd() )
169           {
170             
171             // get the input image index
172             index = inputIt.GetIndex();
173             inputPtr->TransformIndexToPhysicalPoint( index, point );
174         
175             // get the required displacement
176             displacement = fieldIt.Get();
177         
178             // compute the required output image point
179             for(unsigned int j = 0; j < ImageDimension; j++ ) point[j] += displacement[j];
180         
181         
182             // Update the output and the weights
183             if(outputPtr->TransformPhysicalPointToContinuousIndex(point, contIndex ) )
184               {     
185                 for(dim = 0; dim < ImageDimension; dim++)
186                   {
187                     // The following  block is equivalent to the following line without
188                     // having to call floor. For positive inputs!!! 
189                     // baseIndex[dim] = (long) vcl_floor(contIndex[dim] );
190                     baseIndex[dim] = (long) contIndex[dim];
191                     distance[dim] = contIndex[dim] - double( baseIndex[dim] );
192                   }
193                 
194                 //Add contribution for each neighbor
195                 totalOverlap = itk::NumericTraits<double>::Zero;
196                 for( counter = 0; counter < neighbors ; counter++ )
197                   {             
198                     overlap = 1.0;          // fraction overlap
199                     upper = counter;  // each bit indicates upper/lower neighbour
200                 
201                     // get neighbor index and overlap fraction
202                     for( dim = 0; dim < 3; dim++ )
203                       {
204                         if ( upper & 1 )
205                           {
206                             neighIndex[dim] = baseIndex[dim] + 1;
207                             overlap *= distance[dim];
208                           }
209                         else
210                           {
211                             neighIndex[dim] = baseIndex[dim];
212                             overlap *= 1.0 - distance[dim];
213                           }
214                         upper >>= 1;
215                       }
216                     
217                     //Set neighbor value only if overlap is not zero
218                     if( (overlap>0.0)) // && 
219                       //                        (static_cast<unsigned int>(neighIndex[0])<size[0]) && 
220                       //                        (static_cast<unsigned int>(neighIndex[1])<size[1]) && 
221                       //                        (static_cast<unsigned int>(neighIndex[2])<size[2]) &&
222                       //                        (neighIndex[0]>=0) &&
223                       //                        (neighIndex[1]>=0) &&
224                       //                        (neighIndex[2]>=0) )
225                       {
226                         
227                         if (! m_ThreadSafe)
228                           {
229                             //Set the pixel and weight at neighIndex
230                             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));                      
231                             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
232                             
233                           }
234                         else 
235                           {
236                             //Entering critilal section: shared memory
237                             m_MutexImage->GetPixel(neighIndex).Lock();
238                             
239                             //Set the pixel and weight at neighIndex
240                             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));   
241                             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
242                             
243                             //Unlock
244                             m_MutexImage->GetPixel(neighIndex).Unlock();
245                             
246                           }
247                         //Add to total overlap
248                         totalOverlap += overlap;
249                       }
250                     
251                     //check for totaloverlap: not very likely
252                     if( totalOverlap == 1.0 )
253                       {
254                         // finished
255                         break;
256                       }
257                   }          
258               }
259           
260             ++fieldIt;
261             ++inputIt;
262           }
263     
264     
265   }
266
267
268
269   //=========================================================================================================================
270   //helper class 2 to allow a threaded execution of normalisation by the weights
271   //=========================================================================================================================
272   template<class InputImageType, class OutputImageType> 
273   class HelperClass2 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
274   {
275     
276   public: 
277     /** Standard class typedefs. */
278     typedef HelperClass2  Self;
279     typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
280     typedef itk::SmartPointer<Self>         Pointer;
281     typedef itk::SmartPointer<const Self>   ConstPointer;
282     
283     /** Method for creation through the object factory. */
284     itkNewMacro(Self);
285     
286     /** Run-time type information (and related methods) */
287     itkTypeMacro( HelperClass2, ImageToImageFilter );
288         
289     /** Constants for the image dimensions */
290     itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
291
292     //Typedefs
293     typedef typename  InputImageType::PixelType        InputPixelType;
294     typedef typename  OutputImageType::PixelType        OutputPixelType;
295     typedef itk::Image<double, ImageDimension > WeightsImageType;
296     
297     
298     //Set methods
299     void SetWeights(const typename WeightsImageType::Pointer input)
300     {
301       m_Weights = input;
302       this->Modified();
303     }
304     void SetEdgePaddingValue(OutputPixelType value)
305     {
306       m_EdgePaddingValue = value;
307       this->Modified();
308     }
309   
310     /** Typedef to describe the output image region type. */
311     typedef typename OutputImageType::RegionType OutputImageRegionType;
312     
313   protected:
314     HelperClass2();
315     ~HelperClass2(){};
316     
317     
318     //the actual processing
319     void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
320
321
322     //member data
323     typename     WeightsImageType::Pointer m_Weights;
324     OutputPixelType m_EdgePaddingValue;
325   } ;
326
327
328
329   //=========================================================================================================================
330   //Member functions of the helper class 2
331   //=========================================================================================================================
332   
333   
334   //=========================================================================================================================
335   //Empty constructor
336   template<class InputImageType, class OutputImageType > 
337   HelperClass2<InputImageType, OutputImageType>::HelperClass2()
338   {
339     m_EdgePaddingValue=static_cast<OutputPixelType>(0.0);
340   }
341   
342   
343   //=========================================================================================================================
344   //update the output for the outputRegionForThread
345   template<class InputImageType, class OutputImageType > void 
346   HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
347   {
348     
349     //Get pointer to the input
350     typename InputImageType::ConstPointer inputPtr = this->GetInput();
351     
352     //Get pointer to the output
353     typename OutputImageType::Pointer outputPtr = this->GetOutput();
354
355     //Iterators over input, weigths  and output
356     typedef itk::ImageRegionConstIterator<InputImageType> InputImageIteratorType;
357     typedef itk::ImageRegionIterator<OutputImageType> OutputImageIteratorType;
358     typedef itk::ImageRegionIterator<WeightsImageType> WeightsImageIteratorType;
359
360     //define them over the outputRegionForThread
361     OutputImageIteratorType outputIt(outputPtr, outputRegionForThread);
362     InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
363     WeightsImageIteratorType weightsIt(m_Weights, outputRegionForThread);
364
365
366     //==================================================================================================
367     //loop over the output and normalize the input, remove holes
368     OutputPixelType neighValue;
369     double zero = itk::NumericTraits<double>::Zero;
370     while (!outputIt.IsAtEnd())
371       {
372         //the weight is not zero
373         if (weightsIt.Get() != zero)
374           {
375             //divide by the weight
376             outputIt.Set(static_cast<OutputPixelType>(inputIt.Get()/weightsIt.Get()));
377           }
378         
379         //copy the value of the  neighbour that was just processed
380         else 
381           {
382             if(!outputIt.IsAtBegin())
383               {
384                 //go back
385                 --outputIt;
386                 neighValue=outputIt.Get();
387                 ++outputIt;
388                 outputIt.Set(neighValue);
389               }
390             else{
391               //DD("is at begin, setting edgepadding value");
392               outputIt.Set(m_EdgePaddingValue);
393             }
394           }
395         ++weightsIt; 
396         ++outputIt;
397         ++inputIt;
398         
399       }//end while
400   }//end member
401   
402
403 }//end nameless namespace
404
405
406
407 namespace clitk
408 {
409
410   //=========================================================================================================================
411   // The rest is the ForwardWarpImageFilter
412   //=========================================================================================================================
413
414   //=========================================================================================================================
415   //constructor
416   template <class InputImageType, class OutputImageType, class DeformationFieldType> 
417   ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::ForwardWarpImageFilter()
418   {
419     // mIsUpdated=false;
420     m_NumberOfThreadsIsGiven=false;
421     m_EdgePaddingValue=static_cast<PixelType>(0.0);
422     m_ThreadSafe=false;
423     m_Verbose=false;
424   }
425
426
427   //=========================================================================================================================
428   //Update
429   template <class InputImageType, class OutputImageType, class DeformationFieldType> 
430   void ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::GenerateData()
431   {
432
433     //Get the properties of the input
434     typename InputImageType::ConstPointer inputPtr=this->GetInput();
435     typename WeightsImageType::RegionType region;
436     typename WeightsImageType::RegionType::SizeType size=inputPtr->GetLargestPossibleRegion().GetSize();
437     region.SetSize(size);
438     typename OutputImageType::IndexType start;
439     for (unsigned int i =0; i< ImageDimension ;i ++)start[i]=0;
440     region.SetIndex(start);
441     
442     //Allocate the weights
443     typename WeightsImageType::Pointer weights=ForwardWarpImageFilter::WeightsImageType::New();
444     weights->SetRegions(region);
445     weights->Allocate();
446     weights->SetSpacing(inputPtr->GetSpacing());
447     
448     
449     //=========================================================================== 
450     //warp is divided in in two loops, for each we call a threaded helper class
451     //1. Add contribution of input to output and update weights
452     //2. Normalize the output by the weight and remove holes
453     //=========================================================================== 
454
455     //=========================================================================== 
456     //1. Add contribution of input to output and update weights
457
458     //Define an internal image type in double  precision
459     typedef itk::Image<double, ImageDimension> InternalImageType;
460     
461     //Call threaded helper class 1
462     typedef HelperClass1<InputImageType, InternalImageType, DeformationFieldType> HelperClass1Type;
463     typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
464     
465     //Set input
466     if(m_NumberOfThreadsIsGiven)helper1->SetNumberOfThreads(m_NumberOfThreads);
467     helper1->SetInput(inputPtr);
468     helper1->SetDeformationField(m_DeformationField);
469     helper1->SetWeights(weights);
470
471     //Threadsafe?
472     if(m_ThreadSafe)
473       {
474         //Allocate the mutex image
475         typename MutexImageType::Pointer mutex=ForwardWarpImageFilter::MutexImageType::New();
476         mutex->SetRegions(region);
477         mutex->Allocate();
478         mutex->SetSpacing(inputPtr->GetSpacing());
479         helper1->SetMutexImage(mutex);
480         if (m_Verbose) std::cout <<"Forwarp warping using a thread-safe algorithm" <<std::endl;
481       }
482     else  if(m_Verbose)std::cout <<"Forwarp warping using a thread-unsafe algorithm" <<std::endl;
483
484     //Execute helper class
485     helper1->Update();
486     
487     //Get the output
488     typename InternalImageType::Pointer temp= helper1->GetOutput();
489    
490     //For clarity
491     weights=helper1->GetWeights();
492    
493    
494     //=========================================================================== 
495     //2. Normalize the output by the weights and remove holes 
496     //Call threaded helper class 
497     typedef HelperClass2<InternalImageType, OutputImageType> HelperClass2Type;
498     typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
499     
500     //Set temporary output as input
501     if(m_NumberOfThreadsIsGiven)helper2->SetNumberOfThreads(m_NumberOfThreads);
502     helper2->SetInput(temp);
503     helper2->SetWeights(weights);
504     helper2->SetEdgePaddingValue(m_EdgePaddingValue);
505     
506     //Execute helper class
507     helper2->Update();
508     
509     //Set the output
510     this->SetNthOutput(0, helper2->GetOutput());
511   }
512       
513 }
514
515 #endif