1 #ifndef __clitkForwardWarpImageFilter_txx
2 #define __clitkForwardWarpImageFilter_txx
4 #include "clitkForwardWarpImageFilter.h"
5 #include "clitkImageCommon.h"
7 // Put the helper classes in an anonymous namespace so that it is not
13 //=========================================================================================================================
14 //helper class 1 to allow a threaded execution: add contributions of input to output and update weights
15 //=========================================================================================================================
16 template<class InputImageType, class OutputImageType, class DeformationFieldType> class HelperClass1 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
20 /** Standard class typedefs. */
21 typedef HelperClass1 Self;
22 typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
23 typedef itk::SmartPointer<Self> Pointer;
24 typedef itk::SmartPointer<const Self> ConstPointer;
26 /** Method for creation through the object factory. */
29 /** Run-time type information (and related methods) */
30 itkTypeMacro( HelperClass1, ImageToImageFilter );
32 /** Constants for the image dimensions */
33 itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
37 typedef typename OutputImageType::PixelType OutputPixelType;
38 typedef itk::Image<double, ImageDimension > WeightsImageType;
39 typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension > MutexImageType;
40 //===================================================================================
42 void SetWeights(const typename WeightsImageType::Pointer input)
47 void SetDeformationField(const typename DeformationFieldType::Pointer input)
49 m_DeformationField=input;
52 void SetMutexImage(const typename MutexImageType::Pointer input)
60 typename WeightsImageType::Pointer GetWeights(){return m_Weights;}
62 /** Typedef to describe the output image region type. */
63 typedef typename OutputImageType::RegionType OutputImageRegionType;
69 //the actual processing
70 void BeforeThreadedGenerateData();
71 void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
74 typename itk::Image< double, ImageDimension>::Pointer m_Weights;
75 typename DeformationFieldType::Pointer m_DeformationField;
76 typename MutexImageType::Pointer m_MutexImage;
83 //=========================================================================================================================
84 //Member functions of the helper class 1
85 //=========================================================================================================================
88 //=========================================================================================================================
90 template<class InputImageType, class OutputImageType, class DeformationFieldType >
91 HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::HelperClass1()
97 //=========================================================================================================================
98 //Before threaded data
99 template<class InputImageType, class OutputImageType, class DeformationFieldType >
100 void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::BeforeThreadedGenerateData()
102 //Since we will add, put to zero!
103 this->GetOutput()->FillBuffer(itk::NumericTraits<double>::Zero);
104 this->GetWeights()->FillBuffer(itk::NumericTraits<double>::Zero);
108 //=========================================================================================================================
109 //update the output for the outputRegionForThread
110 template<class InputImageType, class OutputImageType, class DeformationFieldType >
111 void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
114 //Get pointer to the input
115 typename InputImageType::ConstPointer inputPtr = this->GetInput();
117 //Get pointer to the output
118 typename OutputImageType::Pointer outputPtr = this->GetOutput();
119 typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
121 //Iterators over input and deformation field
122 typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
123 typedef itk::ImageRegionIterator<DeformationFieldType> DeformationFieldIteratorType;
125 //define them over the outputRegionForThread
126 InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
127 DeformationFieldIteratorType fieldIt(m_DeformationField,outputRegionForThread);
130 typename InputImageType::IndexType index;
131 itk::ContinuousIndex<double,ImageDimension> contIndex;
132 typename InputImageType::PointType point;
133 typedef typename DeformationFieldType::PixelType DisplacementType;
134 DisplacementType displacement;
138 //define some temp variables
139 signed long baseIndex[ImageDimension];
140 double distance[ImageDimension];
141 unsigned int dim, counter, upper;
142 double overlap, totalOverlap;
143 typename OutputImageType::IndexType neighIndex;
145 //Find the number of neighbors
146 unsigned int neighbors = 1 << ImageDimension;
149 //==================================================================================================
150 //Loop over the region and add the intensities to the output and the weight to the weights
151 //==================================================================================================
152 while( !inputIt.IsAtEnd() )
155 // get the input image index
156 index = inputIt.GetIndex();
157 inputPtr->TransformIndexToPhysicalPoint( index, point );
159 // get the required displacement
160 displacement = fieldIt.Get();
162 // compute the required output image point
163 for(unsigned int j = 0; j < ImageDimension; j++ ) point[j] += displacement[j];
166 // Update the output and the weights
167 if(outputPtr->TransformPhysicalPointToContinuousIndex(point, contIndex ) )
169 for(dim = 0; dim < ImageDimension; dim++)
171 // The following block is equivalent to the following line without
172 // having to call floor. For positive inputs!!!
173 // baseIndex[dim] = (long) vcl_floor(contIndex[dim] );
174 baseIndex[dim] = (long) contIndex[dim];
175 distance[dim] = contIndex[dim] - double( baseIndex[dim] );
178 //Add contribution for each neighbor
179 totalOverlap = itk::NumericTraits<double>::Zero;
180 for( counter = 0; counter < neighbors ; counter++ )
182 overlap = 1.0; // fraction overlap
183 upper = counter; // each bit indicates upper/lower neighbour
185 // get neighbor index and overlap fraction
186 for( dim = 0; dim < 3; dim++ )
190 neighIndex[dim] = baseIndex[dim] + 1;
191 overlap *= distance[dim];
195 neighIndex[dim] = baseIndex[dim];
196 overlap *= 1.0 - distance[dim];
201 //Set neighbor value only if overlap is not zero
202 if( (overlap>0.0)) // &&
203 // (static_cast<unsigned int>(neighIndex[0])<size[0]) &&
204 // (static_cast<unsigned int>(neighIndex[1])<size[1]) &&
205 // (static_cast<unsigned int>(neighIndex[2])<size[2]) &&
206 // (neighIndex[0]>=0) &&
207 // (neighIndex[1]>=0) &&
208 // (neighIndex[2]>=0) )
213 //Set the pixel and weight at neighIndex
214 outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));
215 m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
220 //Entering critilal section: shared memory
221 m_MutexImage->GetPixel(neighIndex).Lock();
223 //Set the pixel and weight at neighIndex
224 outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));
225 m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
228 m_MutexImage->GetPixel(neighIndex).Unlock();
231 //Add to total overlap
232 totalOverlap += overlap;
235 //check for totaloverlap: not very likely
236 if( totalOverlap == 1.0 )
253 //=========================================================================================================================
254 //helper class 2 to allow a threaded execution of normalisation by the weights
255 //=========================================================================================================================
256 template<class InputImageType, class OutputImageType>
257 class HelperClass2 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
261 /** Standard class typedefs. */
262 typedef HelperClass2 Self;
263 typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
264 typedef itk::SmartPointer<Self> Pointer;
265 typedef itk::SmartPointer<const Self> ConstPointer;
267 /** Method for creation through the object factory. */
270 /** Run-time type information (and related methods) */
271 itkTypeMacro( HelperClass2, ImageToImageFilter );
273 /** Constants for the image dimensions */
274 itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
277 typedef typename InputImageType::PixelType InputPixelType;
278 typedef typename OutputImageType::PixelType OutputPixelType;
279 typedef itk::Image<double, ImageDimension > WeightsImageType;
283 void SetWeights(const typename WeightsImageType::Pointer input)
288 void SetEdgePaddingValue(OutputPixelType value)
290 m_EdgePaddingValue = value;
294 /** Typedef to describe the output image region type. */
295 typedef typename OutputImageType::RegionType OutputImageRegionType;
302 //the actual processing
303 void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
307 typename WeightsImageType::Pointer m_Weights;
308 OutputPixelType m_EdgePaddingValue;
313 //=========================================================================================================================
314 //Member functions of the helper class 2
315 //=========================================================================================================================
318 //=========================================================================================================================
320 template<class InputImageType, class OutputImageType >
321 HelperClass2<InputImageType, OutputImageType>::HelperClass2()
323 m_EdgePaddingValue=static_cast<OutputPixelType>(0.0);
327 //=========================================================================================================================
328 //update the output for the outputRegionForThread
329 template<class InputImageType, class OutputImageType > void
330 HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
333 //Get pointer to the input
334 typename InputImageType::ConstPointer inputPtr = this->GetInput();
336 //Get pointer to the output
337 typename OutputImageType::Pointer outputPtr = this->GetOutput();
339 //Iterators over input, weigths and output
340 typedef itk::ImageRegionConstIterator<InputImageType> InputImageIteratorType;
341 typedef itk::ImageRegionIterator<OutputImageType> OutputImageIteratorType;
342 typedef itk::ImageRegionIterator<WeightsImageType> WeightsImageIteratorType;
344 //define them over the outputRegionForThread
345 OutputImageIteratorType outputIt(outputPtr, outputRegionForThread);
346 InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
347 WeightsImageIteratorType weightsIt(m_Weights, outputRegionForThread);
350 //==================================================================================================
351 //loop over the output and normalize the input, remove holes
352 OutputPixelType neighValue;
353 double zero = itk::NumericTraits<double>::Zero;
354 while (!outputIt.IsAtEnd())
356 //the weight is not zero
357 if (weightsIt.Get() != zero)
359 //divide by the weight
360 outputIt.Set(static_cast<OutputPixelType>(inputIt.Get()/weightsIt.Get()));
363 //copy the value of the neighbour that was just processed
366 if(!outputIt.IsAtBegin())
370 neighValue=outputIt.Get();
372 outputIt.Set(neighValue);
375 //DD("is at begin, setting edgepadding value");
376 outputIt.Set(m_EdgePaddingValue);
387 }//end nameless namespace
394 //=========================================================================================================================
395 // The rest is the ForwardWarpImageFilter
396 //=========================================================================================================================
398 //=========================================================================================================================
400 template <class InputImageType, class OutputImageType, class DeformationFieldType>
401 ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::ForwardWarpImageFilter()
404 m_NumberOfThreadsIsGiven=false;
405 m_EdgePaddingValue=static_cast<PixelType>(0.0);
411 //=========================================================================================================================
413 template <class InputImageType, class OutputImageType, class DeformationFieldType>
414 void ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::GenerateData()
417 //Get the properties of the input
418 typename InputImageType::ConstPointer inputPtr=this->GetInput();
419 typename WeightsImageType::RegionType region;
420 typename WeightsImageType::RegionType::SizeType size=inputPtr->GetLargestPossibleRegion().GetSize();
421 region.SetSize(size);
422 typename OutputImageType::IndexType start;
423 for (unsigned int i =0; i< ImageDimension ;i ++)start[i]=0;
424 region.SetIndex(start);
426 //Allocate the weights
427 typename WeightsImageType::Pointer weights=ForwardWarpImageFilter::WeightsImageType::New();
428 weights->SetRegions(region);
430 weights->SetSpacing(inputPtr->GetSpacing());
433 //===========================================================================
434 //warp is divided in in two loops, for each we call a threaded helper class
435 //1. Add contribution of input to output and update weights
436 //2. Normalize the output by the weight and remove holes
437 //===========================================================================
439 //===========================================================================
440 //1. Add contribution of input to output and update weights
442 //Define an internal image type in double precision
443 typedef itk::Image<double, ImageDimension> InternalImageType;
445 //Call threaded helper class 1
446 typedef HelperClass1<InputImageType, InternalImageType, DeformationFieldType> HelperClass1Type;
447 typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
450 if(m_NumberOfThreadsIsGiven)helper1->SetNumberOfThreads(m_NumberOfThreads);
451 helper1->SetInput(inputPtr);
452 helper1->SetDeformationField(m_DeformationField);
453 helper1->SetWeights(weights);
458 //Allocate the mutex image
459 typename MutexImageType::Pointer mutex=ForwardWarpImageFilter::MutexImageType::New();
460 mutex->SetRegions(region);
462 mutex->SetSpacing(inputPtr->GetSpacing());
463 helper1->SetMutexImage(mutex);
464 if (m_Verbose) std::cout <<"Forwarp warping using a thread-safe algorithm" <<std::endl;
466 else if(m_Verbose)std::cout <<"Forwarp warping using a thread-unsafe algorithm" <<std::endl;
468 //Execute helper class
472 typename InternalImageType::Pointer temp= helper1->GetOutput();
475 weights=helper1->GetWeights();
478 //===========================================================================
479 //2. Normalize the output by the weights and remove holes
480 //Call threaded helper class
481 typedef HelperClass2<InternalImageType, OutputImageType> HelperClass2Type;
482 typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
484 //Set temporary output as input
485 if(m_NumberOfThreadsIsGiven)helper2->SetNumberOfThreads(m_NumberOfThreads);
486 helper2->SetInput(temp);
487 helper2->SetWeights(weights);
488 helper2->SetEdgePaddingValue(m_EdgePaddingValue);
490 //Execute helper class
494 this->SetNthOutput(0, helper2->GetOutput());