]> Creatis software - clitk.git/blob - itk/clitkInvertVFFilter.txx
added the new headers
[clitk.git] / itk / clitkInvertVFFilter.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to: 
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://oncora1.lyon.fnclcc.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ======================================================================-====*/
18 #ifndef __clitkInvertVFFilter_txx
19 #define __clitkInvertVFFilter_txx
20 namespace
21 {
22
23   //=========================================================================================================================
24   //helper class 1 to allow a threaded execution: add contributions of input to output and update weights
25   //=========================================================================================================================
26   template<class InputImageType, class OutputImageType> class ITK_EXPORT HelperClass1 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
27   {
28     
29   public: 
30     /** Standard class typedefs. */
31     typedef HelperClass1  Self;
32     typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
33     typedef itk::SmartPointer<Self>         Pointer;
34     typedef itk::SmartPointer<const Self>   ConstPointer;
35     
36     /** Method for creation through the object factory. */
37     itkNewMacro(Self);
38     
39     /** Run-time type information (and related methods) */
40     itkTypeMacro( HelperClass1, ImageToImageFilter );
41         
42     /** Constants for the image dimensions */
43     itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
44  
45
46     //Typedefs
47     typedef typename OutputImageType::PixelType        PixelType;
48     typedef itk::Image<double, ImageDimension > WeightsImageType;
49     typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension > MutexImageType;
50
51     //===================================================================================   
52     //Set methods
53     void SetWeights(const typename WeightsImageType::Pointer input)
54     {
55       m_Weights = input;
56       this->Modified();
57     }
58     void SetMutexImage(const typename MutexImageType::Pointer input)
59     {
60       m_MutexImage=input;
61       this->Modified();
62       m_ThreadSafe=true;
63     }
64
65     //Get methods
66     typename  WeightsImageType::Pointer GetWeights(){return m_Weights;}
67
68     /** Typedef to describe the output image region type. */
69     typedef typename OutputImageType::RegionType OutputImageRegionType;
70     
71   protected:
72     HelperClass1();
73     ~HelperClass1(){};
74     
75     //the actual processing
76     void BeforeThreadedGenerateData();
77     void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
78
79     //member data
80     typename  WeightsImageType::Pointer m_Weights;
81     typename MutexImageType::Pointer m_MutexImage;
82     bool m_ThreadSafe;
83
84   };
85
86
87
88   //=========================================================================================================================
89   //Member functions of the helper class 1
90   //=========================================================================================================================
91
92
93   //=========================================================================================================================
94   //Empty constructor
95   template<class InputImageType, class OutputImageType > 
96   HelperClass1<InputImageType, OutputImageType>::HelperClass1() 
97   {
98     m_ThreadSafe=false; 
99   }
100
101   //=========================================================================================================================
102   //Before threaded data
103   template<class InputImageType, class OutputImageType >
104   void HelperClass1<InputImageType, OutputImageType>::BeforeThreadedGenerateData()
105   {
106     //Since we will add, put to zero!
107     this->GetOutput()->FillBuffer(itk::NumericTraits<double>::Zero);
108     this->GetWeights()->FillBuffer(itk::NumericTraits<double>::Zero);
109   }
110
111   //=========================================================================================================================
112   //update the output for the outputRegionForThread
113   template<class InputImageType, class OutputImageType> 
114   void HelperClass1<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
115   {
116  
117     //Get pointer to the input
118     typename InputImageType::ConstPointer inputPtr = this->GetInput();
119     
120     //Get pointer to the output
121     typename OutputImageType::Pointer outputPtr = this->GetOutput();
122     typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
123
124     //Iterator over input
125     typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
126
127     //define them over the outputRegionForThread
128     InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
129  
130     //Initialize
131     typename InputImageType::IndexType index;
132     itk::ContinuousIndex<double,ImageDimension> contIndex;
133     typename InputImageType::PointType ipoint;
134     typename OutputImageType::PointType opoint;
135     typedef typename OutputImageType::PixelType DisplacementType;
136     DisplacementType displacement;
137     inputIt.GoToBegin();
138
139     //define some temp variables
140     signed long baseIndex[ImageDimension];
141     double distance[ImageDimension];
142     unsigned int dim, counter, upper;
143     double totalOverlap,overlap;
144     typename OutputImageType::IndexType neighIndex;
145
146     //Find the number of neighbors
147     unsigned int neighbors =  1 << ImageDimension;
148      
149     //==================================================================================================
150     //Loop over the region and add the intensities to the output and the weight to the weights
151     //==================================================================================================
152     while( !inputIt.IsAtEnd() )
153       {
154         // get the input image index
155         index = inputIt.GetIndex();
156         inputPtr->TransformIndexToPhysicalPoint( index,ipoint );
157
158         // get the required displacement
159         displacement = inputIt.Get();
160         
161         // compute the required output image point
162         for(unsigned int j = 0; j < ImageDimension; j++ ) opoint[j] = ipoint[j] + (double)displacement[j];
163
164         // Update the output and the weights
165         if(outputPtr->TransformPhysicalPointToContinuousIndex(opoint, contIndex ) )
166           {         
167             for(dim = 0; dim < ImageDimension; dim++)
168               {
169                 // The following  block is equivalent to the following line without
170                 // having to call floor. (Only for positive inputs, we already now that is in the image)
171                 // baseIndex[dim] = (long) vcl_floor(contIndex[dim] );
172         
173                     baseIndex[dim] = (long) contIndex[dim];
174                     distance[dim] = contIndex[dim] - double( baseIndex[dim] );
175               }
176
177             //Add contribution for each neighbor
178             totalOverlap = itk::NumericTraits<double>::Zero;
179             for( counter = 0; counter < neighbors ; counter++ )
180               {         
181                 overlap = 1.0;          // fraction overlap
182                 upper = counter;  // each bit indicates upper/lower neighbour
183                                 
184                 // get neighbor index and overlap fraction
185                 for( dim = 0; dim < 3; dim++ )
186                   {
187                     if ( upper & 1 )
188                       {
189                         neighIndex[dim] = baseIndex[dim] + 1;
190                         overlap *= distance[dim];
191                       }
192                     else
193                       {
194                         neighIndex[dim] = baseIndex[dim];
195                         overlap *= 1.0 - distance[dim];
196                       }
197                     upper >>= 1;
198                   }
199
200         
201                 
202                 //Set neighbor value only if overlap is not zero
203                 if( (overlap>0.0)) // && 
204                   //                    (static_cast<unsigned int>(neighIndex[0])<size[0]) && 
205                   //                    (static_cast<unsigned int>(neighIndex[1])<size[1]) && 
206                   //                    (static_cast<unsigned int>(neighIndex[2])<size[2]) &&
207                   //                    (neighIndex[0]>=0) &&
208                   //                    (neighIndex[1]>=0) &&
209                       //                        (neighIndex[2]>=0) )
210                   {
211                     //what to store? the original displacement vector?
212                     if (! m_ThreadSafe)
213                       {
214                         //Set the pixel and weight at neighIndex 
215                         outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) - (displacement*overlap));
216                         m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
217                       }
218
219                     else 
220                       {
221                         //Entering critilal section: shared memory
222                         m_MutexImage->GetPixel(neighIndex).Lock();
223                         
224                         //Set the pixel and weight at neighIndex 
225                         outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) - (displacement*overlap));
226                         m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
227
228                         //Unlock
229                         m_MutexImage->GetPixel(neighIndex).Unlock();
230                             
231                       }
232                     //Add to total overlap
233                     totalOverlap += overlap;
234                   }
235                 
236                 if( totalOverlap == 1.0 )
237                   {
238                     // finished
239                     break;
240                   }
241               }      
242           }
243         
244         ++inputIt;
245       }
246
247   }
248
249
250
251   //=========================================================================================================================
252   //helper class 2 to allow a threaded execution of normalisation by the weights
253   //=========================================================================================================================
254   template<class InputImageType, class OutputImageType> class HelperClass2 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
255   {
256     
257   public: 
258     /** Standard class typedefs. */
259     typedef HelperClass2  Self;
260     typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
261     typedef itk::SmartPointer<Self>         Pointer;
262     typedef itk::SmartPointer<const Self>   ConstPointer;
263     
264     /** Method for creation through the object factory. */
265     itkNewMacro(Self);
266     
267     /** Run-time type information (and related methods) */
268     itkTypeMacro( HelperClass2, ImageToImageFilter );
269         
270     /** Constants for the image dimensions */
271     itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
272      
273     //Typedefs
274     typedef typename OutputImageType::PixelType        PixelType;
275     typedef itk::Image<double,ImageDimension> WeightsImageType;
276    
277       //Set methods
278     void SetWeights(const typename WeightsImageType::Pointer input)
279     {
280       m_Weights = input;
281       this->Modified();
282     }
283     void SetEdgePaddingValue(PixelType value)
284     {
285       m_EdgePaddingValue = value;
286       this->Modified();
287     }
288
289     /** Typedef to describe the output image region type. */
290     typedef typename OutputImageType::RegionType OutputImageRegionType;
291     
292   protected:
293     HelperClass2();
294     ~HelperClass2(){};
295     
296     
297     //the actual processing
298     void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
299
300
301     //member data
302     typename     WeightsImageType::Pointer m_Weights;
303     PixelType m_EdgePaddingValue;
304
305   } ;
306
307
308
309   //=========================================================================================================================
310   //Member functions of the helper class 2
311   //=========================================================================================================================
312   
313   
314   //=========================================================================================================================
315   //Empty constructor
316   template<class InputImageType, class OutputImageType > HelperClass2<InputImageType, OutputImageType>::HelperClass2()
317   {
318     m_EdgePaddingValue=itk::NumericTraits<PixelType>::Zero; 
319   }
320   
321
322   //=========================================================================================================================
323   //update the output for the outputRegionForThread
324   template<class InputImageType, class OutputImageType > void HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
325   {
326     
327     //Get pointer to the input
328     typename InputImageType::ConstPointer inputPtr = this->GetInput();
329     
330     //Get pointer to the output
331     typename OutputImageType::Pointer outputPtr = this->GetOutput();
332
333     //Iterators over input, weigths  and output
334     typedef itk::ImageRegionConstIterator<InputImageType> InputImageIteratorType;
335     typedef itk::ImageRegionIterator<OutputImageType> OutputImageIteratorType;
336     typedef itk::ImageRegionIterator<WeightsImageType> WeightsImageIteratorType;
337
338     //define them over the outputRegionForThread
339     OutputImageIteratorType outputIt(outputPtr, outputRegionForThread);
340     InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
341     WeightsImageIteratorType weightsIt(m_Weights, outputRegionForThread);
342
343
344     //==================================================================================================
345     //loop over the output and normalize the input, remove holes
346     PixelType neighValue;
347     double  zero = itk::NumericTraits<double>::Zero;
348     while (!outputIt.IsAtEnd())
349       {
350         //the weight is not zero
351         if (weightsIt.Get() != zero)
352           {
353             //divide by the weight
354             outputIt.Set(static_cast<PixelType>(inputIt.Get()/weightsIt.Get()));
355           }
356
357         //copy the value of the  neighbour that was just processed
358         else 
359           {
360             if(!outputIt.IsAtBegin())
361               {
362                 //go back
363                 --outputIt;
364
365                 //Neighbour cannot have zero weight because it should be filled already
366                 neighValue=outputIt.Get();
367                 ++outputIt;
368                 outputIt.Set(neighValue);
369                 //DD("hole filled");
370               }
371             else{
372               //DD("is at begin, setting edgepadding value");
373               outputIt.Set(m_EdgePaddingValue);
374             }
375           }
376         ++weightsIt; 
377         ++outputIt;
378         ++inputIt;
379
380       }//end while
381   }//end member
382
383
384 }//end nameless namespace
385
386
387
388 namespace clitk
389 {
390
391   //=========================================================================================================================
392   // The rest is the InvertVFFilter
393   //=========================================================================================================================
394
395   //=========================================================================================================================
396   //constructor
397   template <class InputImageType, class OutputImageType> 
398   InvertVFFilter<InputImageType, OutputImageType>::InvertVFFilter()
399   {
400     m_EdgePaddingValue=itk::NumericTraits<PixelType>::Zero; //no other reasonable value?
401     m_ThreadSafe=false;
402     m_Verbose=false;
403   }
404
405
406   //=========================================================================================================================
407   //Update
408   template <class InputImageType, class OutputImageType> void InvertVFFilter<InputImageType, OutputImageType>::GenerateData()
409   {
410
411     //Get the properties of the input
412     typename InputImageType::ConstPointer inputPtr=this->GetInput();
413     typename WeightsImageType::RegionType region;
414     typename WeightsImageType::RegionType::SizeType size=inputPtr->GetLargestPossibleRegion().GetSize();
415     region.SetSize(size);
416     typename OutputImageType::IndexType start;
417     for (unsigned int i=0; i< ImageDimension; i++) start[i]=0;
418     region.SetIndex(start);
419     PixelType zero = itk::NumericTraits<double>::Zero;
420  
421   
422     //Allocate the weights
423     typename WeightsImageType::Pointer weights=WeightsImageType::New();
424     weights->SetRegions(region);
425     weights->Allocate();
426     weights->SetSpacing(inputPtr->GetSpacing());
427
428     //=========================================================================== 
429     //Inversion is divided in in two loops, for each we will call a threaded helper class
430     //1. add contribution of input to output and update weights
431     //2. normalize the output by the weight and remove holes
432     //=========================================================================== 
433
434
435     //=========================================================================== 
436     //1. add contribution of input to output and update weights
437
438     //Define an internal image type
439
440     typedef itk::Image<itk::Vector<double,ImageDimension>, ImageDimension > InternalImageType;
441
442     //Call threaded helper class 1
443     typedef HelperClass1<InputImageType, InternalImageType > HelperClass1Type;
444     typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
445     
446     //Set input
447     if(m_NumberOfThreadsIsGiven)helper1->SetNumberOfThreads(m_NumberOfThreads);
448     helper1->SetInput(inputPtr);
449     helper1->SetWeights(weights);
450
451     //Threadsafe?
452     if(m_ThreadSafe)
453       {
454         //Allocate the mutex image
455         typename MutexImageType::Pointer mutex=InvertVFFilter::MutexImageType::New();
456         mutex->SetRegions(region);
457         mutex->Allocate();
458         mutex->SetSpacing(inputPtr->GetSpacing());
459         helper1->SetMutexImage(mutex);
460         if (m_Verbose) std::cout <<"Inverting using a thread-safe algorithm" <<std::endl;
461       }
462     else  if(m_Verbose)std::cout <<"Inverting using a thread-unsafe algorithm" <<std::endl;
463
464     //Execute helper class
465     helper1->Update();
466
467     //Get the output
468     typename InternalImageType::Pointer temp= helper1->GetOutput();
469     weights=helper1->GetWeights();
470
471
472     //=========================================================================== 
473     //2. Normalize the output by the weights and remove holes 
474     //Call threaded helper class 
475     typedef HelperClass2<InternalImageType, OutputImageType> HelperClass2Type;
476     typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
477     
478     //Set temporary output as input
479     helper2->SetInput(temp);
480     helper2->SetWeights(weights);
481     helper2->SetEdgePaddingValue(m_EdgePaddingValue);
482
483     //Execute helper class
484     if (m_Verbose) std::cout << "Normalizing the output VF..."<<std::endl;
485     helper2->Update();
486
487     //Set the output
488     this->SetNthOutput(0, helper2->GetOutput());
489   }
490   
491   
492     
493 }
494
495 #endif