]> Creatis software - clitk.git/blobdiff - itk/clitkForwardWarpImageFilter.txx
Change itkSimpleFastMutexLock to std::mutex
[clitk.git] / itk / clitkForwardWarpImageFilter.txx
index 5e1bb132d05470e9fd3c3ab69f2ec7028b431b66..4c38b5a5e3d5a627668de493c6cefccb6e2d5021 100644 (file)
@@ -53,7 +53,9 @@ public:
   //Typedefs
   typedef typename OutputImageType::PixelType        OutputPixelType;
   typedef itk::Image<double, ImageDimension > WeightsImageType;
-  typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension > MutexImageType;
+#if ITK_VERSION_MAJOR <= 4
+    typedef itk::Image<itk::SimpleFastMutexLock, ImageDimension> MutexImageType;
+#endif
   //===================================================================================
   //Set methods
   void SetWeights(const typename WeightsImageType::Pointer input) {
@@ -64,11 +66,18 @@ public:
     m_DeformationField=input;
     this->Modified();
   }
+#if ITK_VERSION_MAJOR <= 4
   void SetMutexImage(const typename MutexImageType::Pointer input) {
     m_MutexImage=input;
     this->Modified();
     m_ThreadSafe=true;
   }
+#else
+  void SetMutexImage() {
+    this->Modified();
+    m_ThreadSafe=true;
+  }
+#endif
 
   //Get methods
   typename WeightsImageType::Pointer GetWeights() {
@@ -89,7 +98,11 @@ protected:
   //member data
   typename  itk::Image< double, ImageDimension>::Pointer m_Weights;
   typename DeformationFieldType::Pointer m_DeformationField;
+#if ITK_VERSION_MAJOR <= 4
   typename MutexImageType::Pointer m_MutexImage;
+#else
+  std::mutex m_Mutex;
+#endif
   bool m_ThreadSafe;
 
 };
@@ -132,7 +145,7 @@ void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::Thread
 
   //Get pointer to the output
   typename OutputImageType::Pointer outputPtr = this->GetOutput();
-  typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
+  //typename OutputImageType::SizeType size=outputPtr->GetLargestPossibleRegion().GetSize();
 
   //Iterators over input and deformation field
   typedef itk::ImageRegionConstIteratorWithIndex<InputImageType> InputImageIteratorType;
@@ -154,6 +167,7 @@ void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::Thread
   //define some temp variables
   signed long baseIndex[ImageDimension];
   double distance[ImageDimension];
+  for(unsigned int i=0; i<ImageDimension; i++) distance[i] = 0.0; // to avoid warning
   unsigned int dim, counter, upper;
   double overlap, totalOverlap;
   typename OutputImageType::IndexType neighIndex;
@@ -183,7 +197,7 @@ void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::Thread
       for(dim = 0; dim < ImageDimension; dim++) {
         // The following  block is equivalent to the following line without
         // having to call floor. For positive inputs!!!
-        // baseIndex[dim] = (long) vcl_floor(contIndex[dim] );
+        // baseIndex[dim] = (long) std::floor(contIndex[dim] );
         baseIndex[dim] = (long) contIndex[dim];
         distance[dim] = contIndex[dim] - double( baseIndex[dim] );
       }
@@ -195,7 +209,7 @@ void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::Thread
         upper = counter;  // each bit indicates upper/lower neighbour
 
         // get neighbor index and overlap fraction
-        for( dim = 0; dim < 3; dim++ ) {
+        for( dim = 0; dim < ImageDimension; dim++ ) {
           if ( upper & 1 ) {
             neighIndex[dim] = baseIndex[dim] + 1;
             overlap *= distance[dim];
@@ -223,14 +237,22 @@ void HelperClass1<InputImageType, OutputImageType, DeformationFieldType>::Thread
 
           } else {
             //Entering critilal section: shared memory
+#if ITK_VERSION_MAJOR <= 4
             m_MutexImage->GetPixel(neighIndex).Lock();
+#else
+            m_Mutex.lock();
+#endif
 
             //Set the pixel and weight at neighIndex
             outputPtr->SetPixel(neighIndex, outputPtr->GetPixel(neighIndex) + overlap * static_cast<OutputPixelType>(inputIt.Get()));
             m_Weights->SetPixel(neighIndex, m_Weights->GetPixel(neighIndex) + overlap);
 
             //Unlock
+#if ITK_VERSION_MAJOR <= 4
             m_MutexImage->GetPixel(neighIndex).Unlock();
+#else
+            m_Mutex.unlock();
+#endif
 
           }
           //Add to total overlap
@@ -444,7 +466,13 @@ void ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldTyp
   typename HelperClass1Type::Pointer helper1=HelperClass1Type::New();
 
   //Set input
-  if(m_NumberOfThreadsIsGiven)helper1->SetNumberOfThreads(m_NumberOfThreads);
+  if(m_NumberOfThreadsIsGiven) {
+#if ITK_VERSION_MAJOR <= 4
+    helper1->SetNumberOfThreads(m_NumberOfThreads);
+#else
+    helper1->SetNumberOfWorkUnits(m_NumberOfWorkUnits);
+#endif
+  }
   helper1->SetInput(inputPtr);
   helper1->SetDeformationField(m_DeformationField);
   helper1->SetWeights(weights);
@@ -452,11 +480,15 @@ void ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldTyp
   //Threadsafe?
   if(m_ThreadSafe) {
     //Allocate the mutex image
+#if ITK_VERSION_MAJOR <= 4
     typename MutexImageType::Pointer mutex=ForwardWarpImageFilter::MutexImageType::New();
     mutex->SetRegions(region);
     mutex->Allocate();
     mutex->SetSpacing(inputPtr->GetSpacing());
     helper1->SetMutexImage(mutex);
+#else
+    helper1->SetMutexImage();
+#endif
     if (m_Verbose) std::cout <<"Forwarp warping using a thread-safe algorithm" <<std::endl;
   } else  if(m_Verbose)std::cout <<"Forwarp warping using a thread-unsafe algorithm" <<std::endl;
 
@@ -477,7 +509,13 @@ void ForwardWarpImageFilter<InputImageType, OutputImageType, DeformationFieldTyp
   typename HelperClass2Type::Pointer helper2=HelperClass2Type::New();
 
   //Set temporary output as input
-  if(m_NumberOfThreadsIsGiven)helper2->SetNumberOfThreads(m_NumberOfThreads);
+  if(m_NumberOfThreadsIsGiven) {
+#if ITK_VERSION_MAJOR <= 4
+    helper2->SetNumberOfThreads(m_NumberOfThreads);
+#else
+    helper2->SetNumberOfWorkUnits(m_NumberOfWorkUnits);
+#endif
+  }
   helper2->SetInput(temp);
   helper2->SetWeights(weights);
   helper2->SetEdgePaddingValue(m_EdgePaddingValue);