]> Creatis software - clitk.git/blob - itk/clitkInvertVFFilter.txx
Change itkSimpleFastMutexLock to std::mutex
[clitk.git] / itk / clitkInvertVFFilter.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to:
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://www.centreleonberard.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef __clitkInvertVFFilter_txx
19 #define __clitkInvertVFFilter_txx
20
21 namespace
22 {
23
24 //=========================================================================================================================
25 //helper class 1 to allow a threaded execution: add contributions of input to output and update weights
26 //=========================================================================================================================
27 template<class InputImageType, class OutputImageType> class ITK_EXPORT HelperClass1 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
28 {
29
30 public:
31   /** Standard class typedefs. */
32   typedef HelperClass1  Self;
33   typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
34   typedef itk::SmartPointer<Self>         Pointer;
35   typedef itk::SmartPointer<const Self>   ConstPointer;
36
37   /** Method for creation through the object factory. */
38   itkNewMacro(Self);
39
40   /** Run-time type information (and related methods) */
41   itkTypeMacro( HelperClass1, ImageToImageFilter );
42
43   /** Constants for the image dimensions */
44   itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
45
46
47   //Typedefs
48   typedef typename OutputImageType::PixelType        PixelType;
49   typedef itk::Image<double, ImageDimension > WeightsImageType;
50 #if ITK_VERSION_MAJOR <= 4
51   typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension> MutexImageType;
52 #endif
53
54   //===================================================================================
55   //Set methods
56   void SetWeights(const typename WeightsImageType::Pointer input) {
57     m_Weights = input;
58     this->Modified();
59   }
60 #if ITK_VERSION_MAJOR <= 4
61   void SetMutexImage(const typename MutexImageType::Pointer input) {
62     m_MutexImage=input;
63     this->Modified();
64     m_ThreadSafe=true;
65   }
66 #else
67   void SetMutexImage() {
68     this->Modified();
69     m_ThreadSafe=true;
70   }
71 #endif
72
73   //Get methods
74   typename  WeightsImageType::Pointer GetWeights() {
75     return m_Weights;
76   }
77
78   /** Typedef to describe the output image region type. */
79   typedef typename OutputImageType::RegionType OutputImageRegionType;
80
81 protected:
82   HelperClass1();
83   ~HelperClass1() {};
84
85   //the actual processing
86   void BeforeThreadedGenerateData() ITK_OVERRIDE;
87   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId ) ITK_OVERRIDE;
88
89   //member data
90   typename  WeightsImageType::Pointer m_Weights;
91 #if ITK_VERSION_MAJOR <= 4
92   typename MutexImageType::Pointer m_MutexImage;
93 #else
94   std::mutex m_Mutex;
95 #endif
96   bool m_ThreadSafe;
97
98 };
99
100
101
102 //=========================================================================================================================
103 //Member functions of the helper class 1
104 //=========================================================================================================================
105
106
107 //=========================================================================================================================
108 //Empty constructor
109 template<class InputImageType, class OutputImageType >
110 HelperClass1<InputImageType, OutputImageType>::HelperClass1()
111 {
112   m_ThreadSafe=false;
113 }
114
115 //=========================================================================================================================
116 //Before threaded data
117 template<class InputImageType, class OutputImageType >
118 void HelperClass1<InputImageType, OutputImageType>::BeforeThreadedGenerateData()
119 {
120   //std::cout << "HelperClass1::BeforeThreadedGenerateData - IN" << std::endl;
121   //Since we will add, put to zero!
122   this->GetOutput()->FillBuffer(itk::NumericTraits<double>::Zero);
123   this->GetWeights()->FillBuffer(itk::NumericTraits<double>::Zero);
124 }
125
126 //=========================================================================================================================
127 //update the output for the outputRegionForThread
128 template<class InputImageType, class OutputImageType>
129 void HelperClass1<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId )
130 {
131 //   std::cout << "HelperClass1::ThreadedGenerateData - IN " << threadId << std::endl;
132   //Get pointer to the input
133   typename InputImageType::ConstPointer inputPtr = this->GetInput();
134
135   //Get pointer to the output
136   typename OutputImageType::Pointer outputPtr = this->GetOutput();
137   //typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
138
139   //Iterator over input
140   typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
141
142   //define them over the outputRegionForThread
143   InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
144
145   //Initialize
146   typename InputImageType::IndexType index;
147   itk::ContinuousIndex<double,ImageDimension> contIndex, inContIndex;
148   typename InputImageType::PointType ipoint;
149   typename OutputImageType::PointType opoint;
150   typedef typename OutputImageType::PixelType DisplacementType;
151   DisplacementType displacement;
152   inputIt.GoToBegin();
153
154   typename OutputImageType::SizeType size = outputPtr->GetLargestPossibleRegion().GetSize();
155
156   //define some temp variables
157   signed long baseIndex[ImageDimension];
158   double distance[ImageDimension];
159   unsigned int dim, counter, upper;
160   double totalOverlap,overlap;
161   typename OutputImageType::IndexType neighIndex;
162
163   //Find the number of neighbors
164   unsigned int neighbors =  1 << ImageDimension;
165
166   //==================================================================================================
167   //Loop over the output region and add the intensities from the input to the output and the weight to the weights
168   //==================================================================================================
169   while( !inputIt.IsAtEnd() ) {
170     // get the input image index
171     index = inputIt.GetIndex();
172     inputPtr->TransformIndexToPhysicalPoint( index,ipoint );
173
174     // get the required displacement
175     displacement = inputIt.Get();
176
177     // compute the required output image point
178     for(unsigned int j = 0; j < ImageDimension; j++ ) opoint[j] = ipoint[j] + (double)displacement[j];
179
180     // Update the output and the weights
181     if(outputPtr->TransformPhysicalPointToContinuousIndex(opoint, contIndex ) ) {
182       for(dim = 0; dim < ImageDimension; dim++) {
183         // The following  block is equivalent to the following line without
184         // having to call floor. (Only for positive inputs, we already now that is in the image)
185         // baseIndex[dim] = (long) std::floor(contIndex[dim] );
186
187         baseIndex[dim] = (long) contIndex[dim];
188         distance[dim] = contIndex[dim] - double( baseIndex[dim] );
189       }
190
191       //Add contribution for each neighbor
192       totalOverlap = itk::NumericTraits<double>::Zero;
193       for( counter = 0; counter < neighbors ; counter++ ) {
194         overlap = 1.0;          // fraction overlap
195         upper = counter;  // each bit indicates upper/lower neighbour
196
197         // get neighbor index and overlap fraction
198         for( dim = 0; dim < ImageDimension; dim++ ) {
199           if ( upper & 1 ) {
200             neighIndex[dim] = baseIndex[dim] + 1;
201             overlap *= distance[dim];
202           } else {
203             neighIndex[dim] = baseIndex[dim];
204             overlap *= 1.0 - distance[dim];
205           }
206           upper >>= 1;
207
208           if (neighIndex[dim] >= size[dim])
209             neighIndex[dim] = size[dim] - 1;
210         }
211
212
213         //Set neighbor value only if overlap is not zero
214         if( (overlap>0.0)) // &&
215           //                    (static_cast<unsigned int>(neighIndex[0])<size[0]) &&
216           //                    (static_cast<unsigned int>(neighIndex[1])<size[1]) &&
217           //                    (static_cast<unsigned int>(neighIndex[2])<size[2]) &&
218           //                    (neighIndex[0]>=0) &&
219           //                    (neighIndex[1]>=0) &&
220           //                    (neighIndex[2]>=0) )
221         {
222           //what to store? the original displacement vector?
223           if (! m_ThreadSafe) {
224             //Set the pixel and weight at neighIndex
225             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) - (displacement*overlap));
226             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
227           }
228
229           else {
230             //Entering critilal section: shared memory
231 #if ITK_VERSION_MAJOR <= 4
232             m_MutexImage->GetPixel(neighIndex).Lock();
233 #else
234             m_Mutex.lock();
235 #endif
236
237             //Set the pixel and weight at neighIndex
238             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) - (displacement*overlap));
239             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
240
241             //Unlock
242 #if ITK_VERSION_MAJOR <= 4
243             m_MutexImage->GetPixel(neighIndex).Unlock();
244 #else
245             m_Mutex.unlock();
246 #endif
247
248           }
249           //Add to total overlap
250           totalOverlap += overlap;
251         }
252
253         if( totalOverlap == 1.0 ) {
254           // finished
255           break;
256         }
257       }
258     }
259
260     ++inputIt;
261   }
262
263 //   std::cout << "HelperClass1::ThreadedGenerateData - OUT " << threadId << std::endl;
264 }
265
266
267
268 //=========================================================================================================================
269 //helper class 2 to allow a threaded execution of normalisation by the weights
270 //=========================================================================================================================
271 template<class InputImageType, class OutputImageType> class HelperClass2 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
272 {
273
274 public:
275   /** Standard class typedefs. */
276   typedef HelperClass2  Self;
277   typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
278   typedef itk::SmartPointer<Self>         Pointer;
279   typedef itk::SmartPointer<const Self>   ConstPointer;
280
281   /** Method for creation through the object factory. */
282   itkNewMacro(Self);
283
284   /** Run-time type information (and related methods) */
285   itkTypeMacro( HelperClass2, ImageToImageFilter );
286
287   /** Constants for the image dimensions */
288   itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
289
290   //Typedefs
291   typedef typename OutputImageType::PixelType        PixelType;
292   typedef itk::Image<double,ImageDimension> WeightsImageType;
293
294   //Set methods
295   void SetWeights(const typename WeightsImageType::Pointer input) {
296     m_Weights = input;
297     this->Modified();
298   }
299   void SetEdgePaddingValue(PixelType value) {
300     m_EdgePaddingValue = value;
301     this->Modified();
302   }
303
304   /** Typedef to describe the output image region type. */
305   typedef typename OutputImageType::RegionType OutputImageRegionType;
306
307 protected:
308   HelperClass2();
309   ~HelperClass2() {};
310
311
312   //the actual processing
313   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId ) ITK_OVERRIDE;
314
315   //member data
316   typename     WeightsImageType::Pointer m_Weights;
317   PixelType m_EdgePaddingValue;
318
319 } ;
320
321
322
323 //=========================================================================================================================
324 //Member functions of the helper class 2
325 //=========================================================================================================================
326
327
328 //=========================================================================================================================
329 //Empty constructor
330 template<class InputImageType, class OutputImageType > HelperClass2<InputImageType, OutputImageType>::HelperClass2()
331 {
332   PixelType zero;
333   for(unsigned int i=0;i <PixelType::Dimension; i++) zero[i] = 0.0;
334   m_EdgePaddingValue=zero;
335   //m_EdgePaddingValue=itk::NumericTraits<PixelType>::Zero;
336 }
337
338
339 //=========================================================================================================================
340 //update the output for the outputRegionForThread
341 template<class InputImageType, class OutputImageType > void HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId )
342 {
343 //   std::cout << "HelperClass2::ThreadedGenerateData - IN " << threadId << std::endl;
344
345   //Get pointer to the input
346   typename InputImageType::ConstPointer inputPtr = this->GetInput();
347
348   //Get pointer to the output
349   typename OutputImageType::Pointer outputPtr = this->GetOutput();
350
351   //Iterators over input, weigths  and output
352   typedef itk::ImageRegionConstIterator<InputImageType> InputImageIteratorType;
353   typedef itk::ImageRegionIterator<OutputImageType> OutputImageIteratorType;
354   typedef itk::ImageRegionIterator<WeightsImageType> WeightsImageIteratorType;
355
356   //define them over the outputRegionForThread
357   OutputImageIteratorType outputIt(outputPtr, outputRegionForThread);
358   InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
359   WeightsImageIteratorType weightsIt(m_Weights, outputRegionForThread);
360
361
362   //==================================================================================================
363   //loop over the output and normalize the input, remove holes
364   PixelType neighValue;
365   double  zero = itk::NumericTraits<double>::Zero;
366   while (!outputIt.IsAtEnd()) {
367     //the weight is not zero
368     if (weightsIt.Get() != zero) {
369       //divide by the weight
370       outputIt.Set(static_cast<PixelType>(inputIt.Get()/weightsIt.Get()));
371     }
372
373     //copy the value of the  neighbour that was just processed
374     else {
375       if(!outputIt.IsAtBegin()) {
376         //go back
377         --outputIt;
378
379         //Neighbour cannot have zero weight because it should be filled already
380         neighValue=outputIt.Get();
381         ++outputIt;
382         outputIt.Set(neighValue);
383         //DD("hole filled");
384       } else {
385         //DD("is at begin, setting edgepadding value");
386         outputIt.Set(m_EdgePaddingValue);
387       }
388     }
389     ++weightsIt;
390     ++outputIt;
391     ++inputIt;
392
393   }//end while
394
395 //   std::cout << "HelperClass2::ThreadedGenerateData - OUT " << threadId << std::endl;
396
397 }//end member
398
399
400 }//end nameless namespace
401
402
403
404 namespace clitk
405 {
406
407 //=========================================================================================================================
408 // The rest is the InvertVFFilter
409 //=========================================================================================================================
410
411 //=========================================================================================================================
412 //constructor
413 template <class InputImageType, class OutputImageType>
414 InvertVFFilter<InputImageType, OutputImageType>::InvertVFFilter()
415 {
416
417   //m_EdgePaddingValue=itk::NumericTraits<PixelType>::Zero; //no other reasonable value?
418   PixelType zero;
419   for(unsigned int i=0;i <PixelType::Dimension; i++) zero[i] = 0.0;
420   m_EdgePaddingValue=zero; //no other reasonable value?
421
422   m_ThreadSafe=false;
423   m_Verbose=false;
424 }
425
426 //=========================================================================================================================
427 //Update
428 template <class InputImageType, class OutputImageType> void InvertVFFilter<InputImageType, OutputImageType>::GenerateData()
429 {
430   // std::cout << "InvertVFFilter::GenerateData - IN" << std::endl;
431
432   //Get the properties of the input
433   typename InputImageType::ConstPointer inputPtr=this->GetInput();
434   typename WeightsImageType::RegionType region = inputPtr->GetLargestPossibleRegion();
435
436   //Allocate the weights
437   typename WeightsImageType::Pointer weights=WeightsImageType::New();
438   weights->SetOrigin(inputPtr->GetOrigin());
439   weights->SetRegions(region);
440   weights->Allocate();
441   weights->SetSpacing(inputPtr->GetSpacing());
442
443   //===========================================================================
444   //Inversion is divided in in two loops, for each we will call a threaded helper class
445   //1. add contribution of input to output and update weights
446   //2. normalize the output by the weight and remove holes
447   //===========================================================================
448
449
450   //===========================================================================
451   //1. add contribution of input to output and update weights
452
453   //Define an internal image type
454
455   typedef itk::Image<itk::Vector<double,ImageDimension>, ImageDimension > InternalImageType;
456
457   //Call threaded helper class 1
458   typedef HelperClass1<InputImageType, InternalImageType > HelperClass1Type;
459   typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
460
461   //Set input
462   if(m_NumberOfThreadsIsGiven) {
463 #if ITK_VERSION_MAJOR <= 4
464     helper1->SetNumberOfThreads(m_NumberOfThreads);
465 #else
466     helper1->SetNumberOfWorkUnits(m_NumberOfWorkUnits);
467 #endif
468   }
469   helper1->SetInput(inputPtr);
470   helper1->SetWeights(weights);
471
472   //Threadsafe?
473   if(m_ThreadSafe) {
474     //Allocate the mutex image
475 #if ITK_VERSION_MAJOR <= 4
476     typename MutexImageType::Pointer mutex=InvertVFFilter::MutexImageType::New();
477     mutex->SetRegions(region);
478     mutex->Allocate();
479     mutex->SetSpacing(inputPtr->GetSpacing());
480     helper1->SetMutexImage(mutex);
481 #else
482     helper1->SetMutexImage();
483 #endif
484     if (m_Verbose) std::cout <<"Inverting using a thread-safe algorithm" <<std::endl;
485   } else  if(m_Verbose)std::cout <<"Inverting using a thread-unsafe algorithm" <<std::endl;
486
487   //Execute helper class
488   helper1->Update();
489
490   //Get the output
491   typename InternalImageType::Pointer temp= helper1->GetOutput();
492   weights=helper1->GetWeights();
493
494
495   //===========================================================================
496   //2. Normalize the output by the weights and remove holes
497   //Call threaded helper class
498   typedef HelperClass2<InternalImageType, OutputImageType> HelperClass2Type;
499   typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
500
501   //Set temporary output as input
502   if(m_NumberOfThreadsIsGiven) {
503 #if ITK_VERSION_MAJOR <= 4
504     helper2->SetNumberOfThreads(m_NumberOfThreads);
505 #else
506     helper2->SetNumberOfWorkUnits(m_NumberOfWorkUnits);
507 #endif
508   }
509   helper2->SetInput(temp);
510   helper2->SetWeights(weights);
511   helper2->SetEdgePaddingValue(m_EdgePaddingValue);
512
513   //Execute helper class
514   if (m_Verbose) std::cout << "Normalizing the output VF..."<<std::endl;
515   helper2->Update();
516
517   //Set the output
518   this->SetNthOutput(0, helper2->GetOutput());
519
520   //std::cout << "InvertVFFilter::GenerateData - OUT" << std::endl;
521 }
522
523
524
525 }
526
527 #endif