]> Creatis software - clitk.git/blob - itk/clitkInvertVFFilter.txx
Initial revision
[clitk.git] / itk / clitkInvertVFFilter.txx
1 #ifndef __clitkInvertVFFilter_txx
2 #define __clitkInvertVFFilter_txx
3
4
5 // Put the helper classes in an anonymous namespace so that it is not
6 // exposed to the user
7
8 namespace
9 {
10
11   //=========================================================================================================================
12   //helper class 1 to allow a threaded execution: add contributions of input to output and update weights
13   //=========================================================================================================================
14   template<class InputImageType, class OutputImageType> class ITK_EXPORT HelperClass1 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
15   {
16     
17   public: 
18     /** Standard class typedefs. */
19     typedef HelperClass1  Self;
20     typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
21     typedef itk::SmartPointer<Self>         Pointer;
22     typedef itk::SmartPointer<const Self>   ConstPointer;
23     
24     /** Method for creation through the object factory. */
25     itkNewMacro(Self);
26     
27     /** Run-time type information (and related methods) */
28     itkTypeMacro( HelperClass1, ImageToImageFilter );
29         
30     /** Constants for the image dimensions */
31     itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
32  
33
34     //Typedefs
35     typedef typename OutputImageType::PixelType        PixelType;
36     typedef itk::Image<double, ImageDimension > WeightsImageType;
37     typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension > MutexImageType;
38
39     //===================================================================================   
40     //Set methods
41     void SetWeights(const typename WeightsImageType::Pointer input)
42     {
43       m_Weights = input;
44       this->Modified();
45     }
46     void SetMutexImage(const typename MutexImageType::Pointer input)
47     {
48       m_MutexImage=input;
49       this->Modified();
50       m_ThreadSafe=true;
51     }
52
53     //Get methods
54     typename  WeightsImageType::Pointer GetWeights(){return m_Weights;}
55
56     /** Typedef to describe the output image region type. */
57     typedef typename OutputImageType::RegionType OutputImageRegionType;
58     
59   protected:
60     HelperClass1();
61     ~HelperClass1(){};
62     
63     //the actual processing
64     void BeforeThreadedGenerateData();
65     void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
66
67     //member data
68     typename  WeightsImageType::Pointer m_Weights;
69     typename MutexImageType::Pointer m_MutexImage;
70     bool m_ThreadSafe;
71
72   };
73
74
75
76   //=========================================================================================================================
77   //Member functions of the helper class 1
78   //=========================================================================================================================
79
80
81   //=========================================================================================================================
82   //Empty constructor
83   template<class InputImageType, class OutputImageType > 
84   HelperClass1<InputImageType, OutputImageType>::HelperClass1() 
85   {
86     m_ThreadSafe=false; 
87   }
88
89   //=========================================================================================================================
90   //Before threaded data
91   template<class InputImageType, class OutputImageType >
92   void HelperClass1<InputImageType, OutputImageType>::BeforeThreadedGenerateData()
93   {
94     //Since we will add, put to zero!
95     this->GetOutput()->FillBuffer(itk::NumericTraits<double>::Zero);
96     this->GetWeights()->FillBuffer(itk::NumericTraits<double>::Zero);
97   }
98
99   //=========================================================================================================================
100   //update the output for the outputRegionForThread
101   template<class InputImageType, class OutputImageType> 
102   void HelperClass1<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
103   {
104  
105     //Get pointer to the input
106     typename InputImageType::ConstPointer inputPtr = this->GetInput();
107     
108     //Get pointer to the output
109     typename OutputImageType::Pointer outputPtr = this->GetOutput();
110     typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
111
112     //Iterator over input
113     typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
114
115     //define them over the outputRegionForThread
116     InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
117  
118     //Initialize
119     typename InputImageType::IndexType index;
120     itk::ContinuousIndex<double,ImageDimension> contIndex;
121     typename InputImageType::PointType ipoint;
122     typename OutputImageType::PointType opoint;
123     typedef typename OutputImageType::PixelType DisplacementType;
124     DisplacementType displacement;
125     inputIt.GoToBegin();
126
127     //define some temp variables
128     signed long baseIndex[ImageDimension];
129     double distance[ImageDimension];
130     unsigned int dim, counter, upper;
131     double totalOverlap,overlap;
132     typename OutputImageType::IndexType neighIndex;
133
134     //Find the number of neighbors
135     unsigned int neighbors =  1 << ImageDimension;
136      
137     //==================================================================================================
138     //Loop over the region and add the intensities to the output and the weight to the weights
139     //==================================================================================================
140     while( !inputIt.IsAtEnd() )
141       {
142         // get the input image index
143         index = inputIt.GetIndex();
144         inputPtr->TransformIndexToPhysicalPoint( index,ipoint );
145
146         // get the required displacement
147         displacement = inputIt.Get();
148         
149         // compute the required output image point
150         for(unsigned int j = 0; j < ImageDimension; j++ ) opoint[j] = ipoint[j] + (double)displacement[j];
151
152         // Update the output and the weights
153         if(outputPtr->TransformPhysicalPointToContinuousIndex(opoint, contIndex ) )
154           {         
155             for(dim = 0; dim < ImageDimension; dim++)
156               {
157                 // The following  block is equivalent to the following line without
158                 // having to call floor. (Only for positive inputs, we already now that is in the image)
159                 // baseIndex[dim] = (long) vcl_floor(contIndex[dim] );
160         
161                     baseIndex[dim] = (long) contIndex[dim];
162                     distance[dim] = contIndex[dim] - double( baseIndex[dim] );
163               }
164
165             //Add contribution for each neighbor
166             totalOverlap = itk::NumericTraits<double>::Zero;
167             for( counter = 0; counter < neighbors ; counter++ )
168               {         
169                 overlap = 1.0;          // fraction overlap
170                 upper = counter;  // each bit indicates upper/lower neighbour
171                                 
172                 // get neighbor index and overlap fraction
173                 for( dim = 0; dim < 3; dim++ )
174                   {
175                     if ( upper & 1 )
176                       {
177                         neighIndex[dim] = baseIndex[dim] + 1;
178                         overlap *= distance[dim];
179                       }
180                     else
181                       {
182                         neighIndex[dim] = baseIndex[dim];
183                         overlap *= 1.0 - distance[dim];
184                       }
185                     upper >>= 1;
186                   }
187
188         
189                 
190                 //Set neighbor value only if overlap is not zero
191                 if( (overlap>0.0)) // && 
192                   //                    (static_cast<unsigned int>(neighIndex[0])<size[0]) && 
193                   //                    (static_cast<unsigned int>(neighIndex[1])<size[1]) && 
194                   //                    (static_cast<unsigned int>(neighIndex[2])<size[2]) &&
195                   //                    (neighIndex[0]>=0) &&
196                   //                    (neighIndex[1]>=0) &&
197                       //                        (neighIndex[2]>=0) )
198                   {
199                     //what to store? the original displacement vector?
200                     if (! m_ThreadSafe)
201                       {
202                         //Set the pixel and weight at neighIndex 
203                         outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) - (displacement*overlap));
204                         m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
205                       }
206
207                     else 
208                       {
209                         //Entering critilal section: shared memory
210                         m_MutexImage->GetPixel(neighIndex).Lock();
211                         
212                         //Set the pixel and weight at neighIndex 
213                         outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) - (displacement*overlap));
214                         m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
215
216                         //Unlock
217                         m_MutexImage->GetPixel(neighIndex).Unlock();
218                             
219                       }
220                     //Add to total overlap
221                     totalOverlap += overlap;
222                   }
223                 
224                 if( totalOverlap == 1.0 )
225                   {
226                     // finished
227                     break;
228                   }
229               }      
230           }
231         
232         ++inputIt;
233       }
234
235   }
236
237
238
239   //=========================================================================================================================
240   //helper class 2 to allow a threaded execution of normalisation by the weights
241   //=========================================================================================================================
242   template<class InputImageType, class OutputImageType> class HelperClass2 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
243   {
244     
245   public: 
246     /** Standard class typedefs. */
247     typedef HelperClass2  Self;
248     typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
249     typedef itk::SmartPointer<Self>         Pointer;
250     typedef itk::SmartPointer<const Self>   ConstPointer;
251     
252     /** Method for creation through the object factory. */
253     itkNewMacro(Self);
254     
255     /** Run-time type information (and related methods) */
256     itkTypeMacro( HelperClass2, ImageToImageFilter );
257         
258     /** Constants for the image dimensions */
259     itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
260      
261     //Typedefs
262     typedef typename OutputImageType::PixelType        PixelType;
263     typedef itk::Image<double,ImageDimension> WeightsImageType;
264    
265       //Set methods
266     void SetWeights(const typename WeightsImageType::Pointer input)
267     {
268       m_Weights = input;
269       this->Modified();
270     }
271     void SetEdgePaddingValue(PixelType value)
272     {
273       m_EdgePaddingValue = value;
274       this->Modified();
275     }
276
277     /** Typedef to describe the output image region type. */
278     typedef typename OutputImageType::RegionType OutputImageRegionType;
279     
280   protected:
281     HelperClass2();
282     ~HelperClass2(){};
283     
284     
285     //the actual processing
286     void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
287
288
289     //member data
290     typename     WeightsImageType::Pointer m_Weights;
291     PixelType m_EdgePaddingValue;
292
293   } ;
294
295
296
297   //=========================================================================================================================
298   //Member functions of the helper class 2
299   //=========================================================================================================================
300   
301   
302   //=========================================================================================================================
303   //Empty constructor
304   template<class InputImageType, class OutputImageType > HelperClass2<InputImageType, OutputImageType>::HelperClass2()
305   {
306     m_EdgePaddingValue=itk::NumericTraits<PixelType>::Zero; 
307   }
308   
309
310   //=========================================================================================================================
311   //update the output for the outputRegionForThread
312   template<class InputImageType, class OutputImageType > void HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
313   {
314     
315     //Get pointer to the input
316     typename InputImageType::ConstPointer inputPtr = this->GetInput();
317     
318     //Get pointer to the output
319     typename OutputImageType::Pointer outputPtr = this->GetOutput();
320
321     //Iterators over input, weigths  and output
322     typedef itk::ImageRegionConstIterator<InputImageType> InputImageIteratorType;
323     typedef itk::ImageRegionIterator<OutputImageType> OutputImageIteratorType;
324     typedef itk::ImageRegionIterator<WeightsImageType> WeightsImageIteratorType;
325
326     //define them over the outputRegionForThread
327     OutputImageIteratorType outputIt(outputPtr, outputRegionForThread);
328     InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
329     WeightsImageIteratorType weightsIt(m_Weights, outputRegionForThread);
330
331
332     //==================================================================================================
333     //loop over the output and normalize the input, remove holes
334     PixelType neighValue;
335     double  zero = itk::NumericTraits<double>::Zero;
336     while (!outputIt.IsAtEnd())
337       {
338         //the weight is not zero
339         if (weightsIt.Get() != zero)
340           {
341             //divide by the weight
342             outputIt.Set(static_cast<PixelType>(inputIt.Get()/weightsIt.Get()));
343           }
344
345         //copy the value of the  neighbour that was just processed
346         else 
347           {
348             if(!outputIt.IsAtBegin())
349               {
350                 //go back
351                 --outputIt;
352
353                 //Neighbour cannot have zero weight because it should be filled already
354                 neighValue=outputIt.Get();
355                 ++outputIt;
356                 outputIt.Set(neighValue);
357                 //DD("hole filled");
358               }
359             else{
360               //DD("is at begin, setting edgepadding value");
361               outputIt.Set(m_EdgePaddingValue);
362             }
363           }
364         ++weightsIt; 
365         ++outputIt;
366         ++inputIt;
367
368       }//end while
369   }//end member
370
371
372 }//end nameless namespace
373
374
375
376 namespace clitk
377 {
378
379   //=========================================================================================================================
380   // The rest is the InvertVFFilter
381   //=========================================================================================================================
382
383   //=========================================================================================================================
384   //constructor
385   template <class InputImageType, class OutputImageType> 
386   InvertVFFilter<InputImageType, OutputImageType>::InvertVFFilter()
387   {
388     m_EdgePaddingValue=itk::NumericTraits<PixelType>::Zero; //no other reasonable value?
389     m_ThreadSafe=false;
390     m_Verbose=false;
391   }
392
393
394   //=========================================================================================================================
395   //Update
396   template <class InputImageType, class OutputImageType> void InvertVFFilter<InputImageType, OutputImageType>::GenerateData()
397   {
398
399     //Get the properties of the input
400     typename InputImageType::ConstPointer inputPtr=this->GetInput();
401     typename WeightsImageType::RegionType region;
402     typename WeightsImageType::RegionType::SizeType size=inputPtr->GetLargestPossibleRegion().GetSize();
403     region.SetSize(size);
404     typename OutputImageType::IndexType start;
405     for (unsigned int i=0; i< ImageDimension; i++) start[i]=0;
406     region.SetIndex(start);
407     PixelType zero = itk::NumericTraits<double>::Zero;
408  
409   
410     //Allocate the weights
411     typename WeightsImageType::Pointer weights=WeightsImageType::New();
412     weights->SetRegions(region);
413     weights->Allocate();
414     weights->SetSpacing(inputPtr->GetSpacing());
415
416     //=========================================================================== 
417     //Inversion is divided in in two loops, for each we will call a threaded helper class
418     //1. add contribution of input to output and update weights
419     //2. normalize the output by the weight and remove holes
420     //=========================================================================== 
421
422
423     //=========================================================================== 
424     //1. add contribution of input to output and update weights
425
426     //Define an internal image type
427
428     typedef itk::Image<itk::Vector<double,ImageDimension>, ImageDimension > InternalImageType;
429
430     //Call threaded helper class 1
431     typedef HelperClass1<InputImageType, InternalImageType > HelperClass1Type;
432     typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
433     
434     //Set input
435     if(m_NumberOfThreadsIsGiven)helper1->SetNumberOfThreads(m_NumberOfThreads);
436     helper1->SetInput(inputPtr);
437     helper1->SetWeights(weights);
438
439     //Threadsafe?
440     if(m_ThreadSafe)
441       {
442         //Allocate the mutex image
443         typename MutexImageType::Pointer mutex=InvertVFFilter::MutexImageType::New();
444         mutex->SetRegions(region);
445         mutex->Allocate();
446         mutex->SetSpacing(inputPtr->GetSpacing());
447         helper1->SetMutexImage(mutex);
448         if (m_Verbose) std::cout <<"Inverting using a thread-safe algorithm" <<std::endl;
449       }
450     else  if(m_Verbose)std::cout <<"Inverting using a thread-unsafe algorithm" <<std::endl;
451
452     //Execute helper class
453     helper1->Update();
454
455     //Get the output
456     typename InternalImageType::Pointer temp= helper1->GetOutput();
457     weights=helper1->GetWeights();
458
459
460     //=========================================================================== 
461     //2. Normalize the output by the weights and remove holes 
462     //Call threaded helper class 
463     typedef HelperClass2<InternalImageType, OutputImageType> HelperClass2Type;
464     typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
465     
466     //Set temporary output as input
467     helper2->SetInput(temp);
468     helper2->SetWeights(weights);
469     helper2->SetEdgePaddingValue(m_EdgePaddingValue);
470
471     //Execute helper class
472     if (m_Verbose) std::cout << "Normalizing the output VF..."<<std::endl;
473     helper2->Update();
474
475     //Set the output
476     this->SetNthOutput(0, helper2->GetOutput());
477   }
478   
479   
480     
481 }
482
483 #endif