]> Creatis software - clitk.git/blob - itk/clitkInvertVFFilter.txx
removed headers
[clitk.git] / itk / clitkInvertVFFilter.txx
1 #ifndef __clitkInvertVFFilter_txx
2 #define __clitkInvertVFFilter_txx
3 namespace
4 {
5
6   //=========================================================================================================================
7   //helper class 1 to allow a threaded execution: add contributions of input to output and update weights
8   //=========================================================================================================================
9   template<class InputImageType, class OutputImageType> class ITK_EXPORT HelperClass1 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
10   {
11     
12   public: 
13     /** Standard class typedefs. */
14     typedef HelperClass1  Self;
15     typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
16     typedef itk::SmartPointer<Self>         Pointer;
17     typedef itk::SmartPointer<const Self>   ConstPointer;
18     
19     /** Method for creation through the object factory. */
20     itkNewMacro(Self);
21     
22     /** Run-time type information (and related methods) */
23     itkTypeMacro( HelperClass1, ImageToImageFilter );
24         
25     /** Constants for the image dimensions */
26     itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
27  
28
29     //Typedefs
30     typedef typename OutputImageType::PixelType        PixelType;
31     typedef itk::Image<double, ImageDimension > WeightsImageType;
32     typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension > MutexImageType;
33
34     //===================================================================================   
35     //Set methods
36     void SetWeights(const typename WeightsImageType::Pointer input)
37     {
38       m_Weights = input;
39       this->Modified();
40     }
41     void SetMutexImage(const typename MutexImageType::Pointer input)
42     {
43       m_MutexImage=input;
44       this->Modified();
45       m_ThreadSafe=true;
46     }
47
48     //Get methods
49     typename  WeightsImageType::Pointer GetWeights(){return m_Weights;}
50
51     /** Typedef to describe the output image region type. */
52     typedef typename OutputImageType::RegionType OutputImageRegionType;
53     
54   protected:
55     HelperClass1();
56     ~HelperClass1(){};
57     
58     //the actual processing
59     void BeforeThreadedGenerateData();
60     void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
61
62     //member data
63     typename  WeightsImageType::Pointer m_Weights;
64     typename MutexImageType::Pointer m_MutexImage;
65     bool m_ThreadSafe;
66
67   };
68
69
70
71   //=========================================================================================================================
72   //Member functions of the helper class 1
73   //=========================================================================================================================
74
75
76   //=========================================================================================================================
77   //Empty constructor
78   template<class InputImageType, class OutputImageType > 
79   HelperClass1<InputImageType, OutputImageType>::HelperClass1() 
80   {
81     m_ThreadSafe=false; 
82   }
83
84   //=========================================================================================================================
85   //Before threaded data
86   template<class InputImageType, class OutputImageType >
87   void HelperClass1<InputImageType, OutputImageType>::BeforeThreadedGenerateData()
88   {
89     //Since we will add, put to zero!
90     this->GetOutput()->FillBuffer(itk::NumericTraits<double>::Zero);
91     this->GetWeights()->FillBuffer(itk::NumericTraits<double>::Zero);
92   }
93
94   //=========================================================================================================================
95   //update the output for the outputRegionForThread
96   template<class InputImageType, class OutputImageType> 
97   void HelperClass1<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
98   {
99  
100     //Get pointer to the input
101     typename InputImageType::ConstPointer inputPtr = this->GetInput();
102     
103     //Get pointer to the output
104     typename OutputImageType::Pointer outputPtr = this->GetOutput();
105     typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
106
107     //Iterator over input
108     typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
109
110     //define them over the outputRegionForThread
111     InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
112  
113     //Initialize
114     typename InputImageType::IndexType index;
115     itk::ContinuousIndex<double,ImageDimension> contIndex;
116     typename InputImageType::PointType ipoint;
117     typename OutputImageType::PointType opoint;
118     typedef typename OutputImageType::PixelType DisplacementType;
119     DisplacementType displacement;
120     inputIt.GoToBegin();
121
122     //define some temp variables
123     signed long baseIndex[ImageDimension];
124     double distance[ImageDimension];
125     unsigned int dim, counter, upper;
126     double totalOverlap,overlap;
127     typename OutputImageType::IndexType neighIndex;
128
129     //Find the number of neighbors
130     unsigned int neighbors =  1 << ImageDimension;
131      
132     //==================================================================================================
133     //Loop over the region and add the intensities to the output and the weight to the weights
134     //==================================================================================================
135     while( !inputIt.IsAtEnd() )
136       {
137         // get the input image index
138         index = inputIt.GetIndex();
139         inputPtr->TransformIndexToPhysicalPoint( index,ipoint );
140
141         // get the required displacement
142         displacement = inputIt.Get();
143         
144         // compute the required output image point
145         for(unsigned int j = 0; j < ImageDimension; j++ ) opoint[j] = ipoint[j] + (double)displacement[j];
146
147         // Update the output and the weights
148         if(outputPtr->TransformPhysicalPointToContinuousIndex(opoint, contIndex ) )
149           {         
150             for(dim = 0; dim < ImageDimension; dim++)
151               {
152                 // The following  block is equivalent to the following line without
153                 // having to call floor. (Only for positive inputs, we already now that is in the image)
154                 // baseIndex[dim] = (long) vcl_floor(contIndex[dim] );
155         
156                     baseIndex[dim] = (long) contIndex[dim];
157                     distance[dim] = contIndex[dim] - double( baseIndex[dim] );
158               }
159
160             //Add contribution for each neighbor
161             totalOverlap = itk::NumericTraits<double>::Zero;
162             for( counter = 0; counter < neighbors ; counter++ )
163               {         
164                 overlap = 1.0;          // fraction overlap
165                 upper = counter;  // each bit indicates upper/lower neighbour
166                                 
167                 // get neighbor index and overlap fraction
168                 for( dim = 0; dim < 3; dim++ )
169                   {
170                     if ( upper & 1 )
171                       {
172                         neighIndex[dim] = baseIndex[dim] + 1;
173                         overlap *= distance[dim];
174                       }
175                     else
176                       {
177                         neighIndex[dim] = baseIndex[dim];
178                         overlap *= 1.0 - distance[dim];
179                       }
180                     upper >>= 1;
181                   }
182
183         
184                 
185                 //Set neighbor value only if overlap is not zero
186                 if( (overlap>0.0)) // && 
187                   //                    (static_cast<unsigned int>(neighIndex[0])<size[0]) && 
188                   //                    (static_cast<unsigned int>(neighIndex[1])<size[1]) && 
189                   //                    (static_cast<unsigned int>(neighIndex[2])<size[2]) &&
190                   //                    (neighIndex[0]>=0) &&
191                   //                    (neighIndex[1]>=0) &&
192                       //                        (neighIndex[2]>=0) )
193                   {
194                     //what to store? the original displacement vector?
195                     if (! m_ThreadSafe)
196                       {
197                         //Set the pixel and weight at neighIndex 
198                         outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) - (displacement*overlap));
199                         m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
200                       }
201
202                     else 
203                       {
204                         //Entering critilal section: shared memory
205                         m_MutexImage->GetPixel(neighIndex).Lock();
206                         
207                         //Set the pixel and weight at neighIndex 
208                         outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) - (displacement*overlap));
209                         m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
210
211                         //Unlock
212                         m_MutexImage->GetPixel(neighIndex).Unlock();
213                             
214                       }
215                     //Add to total overlap
216                     totalOverlap += overlap;
217                   }
218                 
219                 if( totalOverlap == 1.0 )
220                   {
221                     // finished
222                     break;
223                   }
224               }      
225           }
226         
227         ++inputIt;
228       }
229
230   }
231
232
233
234   //=========================================================================================================================
235   //helper class 2 to allow a threaded execution of normalisation by the weights
236   //=========================================================================================================================
237   template<class InputImageType, class OutputImageType> class HelperClass2 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
238   {
239     
240   public: 
241     /** Standard class typedefs. */
242     typedef HelperClass2  Self;
243     typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
244     typedef itk::SmartPointer<Self>         Pointer;
245     typedef itk::SmartPointer<const Self>   ConstPointer;
246     
247     /** Method for creation through the object factory. */
248     itkNewMacro(Self);
249     
250     /** Run-time type information (and related methods) */
251     itkTypeMacro( HelperClass2, ImageToImageFilter );
252         
253     /** Constants for the image dimensions */
254     itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
255      
256     //Typedefs
257     typedef typename OutputImageType::PixelType        PixelType;
258     typedef itk::Image<double,ImageDimension> WeightsImageType;
259    
260       //Set methods
261     void SetWeights(const typename WeightsImageType::Pointer input)
262     {
263       m_Weights = input;
264       this->Modified();
265     }
266     void SetEdgePaddingValue(PixelType value)
267     {
268       m_EdgePaddingValue = value;
269       this->Modified();
270     }
271
272     /** Typedef to describe the output image region type. */
273     typedef typename OutputImageType::RegionType OutputImageRegionType;
274     
275   protected:
276     HelperClass2();
277     ~HelperClass2(){};
278     
279     
280     //the actual processing
281     void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
282
283
284     //member data
285     typename     WeightsImageType::Pointer m_Weights;
286     PixelType m_EdgePaddingValue;
287
288   } ;
289
290
291
292   //=========================================================================================================================
293   //Member functions of the helper class 2
294   //=========================================================================================================================
295   
296   
297   //=========================================================================================================================
298   //Empty constructor
299   template<class InputImageType, class OutputImageType > HelperClass2<InputImageType, OutputImageType>::HelperClass2()
300   {
301     m_EdgePaddingValue=itk::NumericTraits<PixelType>::Zero; 
302   }
303   
304
305   //=========================================================================================================================
306   //update the output for the outputRegionForThread
307   template<class InputImageType, class OutputImageType > void HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
308   {
309     
310     //Get pointer to the input
311     typename InputImageType::ConstPointer inputPtr = this->GetInput();
312     
313     //Get pointer to the output
314     typename OutputImageType::Pointer outputPtr = this->GetOutput();
315
316     //Iterators over input, weigths  and output
317     typedef itk::ImageRegionConstIterator<InputImageType> InputImageIteratorType;
318     typedef itk::ImageRegionIterator<OutputImageType> OutputImageIteratorType;
319     typedef itk::ImageRegionIterator<WeightsImageType> WeightsImageIteratorType;
320
321     //define them over the outputRegionForThread
322     OutputImageIteratorType outputIt(outputPtr, outputRegionForThread);
323     InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
324     WeightsImageIteratorType weightsIt(m_Weights, outputRegionForThread);
325
326
327     //==================================================================================================
328     //loop over the output and normalize the input, remove holes
329     PixelType neighValue;
330     double  zero = itk::NumericTraits<double>::Zero;
331     while (!outputIt.IsAtEnd())
332       {
333         //the weight is not zero
334         if (weightsIt.Get() != zero)
335           {
336             //divide by the weight
337             outputIt.Set(static_cast<PixelType>(inputIt.Get()/weightsIt.Get()));
338           }
339
340         //copy the value of the  neighbour that was just processed
341         else 
342           {
343             if(!outputIt.IsAtBegin())
344               {
345                 //go back
346                 --outputIt;
347
348                 //Neighbour cannot have zero weight because it should be filled already
349                 neighValue=outputIt.Get();
350                 ++outputIt;
351                 outputIt.Set(neighValue);
352                 //DD("hole filled");
353               }
354             else{
355               //DD("is at begin, setting edgepadding value");
356               outputIt.Set(m_EdgePaddingValue);
357             }
358           }
359         ++weightsIt; 
360         ++outputIt;
361         ++inputIt;
362
363       }//end while
364   }//end member
365
366
367 }//end nameless namespace
368
369
370
371 namespace clitk
372 {
373
374   //=========================================================================================================================
375   // The rest is the InvertVFFilter
376   //=========================================================================================================================
377
378   //=========================================================================================================================
379   //constructor
380   template <class InputImageType, class OutputImageType> 
381   InvertVFFilter<InputImageType, OutputImageType>::InvertVFFilter()
382   {
383     m_EdgePaddingValue=itk::NumericTraits<PixelType>::Zero; //no other reasonable value?
384     m_ThreadSafe=false;
385     m_Verbose=false;
386   }
387
388
389   //=========================================================================================================================
390   //Update
391   template <class InputImageType, class OutputImageType> void InvertVFFilter<InputImageType, OutputImageType>::GenerateData()
392   {
393
394     //Get the properties of the input
395     typename InputImageType::ConstPointer inputPtr=this->GetInput();
396     typename WeightsImageType::RegionType region;
397     typename WeightsImageType::RegionType::SizeType size=inputPtr->GetLargestPossibleRegion().GetSize();
398     region.SetSize(size);
399     typename OutputImageType::IndexType start;
400     for (unsigned int i=0; i< ImageDimension; i++) start[i]=0;
401     region.SetIndex(start);
402     PixelType zero = itk::NumericTraits<double>::Zero;
403  
404   
405     //Allocate the weights
406     typename WeightsImageType::Pointer weights=WeightsImageType::New();
407     weights->SetRegions(region);
408     weights->Allocate();
409     weights->SetSpacing(inputPtr->GetSpacing());
410
411     //=========================================================================== 
412     //Inversion is divided in in two loops, for each we will call a threaded helper class
413     //1. add contribution of input to output and update weights
414     //2. normalize the output by the weight and remove holes
415     //=========================================================================== 
416
417
418     //=========================================================================== 
419     //1. add contribution of input to output and update weights
420
421     //Define an internal image type
422
423     typedef itk::Image<itk::Vector<double,ImageDimension>, ImageDimension > InternalImageType;
424
425     //Call threaded helper class 1
426     typedef HelperClass1<InputImageType, InternalImageType > HelperClass1Type;
427     typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
428     
429     //Set input
430     if(m_NumberOfThreadsIsGiven)helper1->SetNumberOfThreads(m_NumberOfThreads);
431     helper1->SetInput(inputPtr);
432     helper1->SetWeights(weights);
433
434     //Threadsafe?
435     if(m_ThreadSafe)
436       {
437         //Allocate the mutex image
438         typename MutexImageType::Pointer mutex=InvertVFFilter::MutexImageType::New();
439         mutex->SetRegions(region);
440         mutex->Allocate();
441         mutex->SetSpacing(inputPtr->GetSpacing());
442         helper1->SetMutexImage(mutex);
443         if (m_Verbose) std::cout <<"Inverting using a thread-safe algorithm" <<std::endl;
444       }
445     else  if(m_Verbose)std::cout <<"Inverting using a thread-unsafe algorithm" <<std::endl;
446
447     //Execute helper class
448     helper1->Update();
449
450     //Get the output
451     typename InternalImageType::Pointer temp= helper1->GetOutput();
452     weights=helper1->GetWeights();
453
454
455     //=========================================================================== 
456     //2. Normalize the output by the weights and remove holes 
457     //Call threaded helper class 
458     typedef HelperClass2<InternalImageType, OutputImageType> HelperClass2Type;
459     typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
460     
461     //Set temporary output as input
462     helper2->SetInput(temp);
463     helper2->SetWeights(weights);
464     helper2->SetEdgePaddingValue(m_EdgePaddingValue);
465
466     //Execute helper class
467     if (m_Verbose) std::cout << "Normalizing the output VF..."<<std::endl;
468     helper2->Update();
469
470     //Set the output
471     this->SetNthOutput(0, helper2->GetOutput());
472   }
473   
474   
475     
476 }
477
478 #endif