]> Creatis software - clitk.git/blob - itk/clitkInvertVFFilter.txx
6bda75eb0d2b65ba4a21b6f438862201c6fffd14
[clitk.git] / itk / clitkInvertVFFilter.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to:
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://www.centreleonberard.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef __clitkInvertVFFilter_txx
19 #define __clitkInvertVFFilter_txx
20 namespace
21 {
22
23 //=========================================================================================================================
24 //helper class 1 to allow a threaded execution: add contributions of input to output and update weights
25 //=========================================================================================================================
26 template<class InputImageType, class OutputImageType> class ITK_EXPORT HelperClass1 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
27 {
28
29 public:
30   /** Standard class typedefs. */
31   typedef HelperClass1  Self;
32   typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
33   typedef itk::SmartPointer<Self>         Pointer;
34   typedef itk::SmartPointer<const Self>   ConstPointer;
35
36   /** Method for creation through the object factory. */
37   itkNewMacro(Self);
38
39   /** Run-time type information (and related methods) */
40   itkTypeMacro( HelperClass1, ImageToImageFilter );
41
42   /** Constants for the image dimensions */
43   itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
44
45
46   //Typedefs
47   typedef typename OutputImageType::PixelType        PixelType;
48   typedef itk::Image<double, ImageDimension > WeightsImageType;
49   typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension > MutexImageType;
50
51   //===================================================================================
52   //Set methods
53   void SetWeights(const typename WeightsImageType::Pointer input) {
54     m_Weights = input;
55     this->Modified();
56   }
57   void SetMutexImage(const typename MutexImageType::Pointer input) {
58     m_MutexImage=input;
59     this->Modified();
60     m_ThreadSafe=true;
61   }
62
63   //Get methods
64   typename  WeightsImageType::Pointer GetWeights() {
65     return m_Weights;
66   }
67
68   /** Typedef to describe the output image region type. */
69   typedef typename OutputImageType::RegionType OutputImageRegionType;
70
71 protected:
72   HelperClass1();
73   ~HelperClass1() {};
74
75   //the actual processing
76   void BeforeThreadedGenerateData();
77   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId );
78
79   //member data
80   typename  WeightsImageType::Pointer m_Weights;
81   typename MutexImageType::Pointer m_MutexImage;
82   bool m_ThreadSafe;
83
84 };
85
86
87
88 //=========================================================================================================================
89 //Member functions of the helper class 1
90 //=========================================================================================================================
91
92
93 //=========================================================================================================================
94 //Empty constructor
95 template<class InputImageType, class OutputImageType >
96 HelperClass1<InputImageType, OutputImageType>::HelperClass1()
97 {
98   m_ThreadSafe=false;
99 }
100
101 //=========================================================================================================================
102 //Before threaded data
103 template<class InputImageType, class OutputImageType >
104 void HelperClass1<InputImageType, OutputImageType>::BeforeThreadedGenerateData()
105 {
106   //Since we will add, put to zero!
107   this->GetOutput()->FillBuffer(itk::NumericTraits<double>::Zero);
108   this->GetWeights()->FillBuffer(itk::NumericTraits<double>::Zero);
109 }
110
111 //=========================================================================================================================
112 //update the output for the outputRegionForThread
113 template<class InputImageType, class OutputImageType>
114 void HelperClass1<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId )
115 {
116
117   //Get pointer to the input
118   typename InputImageType::ConstPointer inputPtr = this->GetInput();
119
120   //Get pointer to the output
121   typename OutputImageType::Pointer outputPtr = this->GetOutput();
122   typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
123
124   //Iterator over input
125   typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
126
127   //define them over the outputRegionForThread
128   InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
129
130   //Initialize
131   typename InputImageType::IndexType index;
132   itk::ContinuousIndex<double,ImageDimension> contIndex;
133   typename InputImageType::PointType ipoint;
134   typename OutputImageType::PointType opoint;
135   typedef typename OutputImageType::PixelType DisplacementType;
136   DisplacementType displacement;
137   inputIt.GoToBegin();
138
139   //define some temp variables
140   signed long baseIndex[ImageDimension];
141   double distance[ImageDimension];
142   unsigned int dim, counter, upper;
143   double totalOverlap,overlap;
144   typename OutputImageType::IndexType neighIndex;
145
146   //Find the number of neighbors
147   unsigned int neighbors =  1 << ImageDimension;
148
149   //==================================================================================================
150   //Loop over the region and add the intensities to the output and the weight to the weights
151   //==================================================================================================
152   while( !inputIt.IsAtEnd() ) {
153     // get the input image index
154     index = inputIt.GetIndex();
155     inputPtr->TransformIndexToPhysicalPoint( index,ipoint );
156
157     // get the required displacement
158     displacement = inputIt.Get();
159
160     // compute the required output image point
161     for(unsigned int j = 0; j < ImageDimension; j++ ) opoint[j] = ipoint[j] + (double)displacement[j];
162
163     // Update the output and the weights
164     if(outputPtr->TransformPhysicalPointToContinuousIndex(opoint, contIndex ) ) {
165       for(dim = 0; dim < ImageDimension; dim++) {
166         // The following  block is equivalent to the following line without
167         // having to call floor. (Only for positive inputs, we already now that is in the image)
168         // baseIndex[dim] = (long) vcl_floor(contIndex[dim] );
169
170         baseIndex[dim] = (long) contIndex[dim];
171         distance[dim] = contIndex[dim] - double( baseIndex[dim] );
172       }
173
174       //Add contribution for each neighbor
175       totalOverlap = itk::NumericTraits<double>::Zero;
176       for( counter = 0; counter < neighbors ; counter++ ) {
177         overlap = 1.0;          // fraction overlap
178         upper = counter;  // each bit indicates upper/lower neighbour
179
180         // get neighbor index and overlap fraction
181         for( dim = 0; dim < 3; dim++ ) {
182           if ( upper & 1 ) {
183             neighIndex[dim] = baseIndex[dim] + 1;
184             overlap *= distance[dim];
185           } else {
186             neighIndex[dim] = baseIndex[dim];
187             overlap *= 1.0 - distance[dim];
188           }
189           upper >>= 1;
190         }
191
192
193
194         //Set neighbor value only if overlap is not zero
195         if( (overlap>0.0)) // &&
196           //                    (static_cast<unsigned int>(neighIndex[0])<size[0]) &&
197           //                    (static_cast<unsigned int>(neighIndex[1])<size[1]) &&
198           //                    (static_cast<unsigned int>(neighIndex[2])<size[2]) &&
199           //                    (neighIndex[0]>=0) &&
200           //                    (neighIndex[1]>=0) &&
201           //                    (neighIndex[2]>=0) )
202         {
203           //what to store? the original displacement vector?
204           if (! m_ThreadSafe) {
205             //Set the pixel and weight at neighIndex
206             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) - (displacement*overlap));
207             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
208           }
209
210           else {
211             //Entering critilal section: shared memory
212             m_MutexImage->GetPixel(neighIndex).Lock();
213
214             //Set the pixel and weight at neighIndex
215             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) - (displacement*overlap));
216             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
217
218             //Unlock
219             m_MutexImage->GetPixel(neighIndex).Unlock();
220
221           }
222           //Add to total overlap
223           totalOverlap += overlap;
224         }
225
226         if( totalOverlap == 1.0 ) {
227           // finished
228           break;
229         }
230       }
231     }
232
233     ++inputIt;
234   }
235
236 }
237
238
239
240 //=========================================================================================================================
241 //helper class 2 to allow a threaded execution of normalisation by the weights
242 //=========================================================================================================================
243 template<class InputImageType, class OutputImageType> class HelperClass2 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
244 {
245
246 public:
247   /** Standard class typedefs. */
248   typedef HelperClass2  Self;
249   typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
250   typedef itk::SmartPointer<Self>         Pointer;
251   typedef itk::SmartPointer<const Self>   ConstPointer;
252
253   /** Method for creation through the object factory. */
254   itkNewMacro(Self);
255
256   /** Run-time type information (and related methods) */
257   itkTypeMacro( HelperClass2, ImageToImageFilter );
258
259   /** Constants for the image dimensions */
260   itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
261
262   //Typedefs
263   typedef typename OutputImageType::PixelType        PixelType;
264   typedef itk::Image<double,ImageDimension> WeightsImageType;
265
266   //Set methods
267   void SetWeights(const typename WeightsImageType::Pointer input) {
268     m_Weights = input;
269     this->Modified();
270   }
271   void SetEdgePaddingValue(PixelType value) {
272     m_EdgePaddingValue = value;
273     this->Modified();
274   }
275
276   /** Typedef to describe the output image region type. */
277   typedef typename OutputImageType::RegionType OutputImageRegionType;
278
279 protected:
280   HelperClass2();
281   ~HelperClass2() {};
282
283
284   //the actual processing
285   void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId );
286
287
288   //member data
289   typename     WeightsImageType::Pointer m_Weights;
290   PixelType m_EdgePaddingValue;
291
292 } ;
293
294
295
296 //=========================================================================================================================
297 //Member functions of the helper class 2
298 //=========================================================================================================================
299
300
301 //=========================================================================================================================
302 //Empty constructor
303 template<class InputImageType, class OutputImageType > HelperClass2<InputImageType, OutputImageType>::HelperClass2()
304 {
305   m_EdgePaddingValue=itk::NumericTraits<PixelType>::Zero;
306 }
307
308
309 //=========================================================================================================================
310 //update the output for the outputRegionForThread
311 template<class InputImageType, class OutputImageType > void HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, itk::ThreadIdType threadId )
312 {
313
314   //Get pointer to the input
315   typename InputImageType::ConstPointer inputPtr = this->GetInput();
316
317   //Get pointer to the output
318   typename OutputImageType::Pointer outputPtr = this->GetOutput();
319
320   //Iterators over input, weigths  and output
321   typedef itk::ImageRegionConstIterator<InputImageType> InputImageIteratorType;
322   typedef itk::ImageRegionIterator<OutputImageType> OutputImageIteratorType;
323   typedef itk::ImageRegionIterator<WeightsImageType> WeightsImageIteratorType;
324
325   //define them over the outputRegionForThread
326   OutputImageIteratorType outputIt(outputPtr, outputRegionForThread);
327   InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
328   WeightsImageIteratorType weightsIt(m_Weights, outputRegionForThread);
329
330
331   //==================================================================================================
332   //loop over the output and normalize the input, remove holes
333   PixelType neighValue;
334   double  zero = itk::NumericTraits<double>::Zero;
335   while (!outputIt.IsAtEnd()) {
336     //the weight is not zero
337     if (weightsIt.Get() != zero) {
338       //divide by the weight
339       outputIt.Set(static_cast<PixelType>(inputIt.Get()/weightsIt.Get()));
340     }
341
342     //copy the value of the  neighbour that was just processed
343     else {
344       if(!outputIt.IsAtBegin()) {
345         //go back
346         --outputIt;
347
348         //Neighbour cannot have zero weight because it should be filled already
349         neighValue=outputIt.Get();
350         ++outputIt;
351         outputIt.Set(neighValue);
352         //DD("hole filled");
353       } else {
354         //DD("is at begin, setting edgepadding value");
355         outputIt.Set(m_EdgePaddingValue);
356       }
357     }
358     ++weightsIt;
359     ++outputIt;
360     ++inputIt;
361
362   }//end while
363 }//end member
364
365
366 }//end nameless namespace
367
368
369
370 namespace clitk
371 {
372
373 //=========================================================================================================================
374 // The rest is the InvertVFFilter
375 //=========================================================================================================================
376
377 //=========================================================================================================================
378 //constructor
379 template <class InputImageType, class OutputImageType>
380 InvertVFFilter<InputImageType, OutputImageType>::InvertVFFilter()
381 {
382   m_EdgePaddingValue=itk::NumericTraits<PixelType>::Zero; //no other reasonable value?
383   m_ThreadSafe=false;
384   m_Verbose=false;
385 }
386
387
388 //=========================================================================================================================
389 //Update
390 template <class InputImageType, class OutputImageType> void InvertVFFilter<InputImageType, OutputImageType>::GenerateData()
391 {
392
393   //Get the properties of the input
394   typename InputImageType::ConstPointer inputPtr=this->GetInput();
395   typename WeightsImageType::RegionType region;
396   typename WeightsImageType::RegionType::SizeType size=inputPtr->GetLargestPossibleRegion().GetSize();
397   region.SetSize(size);
398   typename OutputImageType::IndexType start;
399   for (unsigned int i=0; i< ImageDimension; i++) start[i]=0;
400   region.SetIndex(start);
401   PixelType zero = itk::NumericTraits<double>::Zero;
402
403
404   //Allocate the weights
405   typename WeightsImageType::Pointer weights=WeightsImageType::New();
406   weights->SetRegions(region);
407   weights->Allocate();
408   weights->SetSpacing(inputPtr->GetSpacing());
409
410   //===========================================================================
411   //Inversion is divided in in two loops, for each we will call a threaded helper class
412   //1. add contribution of input to output and update weights
413   //2. normalize the output by the weight and remove holes
414   //===========================================================================
415
416
417   //===========================================================================
418   //1. add contribution of input to output and update weights
419
420   //Define an internal image type
421
422   typedef itk::Image<itk::Vector<double,ImageDimension>, ImageDimension > InternalImageType;
423
424   //Call threaded helper class 1
425   typedef HelperClass1<InputImageType, InternalImageType > HelperClass1Type;
426   typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
427
428   //Set input
429   if(m_NumberOfThreadsIsGiven)helper1->SetNumberOfThreads(m_NumberOfThreads);
430   helper1->SetInput(inputPtr);
431   helper1->SetWeights(weights);
432
433   //Threadsafe?
434   if(m_ThreadSafe) {
435     //Allocate the mutex image
436     typename MutexImageType::Pointer mutex=InvertVFFilter::MutexImageType::New();
437     mutex->SetRegions(region);
438     mutex->Allocate();
439     mutex->SetSpacing(inputPtr->GetSpacing());
440     helper1->SetMutexImage(mutex);
441     if (m_Verbose) std::cout <<"Inverting using a thread-safe algorithm" <<std::endl;
442   } else  if(m_Verbose)std::cout <<"Inverting using a thread-unsafe algorithm" <<std::endl;
443
444   //Execute helper class
445   helper1->Update();
446
447   //Get the output
448   typename InternalImageType::Pointer temp= helper1->GetOutput();
449   weights=helper1->GetWeights();
450
451
452   //===========================================================================
453   //2. Normalize the output by the weights and remove holes
454   //Call threaded helper class
455   typedef HelperClass2<InternalImageType, OutputImageType> HelperClass2Type;
456   typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
457
458   //Set temporary output as input
459   helper2->SetInput(temp);
460   helper2->SetWeights(weights);
461   helper2->SetEdgePaddingValue(m_EdgePaddingValue);
462
463   //Execute helper class
464   if (m_Verbose) std::cout << "Normalizing the output VF..."<<std::endl;
465   helper2->Update();
466
467   //Set the output
468   this->SetNthOutput(0, helper2->GetOutput());
469 }
470
471
472
473 }
474
475 #endif