]> Creatis software - clitk.git/blob - itk/clitkInvertVFFilter.txx
invert VF on given image grid
[clitk.git] / itk / clitkInvertVFFilter.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to:
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://www.centreleonberard.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef __clitkInvertVFFilter_txx
19 #define __clitkInvertVFFilter_txx
20
21 namespace
22 {
23
24 //=========================================================================================================================
25 //helper class 1 to allow a threaded execution: add contributions of input to output and update weights
26 //=========================================================================================================================
27 template<class InputImageType, class OutputImageType> class ITK_EXPORT HelperClass1 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
28 {
29
30 public:
31   /** Standard class typedefs. */
32   typedef HelperClass1  Self;
33   typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
34   typedef itk::SmartPointer<Self>         Pointer;
35   typedef itk::SmartPointer<const Self>   ConstPointer;
36
37   /** Method for creation through the object factory. */
38   itkNewMacro(Self);
39
40   /** Run-time type information (and related methods) */
41   itkTypeMacro( HelperClass1, ImageToImageFilter );
42
43   /** Constants for the image dimensions */
44   itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
45
46
47   //Typedefs
48   typedef typename OutputImageType::PixelType        PixelType;
49   typedef itk::Image<double, ImageDimension > WeightsImageType;
50   typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension > MutexImageType;
51
52   //===================================================================================
53   //Set methods
54   void SetWeights(const typename WeightsImageType::Pointer input) {
55     m_Weights = input;
56     this->Modified();
57   }
58   void SetMutexImage(const typename MutexImageType::Pointer input) {
59     m_MutexImage=input;
60     this->Modified();
61     m_ThreadSafe=true;
62   }
63
64   //Get methods
65   typename  WeightsImageType::Pointer GetWeights() {
66     return m_Weights;
67   }
68
69   /** Typedef to describe the output image region type. */
70   typedef typename OutputImageType::RegionType OutputImageRegionType;
71
72 protected:
73   HelperClass1();
74   ~HelperClass1() {};
75
76   //the actual processing
77   void GenerateInputRequestedRegion();
78   void GenerateOutputInformation();  
79   void BeforeThreadedGenerateData();
80   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
81
82   //member data
83   typename  WeightsImageType::Pointer m_Weights;
84   typename MutexImageType::Pointer m_MutexImage;
85   bool m_ThreadSafe;
86
87 };
88
89
90
91 //=========================================================================================================================
92 //Member functions of the helper class 1
93 //=========================================================================================================================
94
95
96 //=========================================================================================================================
97 //Empty constructor
98 template<class InputImageType, class OutputImageType >
99 HelperClass1<InputImageType, OutputImageType>::HelperClass1()
100 {
101   m_ThreadSafe=false;
102 }
103
104 //=========================================================================================================================
105 //Set input image params
106 template <class InputImageType, class OutputImageType>
107 void HelperClass1<InputImageType, OutputImageType>::GenerateInputRequestedRegion()
108 {
109   //std::cout << "HelperClass1::GenerateInputRequestedRegion - IN" << std::endl;
110
111   typename InputImageType::Pointer inputPtr = const_cast< InputImageType * >(this->GetInput());
112   inputPtr->SetRequestedRegion(inputPtr->GetLargestPossibleRegion());
113   
114   //std::cout << "HelperClass1::GenerateInputRequestedRegion - OUT" << std::endl;
115 }
116
117 //=========================================================================================================================
118 //Set output image params
119 template<class InputImageType, class OutputImageType >
120 void HelperClass1<InputImageType, OutputImageType>::GenerateOutputInformation()
121 {
122   //std::cout << "HelperClass1::GenerateOutputInformation - IN" << std::endl;
123   Superclass::GenerateOutputInformation();
124
125   // it's the weight image that determines the size of the output
126   typename OutputImageType::Pointer outputPtr = this->GetOutput();
127   outputPtr->SetLargestPossibleRegion( m_Weights->GetLargestPossibleRegion() );
128   outputPtr->SetSpacing(m_Weights->GetSpacing());
129 }
130
131 //=========================================================================================================================
132 //Before threaded data
133 template<class InputImageType, class OutputImageType >
134 void HelperClass1<InputImageType, OutputImageType>::BeforeThreadedGenerateData()
135 {
136   //std::cout << "HelperClass1::BeforeThreadedGenerateData - IN" << std::endl;
137   //Since we will add, put to zero!
138   this->GetOutput()->FillBuffer(itk::NumericTraits<double>::Zero);
139   this->GetWeights()->FillBuffer(itk::NumericTraits<double>::Zero);
140 }
141
142 //=========================================================================================================================
143 //update the output for the outputRegionForThread
144 template<class InputImageType, class OutputImageType>
145 void HelperClass1<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
146 {
147   //std::cout << "HelperClass1::ThreadedGenerateData - IN" << std::endl;
148   //Get pointer to the input
149   typename InputImageType::ConstPointer inputPtr = this->GetInput();
150
151   //Get pointer to the output
152   typename OutputImageType::Pointer outputPtr = this->GetOutput();
153 /*  outputPtr->SetLargestPossibleRegion( m_Weights->GetLargestPossibleRegion() );
154   outputPtr->SetSpacing(m_Weights->GetSpacing());*/
155   typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
156   //std::cout << outputPtr->GetSpacing() << size << std::endl;
157
158   //Iterator over input
159   //typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
160   typedef itk::ImageRegionConstIteratorWithIndex<OutputImageType> InputImageIteratorType;
161
162   //define them over the outputRegionForThread
163   //InputImageIteratorType outputIt(inputPtr, outputRegionForThread);
164   InputImageIteratorType outputIt(outputPtr, outputPtr->GetLargestPossibleRegion());
165
166   //Initialize
167   typename InputImageType::IndexType index;
168   itk::ContinuousIndex<double,ImageDimension> contIndex, inContIndex;
169   typename InputImageType::PointType ipoint;
170   typename OutputImageType::PointType opoint;
171   typedef typename OutputImageType::PixelType DisplacementType;
172   DisplacementType displacement;
173   outputIt.GoToBegin();
174
175   //define some temp variables
176   signed long baseIndex[ImageDimension];
177   double distance[ImageDimension];
178   unsigned int dim, counter, upper;
179   double totalOverlap,overlap;
180   typename OutputImageType::IndexType neighIndex, outIndex;
181
182   //Find the number of neighbors
183   unsigned int neighbors =  1 << ImageDimension;
184
185   //==================================================================================================
186   //Loop over the output region and add the intensities from the input to the output and the weight to the weights
187   //==================================================================================================
188   while( !outputIt.IsAtEnd() ) {
189     // get the input image index
190     outIndex = outputIt.GetIndex();
191     outputPtr->TransformIndexToPhysicalPoint( outIndex,opoint );
192     for(unsigned int j = 0; j < ImageDimension; j++ ) ipoint[j] = opoint[j];
193     inputPtr->TransformPhysicalPointToIndex(ipoint, index);
194     inputPtr->TransformIndexToPhysicalPoint( index,ipoint );
195
196     // get the required displacement
197     //displacement = outputIt.Get();
198     displacement = inputPtr->GetPixel(index);
199
200     // compute the required output image point
201     for(unsigned int j = 0; j < ImageDimension; j++ ) opoint[j] = ipoint[j] + (double)displacement[j];
202
203     // Update the output and the weights
204     if(outputPtr->TransformPhysicalPointToContinuousIndex(opoint, contIndex ) ) {
205       for(dim = 0; dim < ImageDimension; dim++) {
206         // The following  block is equivalent to the following line without
207         // having to call floor. (Only for positive inputs, we already now that is in the image)
208         // baseIndex[dim] = (long) vcl_floor(contIndex[dim] );
209
210         baseIndex[dim] = (long) contIndex[dim];
211         distance[dim] = contIndex[dim] - double( baseIndex[dim] );
212       }
213
214       //Add contribution for each neighbor
215       totalOverlap = itk::NumericTraits<double>::Zero;
216       for( counter = 0; counter < neighbors ; counter++ ) {
217         overlap = 1.0;          // fraction overlap
218         upper = counter;  // each bit indicates upper/lower neighbour
219
220         // get neighbor index and overlap fraction
221         for( dim = 0; dim < ImageDimension; dim++ ) {
222           if ( upper & 1 ) {
223             neighIndex[dim] = baseIndex[dim] + 1;
224             overlap *= distance[dim];
225           } else {
226             neighIndex[dim] = baseIndex[dim];
227             overlap *= 1.0 - distance[dim];
228           }
229           upper >>= 1;
230         }
231
232
233
234         //Set neighbor value only if overlap is not zero
235         if( (overlap>0.0)) // &&
236           //                    (static_cast<unsigned int>(neighIndex[0])<size[0]) &&
237           //                    (static_cast<unsigned int>(neighIndex[1])<size[1]) &&
238           //                    (static_cast<unsigned int>(neighIndex[2])<size[2]) &&
239           //                    (neighIndex[0]>=0) &&
240           //                    (neighIndex[1]>=0) &&
241           //                    (neighIndex[2]>=0) )
242         {
243           //what to store? the original displacement vector?
244           if (! m_ThreadSafe) {
245             //Set the pixel and weight at neighIndex
246             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) - (displacement*overlap));
247             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
248           }
249
250           else {
251             //Entering critilal section: shared memory
252             m_MutexImage->GetPixel(neighIndex).Lock();
253
254             //Set the pixel and weight at neighIndex
255             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) - (displacement*overlap));
256             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
257
258             //Unlock
259             m_MutexImage->GetPixel(neighIndex).Unlock();
260
261           }
262           //Add to total overlap
263           totalOverlap += overlap;
264         }
265
266         if( totalOverlap == 1.0 ) {
267           // finished
268           break;
269         }
270       }
271     }
272
273     ++outputIt;
274   }
275
276   //std::cout << "HelperClass1::ThreadedGenerateData - OUT" << std::endl;
277 }
278
279
280
281 //=========================================================================================================================
282 //helper class 2 to allow a threaded execution of normalisation by the weights
283 //=========================================================================================================================
284 template<class InputImageType, class OutputImageType> class HelperClass2 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
285 {
286
287 public:
288   /** Standard class typedefs. */
289   typedef HelperClass2  Self;
290   typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
291   typedef itk::SmartPointer<Self>         Pointer;
292   typedef itk::SmartPointer<const Self>   ConstPointer;
293
294   /** Method for creation through the object factory. */
295   itkNewMacro(Self);
296
297   /** Run-time type information (and related methods) */
298   itkTypeMacro( HelperClass2, ImageToImageFilter );
299
300   /** Constants for the image dimensions */
301   itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
302
303   //Typedefs
304   typedef typename OutputImageType::PixelType        PixelType;
305   typedef itk::Image<double,ImageDimension> WeightsImageType;
306
307   //Set methods
308   void SetWeights(const typename WeightsImageType::Pointer input) {
309     m_Weights = input;
310     this->Modified();
311   }
312   void SetEdgePaddingValue(PixelType value) {
313     m_EdgePaddingValue = value;
314     this->Modified();
315   }
316
317   /** Typedef to describe the output image region type. */
318   typedef typename OutputImageType::RegionType OutputImageRegionType;
319
320 protected:
321   HelperClass2();
322   ~HelperClass2() {};
323
324
325   //the actual processing
326   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
327
328
329   //member data
330   typename     WeightsImageType::Pointer m_Weights;
331   PixelType m_EdgePaddingValue;
332
333 } ;
334
335
336
337 //=========================================================================================================================
338 //Member functions of the helper class 2
339 //=========================================================================================================================
340
341
342 //=========================================================================================================================
343 //Empty constructor
344 template<class InputImageType, class OutputImageType > HelperClass2<InputImageType, OutputImageType>::HelperClass2()
345 {
346   m_EdgePaddingValue=itk::NumericTraits<PixelType>::Zero;
347 }
348
349
350 //=========================================================================================================================
351 //update the output for the outputRegionForThread
352 template<class InputImageType, class OutputImageType > void HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
353 {
354   //std::cout << "HelperClass2::ThreadedGenerateData - IN" << std::endl;
355   
356   //Get pointer to the input
357   typename InputImageType::ConstPointer inputPtr = this->GetInput();
358
359   //Get pointer to the output
360   typename OutputImageType::Pointer outputPtr = this->GetOutput();
361
362   //Iterators over input, weigths  and output
363   typedef itk::ImageRegionConstIterator<InputImageType> InputImageIteratorType;
364   typedef itk::ImageRegionIterator<OutputImageType> OutputImageIteratorType;
365   typedef itk::ImageRegionIterator<WeightsImageType> WeightsImageIteratorType;
366
367   //define them over the outputRegionForThread
368   OutputImageIteratorType outputIt(outputPtr, outputRegionForThread);
369   InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
370   WeightsImageIteratorType weightsIt(m_Weights, outputRegionForThread);
371
372
373   //==================================================================================================
374   //loop over the output and normalize the input, remove holes
375   PixelType neighValue;
376   double  zero = itk::NumericTraits<double>::Zero;
377   while (!outputIt.IsAtEnd()) {
378     //the weight is not zero
379     if (weightsIt.Get() != zero) {
380       //divide by the weight
381       outputIt.Set(static_cast<PixelType>(inputIt.Get()/weightsIt.Get()));
382     }
383
384     //copy the value of the  neighbour that was just processed
385     else {
386       if(!outputIt.IsAtBegin()) {
387         //go back
388         --outputIt;
389
390         //Neighbour cannot have zero weight because it should be filled already
391         neighValue=outputIt.Get();
392         ++outputIt;
393         outputIt.Set(neighValue);
394         //DD("hole filled");
395       } else {
396         //DD("is at begin, setting edgepadding value");
397         outputIt.Set(m_EdgePaddingValue);
398       }
399     }
400     ++weightsIt;
401     ++outputIt;
402     ++inputIt;
403
404   }//end while
405   
406   //std::cout << "HelperClass2::ThreadedGenerateData - OUT" << std::endl;
407   
408 }//end member
409
410
411 }//end nameless namespace
412
413
414
415 namespace clitk
416 {
417
418 //=========================================================================================================================
419 // The rest is the InvertVFFilter
420 //=========================================================================================================================
421
422 //=========================================================================================================================
423 //constructor
424 template <class InputImageType, class OutputImageType>
425 InvertVFFilter<InputImageType, OutputImageType>::InvertVFFilter()
426 {
427   m_EdgePaddingValue=itk::NumericTraits<PixelType>::Zero; //no other reasonable value?
428   m_ThreadSafe=false;
429   m_Verbose=false;
430 }
431
432 //=========================================================================================================================
433 //Set input image params
434 template <class InputImageType, class OutputImageType>
435 void InvertVFFilter<InputImageType, OutputImageType>::GenerateInputRequestedRegion()
436 {
437   //std::cout << "InvertVFFilter::GenerateInputRequestedRegion - IN" << std::endl;
438   // call the superclass' implementation of this method
439   // Superclass::GenerateInputRequestedRegion();
440
441   typename InputImageType::Pointer inputPtr = const_cast< InputImageType * >(this->GetInput());
442   inputPtr->SetRequestedRegion(inputPtr->GetLargestPossibleRegion());
443   
444   //std::cout << "InvertVFFilter::GenerateInputRequestedRegion - OUT" << std::endl;
445 }
446
447 //=========================================================================================================================
448 //Set output image params
449 template<class InputImageType, class OutputImageType >
450 void InvertVFFilter<InputImageType, OutputImageType>::GenerateOutputInformation()
451 {
452   //std::cout << "InvertVFFilter::GenerateOutputInformation - IN" << std::endl;
453   Superclass::GenerateOutputInformation();
454
455   typename InputImageType::ConstPointer inputPtr=this->GetInput();
456   typename WeightsImageType::RegionType region = inputPtr->GetLargestPossibleRegion();
457   region.SetSize(m_OutputSize);
458
459   typename OutputImageType::Pointer outputPtr = this->GetOutput();
460   outputPtr->SetRegions( region );
461   outputPtr->SetSpacing(m_OutputSpacing);
462   outputPtr->SetOrigin(inputPtr->GetOrigin());
463
464   //std::cout << "InvertVFFilter::GenerateOutputInformation - OUT" << std::endl;
465 }
466
467 //=========================================================================================================================
468 //Update
469 template <class InputImageType, class OutputImageType> void InvertVFFilter<InputImageType, OutputImageType>::GenerateData()
470 {
471   //std::cout << "InvertVFFilter::GenerateData - IN" << std::endl;
472
473   //Get the properties of the input
474   typename InputImageType::ConstPointer inputPtr=this->GetInput();
475   typename WeightsImageType::RegionType region = inputPtr->GetLargestPossibleRegion();
476   //region.SetSize(size);
477   region.SetSize(m_OutputSize);
478
479   //Allocate the weights
480   typename WeightsImageType::Pointer weights=WeightsImageType::New();
481   weights->SetOrigin(inputPtr->GetOrigin());
482   weights->SetRegions(region);
483   weights->SetSpacing(m_OutputSpacing);
484   weights->Allocate();
485   //weights->SetSpacing(inputPtr->GetSpacing());
486
487   //===========================================================================
488   //Inversion is divided in in two loops, for each we will call a threaded helper class
489   //1. add contribution of input to output and update weights
490   //2. normalize the output by the weight and remove holes
491   //===========================================================================
492
493
494   //===========================================================================
495   //1. add contribution of input to output and update weights
496
497   //Define an internal image type
498
499   typedef itk::Image<itk::Vector<double,ImageDimension>, ImageDimension > InternalImageType;
500
501   //Call threaded helper class 1
502   typedef HelperClass1<InputImageType, InternalImageType > HelperClass1Type;
503   typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
504
505   //Set input
506   if(m_NumberOfThreadsIsGiven)helper1->SetNumberOfThreads(m_NumberOfThreads);
507   helper1->SetInput(inputPtr);
508   helper1->SetWeights(weights);
509
510   //Threadsafe?
511   if(m_ThreadSafe) {
512     //Allocate the mutex image
513     typename MutexImageType::Pointer mutex=InvertVFFilter::MutexImageType::New();
514     mutex->SetRegions(region);
515     mutex->Allocate();
516     //mutex->SetSpacing(inputPtr->GetSpacing());
517     mutex->SetSpacing(m_OutputSpacing);
518     helper1->SetMutexImage(mutex);
519     if (m_Verbose) std::cout <<"Inverting using a thread-safe algorithm" <<std::endl;
520   } else  if(m_Verbose)std::cout <<"Inverting using a thread-unsafe algorithm" <<std::endl;
521
522   //Execute helper class
523   helper1->Update();
524
525   //Get the output
526   typename InternalImageType::Pointer temp= helper1->GetOutput();
527   weights=helper1->GetWeights();
528
529
530   //===========================================================================
531   //2. Normalize the output by the weights and remove holes
532   //Call threaded helper class
533   typedef HelperClass2<InternalImageType, OutputImageType> HelperClass2Type;
534   typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
535
536   //Set temporary output as input
537   //std::cout << temp->GetSpacing() << temp->GetLargestPossibleRegion().GetSize() << std::endl;
538   //std::cout << weights->GetSpacing() << weights->GetLargestPossibleRegion().GetSize() << std::endl;
539   helper2->SetInput(temp);
540   helper2->SetWeights(weights);
541   helper2->SetEdgePaddingValue(m_EdgePaddingValue);
542
543   //Execute helper class
544   if (m_Verbose) std::cout << "Normalizing the output VF..."<<std::endl;
545   helper2->Update();
546
547   //Set the output
548   this->SetNthOutput(0, helper2->GetOutput());
549 }
550
551
552
553 }
554
555 #endif