]> Creatis software - clitk.git/blob - itk/clitkForwardWarpImageFilter.txx
cleanup / better handling of errors
[clitk.git] / itk / clitkForwardWarpImageFilter.txx
1 #ifndef __clitkForwardWarpImageFilter_txx
2 #define __clitkForwardWarpImageFilter_txx
3
4 #include "clitkForwardWarpImageFilter.h"
5 #include "clitkImageCommon.h"
6
7 // Put the helper classes in an anonymous namespace so that it is not
8 // exposed to the user
9
10 namespace 
11 {//nameless namespace
12
13   //=========================================================================================================================
14   //helper class 1 to allow a threaded execution: add contributions of input to output and update weights
15   //=========================================================================================================================
16   template<class InputImageType, class OutputImageType, class DeformationFieldType> class HelperClass1 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
17   {
18     
19   public: 
20     /** Standard class typedefs. */
21     typedef HelperClass1  Self;
22     typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
23     typedef itk::SmartPointer<Self>         Pointer;
24     typedef itk::SmartPointer<const Self>   ConstPointer;
25     
26     /** Method for creation through the object factory. */
27     itkNewMacro(Self);
28     
29     /** Run-time type information (and related methods) */
30     itkTypeMacro( HelperClass1, ImageToImageFilter );
31         
32     /** Constants for the image dimensions */
33     itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
34  
35
36     //Typedefs
37     typedef typename OutputImageType::PixelType        OutputPixelType;
38     typedef itk::Image<double, ImageDimension > WeightsImageType;
39     typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension > MutexImageType;
40     //===================================================================================   
41     //Set methods
42     void SetWeights(const typename WeightsImageType::Pointer input)
43     {
44       m_Weights = input;
45       this->Modified();
46     }
47     void SetDeformationField(const typename DeformationFieldType::Pointer input)
48     {
49       m_DeformationField=input;
50       this->Modified();
51     }
52     void SetMutexImage(const typename MutexImageType::Pointer input)
53     {
54       m_MutexImage=input;
55       this->Modified();
56       m_ThreadSafe=true;
57     }
58     
59     //Get methods
60     typename WeightsImageType::Pointer GetWeights(){return m_Weights;}
61
62     /** Typedef to describe the output image region type. */
63     typedef typename OutputImageType::RegionType OutputImageRegionType;
64    
65   protected:
66     HelperClass1();
67     ~HelperClass1(){};
68     
69     //the actual processing
70     void BeforeThreadedGenerateData();
71     void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
72
73     //member data
74     typename  itk::Image< double, ImageDimension>::Pointer m_Weights;
75     typename DeformationFieldType::Pointer m_DeformationField;
76     typename MutexImageType::Pointer m_MutexImage;
77     bool m_ThreadSafe;
78
79   };
80
81
82
83   //=========================================================================================================================
84   //Member functions of the helper class 1
85   //=========================================================================================================================
86
87
88   //=========================================================================================================================
89   //Empty constructor
90   template<class InputImageType, class OutputImageType, class DeformationFieldType > 
91   HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::HelperClass1() 
92   {
93     m_ThreadSafe=false; 
94   }
95
96
97   //=========================================================================================================================
98   //Before threaded data
99   template<class InputImageType, class OutputImageType, class DeformationFieldType >
100   void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::BeforeThreadedGenerateData()
101   {
102     //Since we will add, put to zero!
103     this->GetOutput()->FillBuffer(itk::NumericTraits<double>::Zero);
104     this->GetWeights()->FillBuffer(itk::NumericTraits<double>::Zero);
105   }
106   
107
108   //=========================================================================================================================
109   //update the output for the outputRegionForThread
110   template<class InputImageType, class OutputImageType, class DeformationFieldType > 
111   void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
112   {
113
114     //Get pointer to the input
115     typename InputImageType::ConstPointer inputPtr = this->GetInput();
116     
117     //Get pointer to the output
118     typename OutputImageType::Pointer outputPtr = this->GetOutput();
119     typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
120
121     //Iterators over input and deformation field
122     typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
123     typedef itk::ImageRegionIterator<DeformationFieldType> DeformationFieldIteratorType;
124
125     //define them over the outputRegionForThread
126     InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
127     DeformationFieldIteratorType fieldIt(m_DeformationField,outputRegionForThread);
128
129     //Initialize
130     typename InputImageType::IndexType index;
131     itk::ContinuousIndex<double,ImageDimension> contIndex;
132     typename InputImageType::PointType point;
133     typedef typename DeformationFieldType::PixelType DisplacementType;
134     DisplacementType displacement;
135     fieldIt.GoToBegin();
136     inputIt.GoToBegin();
137
138     //define some temp variables
139     signed long baseIndex[ImageDimension];
140     double distance[ImageDimension];
141     unsigned int dim, counter, upper;
142     double overlap, totalOverlap;
143     typename OutputImageType::IndexType neighIndex;
144
145     //Find the number of neighbors
146     unsigned int neighbors =  1 << ImageDimension;
147      
148
149     //==================================================================================================
150     //Loop over the region and add the intensities to the output and the weight to the weights
151     //==================================================================================================
152     while( !inputIt.IsAtEnd() )
153           {
154             
155             // get the input image index
156             index = inputIt.GetIndex();
157             inputPtr->TransformIndexToPhysicalPoint( index, point );
158         
159             // get the required displacement
160             displacement = fieldIt.Get();
161         
162             // compute the required output image point
163             for(unsigned int j = 0; j < ImageDimension; j++ ) point[j] += displacement[j];
164         
165         
166             // Update the output and the weights
167             if(outputPtr->TransformPhysicalPointToContinuousIndex(point, contIndex ) )
168               {     
169                 for(dim = 0; dim < ImageDimension; dim++)
170                   {
171                     // The following  block is equivalent to the following line without
172                     // having to call floor. For positive inputs!!! 
173                     // baseIndex[dim] = (long) vcl_floor(contIndex[dim] );
174                     baseIndex[dim] = (long) contIndex[dim];
175                     distance[dim] = contIndex[dim] - double( baseIndex[dim] );
176                   }
177                 
178                 //Add contribution for each neighbor
179                 totalOverlap = itk::NumericTraits<double>::Zero;
180                 for( counter = 0; counter < neighbors ; counter++ )
181                   {             
182                     overlap = 1.0;          // fraction overlap
183                     upper = counter;  // each bit indicates upper/lower neighbour
184                 
185                     // get neighbor index and overlap fraction
186                     for( dim = 0; dim < 3; dim++ )
187                       {
188                         if ( upper & 1 )
189                           {
190                             neighIndex[dim] = baseIndex[dim] + 1;
191                             overlap *= distance[dim];
192                           }
193                         else
194                           {
195                             neighIndex[dim] = baseIndex[dim];
196                             overlap *= 1.0 - distance[dim];
197                           }
198                         upper >>= 1;
199                       }
200                     
201                     //Set neighbor value only if overlap is not zero
202                     if( (overlap>0.0)) // && 
203                       //                        (static_cast<unsigned int>(neighIndex[0])<size[0]) && 
204                       //                        (static_cast<unsigned int>(neighIndex[1])<size[1]) && 
205                       //                        (static_cast<unsigned int>(neighIndex[2])<size[2]) &&
206                       //                        (neighIndex[0]>=0) &&
207                       //                        (neighIndex[1]>=0) &&
208                       //                        (neighIndex[2]>=0) )
209                       {
210                         
211                         if (! m_ThreadSafe)
212                           {
213                             //Set the pixel and weight at neighIndex
214                             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));                      
215                             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
216                             
217                           }
218                         else 
219                           {
220                             //Entering critilal section: shared memory
221                             m_MutexImage->GetPixel(neighIndex).Lock();
222                             
223                             //Set the pixel and weight at neighIndex
224                             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));   
225                             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
226                             
227                             //Unlock
228                             m_MutexImage->GetPixel(neighIndex).Unlock();
229                             
230                           }
231                         //Add to total overlap
232                         totalOverlap += overlap;
233                       }
234                     
235                     //check for totaloverlap: not very likely
236                     if( totalOverlap == 1.0 )
237                       {
238                         // finished
239                         break;
240                       }
241                   }          
242               }
243           
244             ++fieldIt;
245             ++inputIt;
246           }
247     
248     
249   }
250
251
252
253   //=========================================================================================================================
254   //helper class 2 to allow a threaded execution of normalisation by the weights
255   //=========================================================================================================================
256   template<class InputImageType, class OutputImageType> 
257   class HelperClass2 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
258   {
259     
260   public: 
261     /** Standard class typedefs. */
262     typedef HelperClass2  Self;
263     typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
264     typedef itk::SmartPointer<Self>         Pointer;
265     typedef itk::SmartPointer<const Self>   ConstPointer;
266     
267     /** Method for creation through the object factory. */
268     itkNewMacro(Self);
269     
270     /** Run-time type information (and related methods) */
271     itkTypeMacro( HelperClass2, ImageToImageFilter );
272         
273     /** Constants for the image dimensions */
274     itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
275
276     //Typedefs
277     typedef typename  InputImageType::PixelType        InputPixelType;
278     typedef typename  OutputImageType::PixelType        OutputPixelType;
279     typedef itk::Image<double, ImageDimension > WeightsImageType;
280     
281     
282     //Set methods
283     void SetWeights(const typename WeightsImageType::Pointer input)
284     {
285       m_Weights = input;
286       this->Modified();
287     }
288     void SetEdgePaddingValue(OutputPixelType value)
289     {
290       m_EdgePaddingValue = value;
291       this->Modified();
292     }
293   
294     /** Typedef to describe the output image region type. */
295     typedef typename OutputImageType::RegionType OutputImageRegionType;
296     
297   protected:
298     HelperClass2();
299     ~HelperClass2(){};
300     
301     
302     //the actual processing
303     void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
304
305
306     //member data
307     typename     WeightsImageType::Pointer m_Weights;
308     OutputPixelType m_EdgePaddingValue;
309   } ;
310
311
312
313   //=========================================================================================================================
314   //Member functions of the helper class 2
315   //=========================================================================================================================
316   
317   
318   //=========================================================================================================================
319   //Empty constructor
320   template<class InputImageType, class OutputImageType > 
321   HelperClass2<InputImageType, OutputImageType>::HelperClass2()
322   {
323     m_EdgePaddingValue=static_cast<OutputPixelType>(0.0);
324   }
325   
326   
327   //=========================================================================================================================
328   //update the output for the outputRegionForThread
329   template<class InputImageType, class OutputImageType > void 
330   HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
331   {
332     
333     //Get pointer to the input
334     typename InputImageType::ConstPointer inputPtr = this->GetInput();
335     
336     //Get pointer to the output
337     typename OutputImageType::Pointer outputPtr = this->GetOutput();
338
339     //Iterators over input, weigths  and output
340     typedef itk::ImageRegionConstIterator<InputImageType> InputImageIteratorType;
341     typedef itk::ImageRegionIterator<OutputImageType> OutputImageIteratorType;
342     typedef itk::ImageRegionIterator<WeightsImageType> WeightsImageIteratorType;
343
344     //define them over the outputRegionForThread
345     OutputImageIteratorType outputIt(outputPtr, outputRegionForThread);
346     InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
347     WeightsImageIteratorType weightsIt(m_Weights, outputRegionForThread);
348
349
350     //==================================================================================================
351     //loop over the output and normalize the input, remove holes
352     OutputPixelType neighValue;
353     double zero = itk::NumericTraits<double>::Zero;
354     while (!outputIt.IsAtEnd())
355       {
356         //the weight is not zero
357         if (weightsIt.Get() != zero)
358           {
359             //divide by the weight
360             outputIt.Set(static_cast<OutputPixelType>(inputIt.Get()/weightsIt.Get()));
361           }
362         
363         //copy the value of the  neighbour that was just processed
364         else 
365           {
366             if(!outputIt.IsAtBegin())
367               {
368                 //go back
369                 --outputIt;
370                 neighValue=outputIt.Get();
371                 ++outputIt;
372                 outputIt.Set(neighValue);
373               }
374             else{
375               //DD("is at begin, setting edgepadding value");
376               outputIt.Set(m_EdgePaddingValue);
377             }
378           }
379         ++weightsIt; 
380         ++outputIt;
381         ++inputIt;
382         
383       }//end while
384   }//end member
385   
386
387 }//end nameless namespace
388
389
390
391 namespace clitk
392 {
393
394   //=========================================================================================================================
395   // The rest is the ForwardWarpImageFilter
396   //=========================================================================================================================
397
398   //=========================================================================================================================
399   //constructor
400   template <class InputImageType, class OutputImageType, class DeformationFieldType> 
401   ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::ForwardWarpImageFilter()
402   {
403     // mIsUpdated=false;
404     m_NumberOfThreadsIsGiven=false;
405     m_EdgePaddingValue=static_cast<PixelType>(0.0);
406     m_ThreadSafe=false;
407     m_Verbose=false;
408   }
409
410
411   //=========================================================================================================================
412   //Update
413   template <class InputImageType, class OutputImageType, class DeformationFieldType> 
414   void ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::GenerateData()
415   {
416
417     //Get the properties of the input
418     typename InputImageType::ConstPointer inputPtr=this->GetInput();
419     typename WeightsImageType::RegionType region;
420     typename WeightsImageType::RegionType::SizeType size=inputPtr->GetLargestPossibleRegion().GetSize();
421     region.SetSize(size);
422     typename OutputImageType::IndexType start;
423     for (unsigned int i =0; i< ImageDimension ;i ++)start[i]=0;
424     region.SetIndex(start);
425     
426     //Allocate the weights
427     typename WeightsImageType::Pointer weights=ForwardWarpImageFilter::WeightsImageType::New();
428     weights->SetRegions(region);
429     weights->Allocate();
430     weights->SetSpacing(inputPtr->GetSpacing());
431     
432     
433     //=========================================================================== 
434     //warp is divided in in two loops, for each we call a threaded helper class
435     //1. Add contribution of input to output and update weights
436     //2. Normalize the output by the weight and remove holes
437     //=========================================================================== 
438
439     //=========================================================================== 
440     //1. Add contribution of input to output and update weights
441
442     //Define an internal image type in double  precision
443     typedef itk::Image<double, ImageDimension> InternalImageType;
444     
445     //Call threaded helper class 1
446     typedef HelperClass1<InputImageType, InternalImageType, DeformationFieldType> HelperClass1Type;
447     typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
448     
449     //Set input
450     if(m_NumberOfThreadsIsGiven)helper1->SetNumberOfThreads(m_NumberOfThreads);
451     helper1->SetInput(inputPtr);
452     helper1->SetDeformationField(m_DeformationField);
453     helper1->SetWeights(weights);
454
455     //Threadsafe?
456     if(m_ThreadSafe)
457       {
458         //Allocate the mutex image
459         typename MutexImageType::Pointer mutex=ForwardWarpImageFilter::MutexImageType::New();
460         mutex->SetRegions(region);
461         mutex->Allocate();
462         mutex->SetSpacing(inputPtr->GetSpacing());
463         helper1->SetMutexImage(mutex);
464         if (m_Verbose) std::cout <<"Forwarp warping using a thread-safe algorithm" <<std::endl;
465       }
466     else  if(m_Verbose)std::cout <<"Forwarp warping using a thread-unsafe algorithm" <<std::endl;
467
468     //Execute helper class
469     helper1->Update();
470     
471     //Get the output
472     typename InternalImageType::Pointer temp= helper1->GetOutput();
473    
474     //For clarity
475     weights=helper1->GetWeights();
476    
477    
478     //=========================================================================== 
479     //2. Normalize the output by the weights and remove holes 
480     //Call threaded helper class 
481     typedef HelperClass2<InternalImageType, OutputImageType> HelperClass2Type;
482     typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
483     
484     //Set temporary output as input
485     if(m_NumberOfThreadsIsGiven)helper2->SetNumberOfThreads(m_NumberOfThreads);
486     helper2->SetInput(temp);
487     helper2->SetWeights(weights);
488     helper2->SetEdgePaddingValue(m_EdgePaddingValue);
489     
490     //Execute helper class
491     helper2->Update();
492     
493     //Set the output
494     this->SetNthOutput(0, helper2->GetOutput());
495   }
496       
497 }
498
499 #endif