]> Creatis software - clitk.git/blob - itk/clitkInvertVFFilter.txx
In SliceBySliceSetBackgroundFromLineSeparation add keepIfEqual option
[clitk.git] / itk / clitkInvertVFFilter.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to:
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://www.centreleonberard.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef __clitkInvertVFFilter_txx
19 #define __clitkInvertVFFilter_txx
20
21 namespace
22 {
23
24 //=========================================================================================================================
25 //helper class 1 to allow a threaded execution: add contributions of input to output and update weights
26 //=========================================================================================================================
27 template<class InputImageType, class OutputImageType> class ITK_EXPORT HelperClass1 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
28 {
29
30 public:
31   /** Standard class typedefs. */
32   typedef HelperClass1  Self;
33   typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
34   typedef itk::SmartPointer<Self>         Pointer;
35   typedef itk::SmartPointer<const Self>   ConstPointer;
36
37   /** Method for creation through the object factory. */
38   itkNewMacro(Self);
39
40   /** Run-time type information (and related methods) */
41   itkTypeMacro( HelperClass1, ImageToImageFilter );
42
43   /** Constants for the image dimensions */
44   itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
45
46
47   //Typedefs
48   typedef typename OutputImageType::PixelType        PixelType;
49   typedef itk::Image<double, ImageDimension > WeightsImageType;
50   typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension > MutexImageType;
51
52   //===================================================================================
53   //Set methods
54   void SetWeights(const typename WeightsImageType::Pointer input) {
55     m_Weights = input;
56     this->Modified();
57   }
58   void SetMutexImage(const typename MutexImageType::Pointer input) {
59     m_MutexImage=input;
60     this->Modified();
61     m_ThreadSafe=true;
62   }
63
64   //Get methods
65   typename  WeightsImageType::Pointer GetWeights() {
66     return m_Weights;
67   }
68
69   /** Typedef to describe the output image region type. */
70   typedef typename OutputImageType::RegionType OutputImageRegionType;
71
72 protected:
73   HelperClass1();
74   ~HelperClass1() {};
75
76   //the actual processing
77   void BeforeThreadedGenerateData();
78 #if ITK_VERSION_MAJOR >= 4
79   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId );
80 #else
81   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
82 #endif
83
84   //member data
85   typename  WeightsImageType::Pointer m_Weights;
86   typename MutexImageType::Pointer m_MutexImage;
87   bool m_ThreadSafe;
88
89 };
90
91
92
93 //=========================================================================================================================
94 //Member functions of the helper class 1
95 //=========================================================================================================================
96
97
98 //=========================================================================================================================
99 //Empty constructor
100 template<class InputImageType, class OutputImageType >
101 HelperClass1<InputImageType, OutputImageType>::HelperClass1()
102 {
103   m_ThreadSafe=false;
104 }
105
106 //=========================================================================================================================
107 //Before threaded data
108 template<class InputImageType, class OutputImageType >
109 void HelperClass1<InputImageType, OutputImageType>::BeforeThreadedGenerateData()
110 {
111   //std::cout << "HelperClass1::BeforeThreadedGenerateData - IN" << std::endl;
112   //Since we will add, put to zero!
113   this->GetOutput()->FillBuffer(itk::NumericTraits<double>::Zero);
114   this->GetWeights()->FillBuffer(itk::NumericTraits<double>::Zero);
115 }
116
117 //=========================================================================================================================
118 //update the output for the outputRegionForThread
119 template<class InputImageType, class OutputImageType>
120 #if ITK_VERSION_MAJOR >= 4
121 void HelperClass1<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId )
122 #else
123 void HelperClass1<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
124 #endif
125 {
126 //   std::cout << "HelperClass1::ThreadedGenerateData - IN " << threadId << std::endl;
127   //Get pointer to the input
128   typename InputImageType::ConstPointer inputPtr = this->GetInput();
129
130   //Get pointer to the output
131   typename OutputImageType::Pointer outputPtr = this->GetOutput();
132   //typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
133
134   //Iterator over input
135   typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
136
137   //define them over the outputRegionForThread
138   InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
139
140   //Initialize
141   typename InputImageType::IndexType index;
142   itk::ContinuousIndex<double,ImageDimension> contIndex, inContIndex;
143   typename InputImageType::PointType ipoint;
144   typename OutputImageType::PointType opoint;
145   typedef typename OutputImageType::PixelType DisplacementType;
146   DisplacementType displacement;
147   inputIt.GoToBegin();
148   
149   typename OutputImageType::SizeType size = outputPtr->GetLargestPossibleRegion().GetSize();
150
151   //define some temp variables
152   signed long baseIndex[ImageDimension];
153   double distance[ImageDimension];
154   unsigned int dim, counter, upper;
155   double totalOverlap,overlap;
156   typename OutputImageType::IndexType neighIndex;
157
158   //Find the number of neighbors
159   unsigned int neighbors =  1 << ImageDimension;
160
161   //==================================================================================================
162   //Loop over the output region and add the intensities from the input to the output and the weight to the weights
163   //==================================================================================================
164   while( !inputIt.IsAtEnd() ) {
165     // get the input image index
166     index = inputIt.GetIndex();
167     inputPtr->TransformIndexToPhysicalPoint( index,ipoint );
168
169     // get the required displacement
170     displacement = inputIt.Get();
171
172     // compute the required output image point
173     for(unsigned int j = 0; j < ImageDimension; j++ ) opoint[j] = ipoint[j] + (double)displacement[j];
174
175     // Update the output and the weights
176     if(outputPtr->TransformPhysicalPointToContinuousIndex(opoint, contIndex ) ) {
177       for(dim = 0; dim < ImageDimension; dim++) {
178         // The following  block is equivalent to the following line without
179         // having to call floor. (Only for positive inputs, we already now that is in the image)
180         // baseIndex[dim] = (long) vcl_floor(contIndex[dim] );
181
182         baseIndex[dim] = (long) contIndex[dim];
183         distance[dim] = contIndex[dim] - double( baseIndex[dim] );
184       }
185
186       //Add contribution for each neighbor
187       totalOverlap = itk::NumericTraits<double>::Zero;
188       for( counter = 0; counter < neighbors ; counter++ ) {
189         overlap = 1.0;          // fraction overlap
190         upper = counter;  // each bit indicates upper/lower neighbour
191
192         // get neighbor index and overlap fraction
193         for( dim = 0; dim < ImageDimension; dim++ ) {
194           if ( upper & 1 ) {
195             neighIndex[dim] = baseIndex[dim] + 1;
196             overlap *= distance[dim];
197           } else {
198             neighIndex[dim] = baseIndex[dim];
199             overlap *= 1.0 - distance[dim];
200           }
201           upper >>= 1;
202           
203           if (neighIndex[dim] >= size[dim])
204             neighIndex[dim] = size[dim] - 1;
205         }
206
207
208         //Set neighbor value only if overlap is not zero
209         if( (overlap>0.0)) // &&
210           //                    (static_cast<unsigned int>(neighIndex[0])<size[0]) &&
211           //                    (static_cast<unsigned int>(neighIndex[1])<size[1]) &&
212           //                    (static_cast<unsigned int>(neighIndex[2])<size[2]) &&
213           //                    (neighIndex[0]>=0) &&
214           //                    (neighIndex[1]>=0) &&
215           //                    (neighIndex[2]>=0) )
216         {
217           //what to store? the original displacement vector?
218           if (! m_ThreadSafe) {
219             //Set the pixel and weight at neighIndex
220             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) - (displacement*overlap));
221             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
222           }
223
224           else {
225             //Entering critilal section: shared memory
226             m_MutexImage->GetPixel(neighIndex).Lock();
227
228             //Set the pixel and weight at neighIndex
229             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) - (displacement*overlap));
230             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
231
232             //Unlock
233             m_MutexImage->GetPixel(neighIndex).Unlock();
234
235           }
236           //Add to total overlap
237           totalOverlap += overlap;
238         }
239
240         if( totalOverlap == 1.0 ) {
241           // finished
242           break;
243         }
244       }
245     }
246
247     ++inputIt;
248   }
249
250 //   std::cout << "HelperClass1::ThreadedGenerateData - OUT " << threadId << std::endl;
251 }
252
253
254
255 //=========================================================================================================================
256 //helper class 2 to allow a threaded execution of normalisation by the weights
257 //=========================================================================================================================
258 template<class InputImageType, class OutputImageType> class HelperClass2 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
259 {
260
261 public:
262   /** Standard class typedefs. */
263   typedef HelperClass2  Self;
264   typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
265   typedef itk::SmartPointer<Self>         Pointer;
266   typedef itk::SmartPointer<const Self>   ConstPointer;
267
268   /** Method for creation through the object factory. */
269   itkNewMacro(Self);
270
271   /** Run-time type information (and related methods) */
272   itkTypeMacro( HelperClass2, ImageToImageFilter );
273
274   /** Constants for the image dimensions */
275   itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
276
277   //Typedefs
278   typedef typename OutputImageType::PixelType        PixelType;
279   typedef itk::Image<double,ImageDimension> WeightsImageType;
280
281   //Set methods
282   void SetWeights(const typename WeightsImageType::Pointer input) {
283     m_Weights = input;
284     this->Modified();
285   }
286   void SetEdgePaddingValue(PixelType value) {
287     m_EdgePaddingValue = value;
288     this->Modified();
289   }
290
291   /** Typedef to describe the output image region type. */
292   typedef typename OutputImageType::RegionType OutputImageRegionType;
293
294 protected:
295   HelperClass2();
296   ~HelperClass2() {};
297
298
299   //the actual processing
300 #if ITK_VERSION_MAJOR >= 4
301   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId );
302 #else
303   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
304 #endif
305
306   //member data
307   typename     WeightsImageType::Pointer m_Weights;
308   PixelType m_EdgePaddingValue;
309
310 } ;
311
312
313
314 //=========================================================================================================================
315 //Member functions of the helper class 2
316 //=========================================================================================================================
317
318
319 //=========================================================================================================================
320 //Empty constructor
321 template<class InputImageType, class OutputImageType > HelperClass2<InputImageType, OutputImageType>::HelperClass2()
322 {
323   m_EdgePaddingValue=itk::NumericTraits<PixelType>::Zero;
324 }
325
326
327 //=========================================================================================================================
328 //update the output for the outputRegionForThread
329 #if ITK_VERSION_MAJOR >= 4
330 template<class InputImageType, class OutputImageType > void HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId )
331 #else
332 template<class InputImageType, class OutputImageType > void HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
333 #endif
334 {
335 //   std::cout << "HelperClass2::ThreadedGenerateData - IN " << threadId << std::endl;
336   
337   //Get pointer to the input
338   typename InputImageType::ConstPointer inputPtr = this->GetInput();
339
340   //Get pointer to the output
341   typename OutputImageType::Pointer outputPtr = this->GetOutput();
342
343   //Iterators over input, weigths  and output
344   typedef itk::ImageRegionConstIterator<InputImageType> InputImageIteratorType;
345   typedef itk::ImageRegionIterator<OutputImageType> OutputImageIteratorType;
346   typedef itk::ImageRegionIterator<WeightsImageType> WeightsImageIteratorType;
347
348   //define them over the outputRegionForThread
349   OutputImageIteratorType outputIt(outputPtr, outputRegionForThread);
350   InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
351   WeightsImageIteratorType weightsIt(m_Weights, outputRegionForThread);
352
353
354   //==================================================================================================
355   //loop over the output and normalize the input, remove holes
356   PixelType neighValue;
357   double  zero = itk::NumericTraits<double>::Zero;
358   while (!outputIt.IsAtEnd()) {
359     //the weight is not zero
360     if (weightsIt.Get() != zero) {
361       //divide by the weight
362       outputIt.Set(static_cast<PixelType>(inputIt.Get()/weightsIt.Get()));
363     }
364
365     //copy the value of the  neighbour that was just processed
366     else {
367       if(!outputIt.IsAtBegin()) {
368         //go back
369         --outputIt;
370
371         //Neighbour cannot have zero weight because it should be filled already
372         neighValue=outputIt.Get();
373         ++outputIt;
374         outputIt.Set(neighValue);
375         //DD("hole filled");
376       } else {
377         //DD("is at begin, setting edgepadding value");
378         outputIt.Set(m_EdgePaddingValue);
379       }
380     }
381     ++weightsIt;
382     ++outputIt;
383     ++inputIt;
384
385   }//end while
386   
387 //   std::cout << "HelperClass2::ThreadedGenerateData - OUT " << threadId << std::endl;
388   
389 }//end member
390
391
392 }//end nameless namespace
393
394
395
396 namespace clitk
397 {
398
399 //=========================================================================================================================
400 // The rest is the InvertVFFilter
401 //=========================================================================================================================
402
403 //=========================================================================================================================
404 //constructor
405 template <class InputImageType, class OutputImageType>
406 InvertVFFilter<InputImageType, OutputImageType>::InvertVFFilter()
407 {
408   m_EdgePaddingValue=itk::NumericTraits<PixelType>::Zero; //no other reasonable value?
409   m_ThreadSafe=false;
410   m_Verbose=false;
411 }
412
413 //=========================================================================================================================
414 //Update
415 template <class InputImageType, class OutputImageType> void InvertVFFilter<InputImageType, OutputImageType>::GenerateData()
416 {
417   // std::cout << "InvertVFFilter::GenerateData - IN" << std::endl;
418
419   //Get the properties of the input
420   typename InputImageType::ConstPointer inputPtr=this->GetInput();
421   typename WeightsImageType::RegionType region = inputPtr->GetLargestPossibleRegion();
422
423   //Allocate the weights
424   typename WeightsImageType::Pointer weights=WeightsImageType::New();
425   weights->SetOrigin(inputPtr->GetOrigin());
426   weights->SetRegions(region);
427   weights->Allocate();
428   weights->SetSpacing(inputPtr->GetSpacing());
429
430   //===========================================================================
431   //Inversion is divided in in two loops, for each we will call a threaded helper class
432   //1. add contribution of input to output and update weights
433   //2. normalize the output by the weight and remove holes
434   //===========================================================================
435
436
437   //===========================================================================
438   //1. add contribution of input to output and update weights
439
440   //Define an internal image type
441
442   typedef itk::Image<itk::Vector<double,ImageDimension>, ImageDimension > InternalImageType;
443
444   //Call threaded helper class 1
445   typedef HelperClass1<InputImageType, InternalImageType > HelperClass1Type;
446   typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
447
448   //Set input
449   if(m_NumberOfThreadsIsGiven)helper1->SetNumberOfThreads(m_NumberOfThreads);
450   helper1->SetInput(inputPtr);
451   helper1->SetWeights(weights);
452
453   //Threadsafe?
454   if(m_ThreadSafe) {
455     //Allocate the mutex image
456     typename MutexImageType::Pointer mutex=InvertVFFilter::MutexImageType::New();
457     mutex->SetRegions(region);
458     mutex->Allocate();
459     mutex->SetSpacing(inputPtr->GetSpacing());
460     helper1->SetMutexImage(mutex);
461     if (m_Verbose) std::cout <<"Inverting using a thread-safe algorithm" <<std::endl;
462   } else  if(m_Verbose)std::cout <<"Inverting using a thread-unsafe algorithm" <<std::endl;
463
464   //Execute helper class
465   helper1->Update();
466
467   //Get the output
468   typename InternalImageType::Pointer temp= helper1->GetOutput();
469   weights=helper1->GetWeights();
470
471
472   //===========================================================================
473   //2. Normalize the output by the weights and remove holes
474   //Call threaded helper class
475   typedef HelperClass2<InternalImageType, OutputImageType> HelperClass2Type;
476   typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
477
478   //Set temporary output as input
479   if(m_NumberOfThreadsIsGiven)helper2->SetNumberOfThreads(m_NumberOfThreads);
480   helper2->SetInput(temp);
481   helper2->SetWeights(weights);
482   helper2->SetEdgePaddingValue(m_EdgePaddingValue);
483
484   //Execute helper class
485   if (m_Verbose) std::cout << "Normalizing the output VF..."<<std::endl;
486   helper2->Update();
487
488   //Set the output
489   this->SetNthOutput(0, helper2->GetOutput());
490   
491   //std::cout << "InvertVFFilter::GenerateData - OUT" << std::endl;
492 }
493
494
495
496 }
497
498 #endif