1 /*=========================================================================
2 Program: vv http://www.creatis.insa-lyon.fr/rio/vv
5 - University of LYON http://www.universite-lyon.fr/
6 - Léon Bérard cancer center http://www.centreleonberard.fr
7 - CREATIS CNRS laboratory http://www.creatis.insa-lyon.fr
9 This software is distributed WITHOUT ANY WARRANTY; without even
10 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11 PURPOSE. See the copyright notices for more information.
13 It is distributed under dual licence
15 - BSD See included LICENSE.txt file
16 - CeCILL-B http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef __clitkForwardWarpImageFilter_txx
19 #define __clitkForwardWarpImageFilter_txx
20 #include "clitkForwardWarpImageFilter.h"
21 #include "clitkImageCommon.h"
23 // Put the helper classes in an anonymous namespace so that it is not
24 // exposed to the user
30 //=========================================================================================================================
31 //helper class 1 to allow a threaded execution: add contributions of input to output and update weights
32 //=========================================================================================================================
33 template<class InputImageType, class OutputImageType, class DeformationFieldType> class HelperClass1 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
37 /** Standard class typedefs. */
38 typedef HelperClass1 Self;
39 typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
40 typedef itk::SmartPointer<Self> Pointer;
41 typedef itk::SmartPointer<const Self> ConstPointer;
43 /** Method for creation through the object factory. */
46 /** Run-time type information (and related methods) */
47 itkTypeMacro( HelperClass1, ImageToImageFilter );
49 /** Constants for the image dimensions */
50 itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
54 typedef typename OutputImageType::PixelType OutputPixelType;
55 typedef itk::Image<double, ImageDimension > WeightsImageType;
56 typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension > MutexImageType;
57 //===================================================================================
59 void SetWeights(const typename WeightsImageType::Pointer input) {
63 void SetDeformationField(const typename DeformationFieldType::Pointer input) {
64 m_DeformationField=input;
67 void SetMutexImage(const typename MutexImageType::Pointer input) {
74 typename WeightsImageType::Pointer GetWeights() {
78 /** Typedef to describe the output image region type. */
79 typedef typename OutputImageType::RegionType OutputImageRegionType;
85 //the actual processing
86 void BeforeThreadedGenerateData();
87 void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
90 typename itk::Image< double, ImageDimension>::Pointer m_Weights;
91 typename DeformationFieldType::Pointer m_DeformationField;
92 typename MutexImageType::Pointer m_MutexImage;
99 //=========================================================================================================================
100 //Member functions of the helper class 1
101 //=========================================================================================================================
104 //=========================================================================================================================
106 template<class InputImageType, class OutputImageType, class DeformationFieldType >
107 HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::HelperClass1()
113 //=========================================================================================================================
114 //Before threaded data
115 template<class InputImageType, class OutputImageType, class DeformationFieldType >
116 void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::BeforeThreadedGenerateData()
118 //Since we will add, put to zero!
119 this->GetOutput()->FillBuffer(itk::NumericTraits<double>::Zero);
120 this->GetWeights()->FillBuffer(itk::NumericTraits<double>::Zero);
124 //=========================================================================================================================
125 //update the output for the outputRegionForThread
126 template<class InputImageType, class OutputImageType, class DeformationFieldType >
127 void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
130 //Get pointer to the input
131 typename InputImageType::ConstPointer inputPtr = this->GetInput();
133 //Get pointer to the output
134 typename OutputImageType::Pointer outputPtr = this->GetOutput();
135 typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
137 //Iterators over input and deformation field
138 typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
139 typedef itk::ImageRegionIterator<DeformationFieldType> DeformationFieldIteratorType;
141 //define them over the outputRegionForThread
142 InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
143 DeformationFieldIteratorType fieldIt(m_DeformationField,outputRegionForThread);
146 typename InputImageType::IndexType index;
147 itk::ContinuousIndex<double,ImageDimension> contIndex;
148 typename InputImageType::PointType point;
149 typedef typename DeformationFieldType::PixelType DisplacementType;
150 DisplacementType displacement;
154 //define some temp variables
155 signed long baseIndex[ImageDimension];
156 double distance[ImageDimension];
157 for(uint i=0; i<ImageDimension; i++) distance[i] = 0.0; // to avoid warning
158 unsigned int dim, counter, upper;
159 double overlap, totalOverlap;
160 typename OutputImageType::IndexType neighIndex;
162 //Find the number of neighbors
163 unsigned int neighbors = 1 << ImageDimension;
166 //==================================================================================================
167 //Loop over the region and add the intensities to the output and the weight to the weights
168 //==================================================================================================
169 while( !inputIt.IsAtEnd() ) {
171 // get the input image index
172 index = inputIt.GetIndex();
173 inputPtr->TransformIndexToPhysicalPoint( index, point );
175 // get the required displacement
176 displacement = fieldIt.Get();
178 // compute the required output image point
179 for(unsigned int j = 0; j < ImageDimension; j++ ) point[j] += displacement[j];
182 // Update the output and the weights
183 if(outputPtr->TransformPhysicalPointToContinuousIndex(point, contIndex ) ) {
184 for(dim = 0; dim < ImageDimension; dim++) {
185 // The following block is equivalent to the following line without
186 // having to call floor. For positive inputs!!!
187 // baseIndex[dim] = (long) vcl_floor(contIndex[dim] );
188 baseIndex[dim] = (long) contIndex[dim];
189 distance[dim] = contIndex[dim] - double( baseIndex[dim] );
192 //Add contribution for each neighbor
193 totalOverlap = itk::NumericTraits<double>::Zero;
194 for( counter = 0; counter < neighbors ; counter++ ) {
195 overlap = 1.0; // fraction overlap
196 upper = counter; // each bit indicates upper/lower neighbour
198 // get neighbor index and overlap fraction
199 for( dim = 0; dim < ImageDimension; dim++ ) {
201 neighIndex[dim] = baseIndex[dim] + 1;
202 overlap *= distance[dim];
204 neighIndex[dim] = baseIndex[dim];
205 overlap *= 1.0 - distance[dim];
210 //Set neighbor value only if overlap is not zero
211 if( (overlap>0.0)) // &&
212 // (static_cast<unsigned int>(neighIndex[0])<size[0]) &&
213 // (static_cast<unsigned int>(neighIndex[1])<size[1]) &&
214 // (static_cast<unsigned int>(neighIndex[2])<size[2]) &&
215 // (neighIndex[0]>=0) &&
216 // (neighIndex[1]>=0) &&
217 // (neighIndex[2]>=0) )
220 if (! m_ThreadSafe) {
221 //Set the pixel and weight at neighIndex
222 outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));
223 m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
226 //Entering critilal section: shared memory
227 m_MutexImage->GetPixel(neighIndex).Lock();
229 //Set the pixel and weight at neighIndex
230 outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));
231 m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
234 m_MutexImage->GetPixel(neighIndex).Unlock();
237 //Add to total overlap
238 totalOverlap += overlap;
241 //check for totaloverlap: not very likely
242 if( totalOverlap == 1.0 ) {
258 //=========================================================================================================================
259 //helper class 2 to allow a threaded execution of normalisation by the weights
260 //=========================================================================================================================
261 template<class InputImageType, class OutputImageType>
262 class HelperClass2 : public itk::ImageToImageFilter<InputImageType, OutputImageType>
266 /** Standard class typedefs. */
267 typedef HelperClass2 Self;
268 typedef itk::ImageToImageFilter<InputImageType,OutputImageType> Superclass;
269 typedef itk::SmartPointer<Self> Pointer;
270 typedef itk::SmartPointer<const Self> ConstPointer;
272 /** Method for creation through the object factory. */
275 /** Run-time type information (and related methods) */
276 itkTypeMacro( HelperClass2, ImageToImageFilter );
278 /** Constants for the image dimensions */
279 itkStaticConstMacro(ImageDimension, unsigned int,InputImageType::ImageDimension);
282 typedef typename InputImageType::PixelType InputPixelType;
283 typedef typename OutputImageType::PixelType OutputPixelType;
284 typedef itk::Image<double, ImageDimension > WeightsImageType;
288 void SetWeights(const typename WeightsImageType::Pointer input) {
292 void SetEdgePaddingValue(OutputPixelType value) {
293 m_EdgePaddingValue = value;
297 /** Typedef to describe the output image region type. */
298 typedef typename OutputImageType::RegionType OutputImageRegionType;
305 //the actual processing
306 void ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId );
310 typename WeightsImageType::Pointer m_Weights;
311 OutputPixelType m_EdgePaddingValue;
316 //=========================================================================================================================
317 //Member functions of the helper class 2
318 //=========================================================================================================================
321 //=========================================================================================================================
323 template<class InputImageType, class OutputImageType >
324 HelperClass2<InputImageType, OutputImageType>::HelperClass2()
326 m_EdgePaddingValue=static_cast<OutputPixelType>(0.0);
330 //=========================================================================================================================
331 //update the output for the outputRegionForThread
332 template<class InputImageType, class OutputImageType > void
333 HelperClass2<InputImageType, OutputImageType>::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, int threadId )
336 //Get pointer to the input
337 typename InputImageType::ConstPointer inputPtr = this->GetInput();
339 //Get pointer to the output
340 typename OutputImageType::Pointer outputPtr = this->GetOutput();
342 //Iterators over input, weigths and output
343 typedef itk::ImageRegionConstIterator<InputImageType> InputImageIteratorType;
344 typedef itk::ImageRegionIterator<OutputImageType> OutputImageIteratorType;
345 typedef itk::ImageRegionIterator<WeightsImageType> WeightsImageIteratorType;
347 //define them over the outputRegionForThread
348 OutputImageIteratorType outputIt(outputPtr, outputRegionForThread);
349 InputImageIteratorType inputIt(inputPtr, outputRegionForThread);
350 WeightsImageIteratorType weightsIt(m_Weights, outputRegionForThread);
353 //==================================================================================================
354 //loop over the output and normalize the input, remove holes
355 OutputPixelType neighValue;
356 double zero = itk::NumericTraits<double>::Zero;
357 while (!outputIt.IsAtEnd()) {
358 //the weight is not zero
359 if (weightsIt.Get() != zero) {
360 //divide by the weight
361 outputIt.Set(static_cast<OutputPixelType>(inputIt.Get()/weightsIt.Get()));
364 //copy the value of the neighbour that was just processed
366 if(!outputIt.IsAtBegin()) {
369 neighValue=outputIt.Get();
371 outputIt.Set(neighValue);
373 //DD("is at begin, setting edgepadding value");
374 outputIt.Set(m_EdgePaddingValue);
385 }//end nameless namespace
392 //=========================================================================================================================
393 // The rest is the ForwardWarpImageFilter
394 //=========================================================================================================================
396 //=========================================================================================================================
398 template <class InputImageType, class OutputImageType, class DeformationFieldType>
399 ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::ForwardWarpImageFilter()
402 m_NumberOfThreadsIsGiven=false;
403 m_EdgePaddingValue=static_cast<PixelType>(0.0);
409 //=========================================================================================================================
411 template <class InputImageType, class OutputImageType, class DeformationFieldType>
412 void ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldType>::GenerateData()
415 //Get the properties of the input
416 typename InputImageType::ConstPointer inputPtr=this->GetInput();
417 typename WeightsImageType::RegionType region;
418 typename WeightsImageType::RegionType::SizeType size=inputPtr->GetLargestPossibleRegion().GetSize();
419 region.SetSize(size);
420 typename OutputImageType::IndexType start;
421 for (unsigned int i =0; i< ImageDimension ; i ++)start[i]=0;
422 region.SetIndex(start);
424 //Allocate the weights
425 typename WeightsImageType::Pointer weights=ForwardWarpImageFilter::WeightsImageType::New();
426 weights->SetRegions(region);
428 weights->SetSpacing(inputPtr->GetSpacing());
431 //===========================================================================
432 //warp is divided in in two loops, for each we call a threaded helper class
433 //1. Add contribution of input to output and update weights
434 //2. Normalize the output by the weight and remove holes
435 //===========================================================================
437 //===========================================================================
438 //1. Add contribution of input to output and update weights
440 //Define an internal image type in double precision
441 typedef itk::Image<double, ImageDimension> InternalImageType;
443 //Call threaded helper class 1
444 typedef HelperClass1<InputImageType, InternalImageType, DeformationFieldType> HelperClass1Type;
445 typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
448 if(m_NumberOfThreadsIsGiven)helper1->SetNumberOfThreads(m_NumberOfThreads);
449 helper1->SetInput(inputPtr);
450 helper1->SetDeformationField(m_DeformationField);
451 helper1->SetWeights(weights);
455 //Allocate the mutex image
456 typename MutexImageType::Pointer mutex=ForwardWarpImageFilter::MutexImageType::New();
457 mutex->SetRegions(region);
459 mutex->SetSpacing(inputPtr->GetSpacing());
460 helper1->SetMutexImage(mutex);
461 if (m_Verbose) std::cout <<"Forwarp warping using a thread-safe algorithm" <<std::endl;
462 } else if(m_Verbose)std::cout <<"Forwarp warping using a thread-unsafe algorithm" <<std::endl;
464 //Execute helper class
468 typename InternalImageType::Pointer temp= helper1->GetOutput();
471 weights=helper1->GetWeights();
474 //===========================================================================
475 //2. Normalize the output by the weights and remove holes
476 //Call threaded helper class
477 typedef HelperClass2<InternalImageType, OutputImageType> HelperClass2Type;
478 typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
480 //Set temporary output as input
481 if(m_NumberOfThreadsIsGiven)helper2->SetNumberOfThreads(m_NumberOfThreads);
482 helper2->SetInput(temp);
483 helper2->SetWeights(weights);
484 helper2->SetEdgePaddingValue(m_EdgePaddingValue);
486 //Execute helper class
490 this->SetNthOutput(0, helper2->GetOutput());