]> Creatis software - clitk.git/blobdiff - registration/clitkBLUTDIRGenericFilter.cxx
Debug RTStruct conversion with empty struc
[clitk.git] / registration / clitkBLUTDIRGenericFilter.cxx
old mode 100755 (executable)
new mode 100644 (file)
index 15afacf..24d4613
@@ -29,6 +29,13 @@ It is distributed under dual licence
 
 #include "clitkBLUTDIRGenericFilter.h"
 #include "clitkBLUTDIRCommandIterationUpdateDVF.h"
+#include "itkCenteredTransformInitializer.h"
+#include "itkLabelStatisticsImageFilter.h"
+#if (ITK_VERSION_MAJOR == 4) && (ITK_VERSION_MINOR < 6)
+# include "itkTransformToDisplacementFieldSource.h"
+#else
+# include "itkTransformToDisplacementFieldFilter.h"
+#endif
 
 namespace clitk
 {
@@ -239,7 +246,7 @@ namespace clitk
           m_CommandIterationUpdate->SetOptimizer(m_GenericOptimizer);
 
           // Set the previous transform parameters to the registration
-          // if(m_Initializer->m_Parameters!=NULL )delete m_Initializer->m_Parameters;
+          // if(m_Initializer->m_Parameters!=ITK_NULLPTR )delete m_Initializer->m_Parameters;
           m_Initializer->SetInitialParameters(currentCoefficientImages, *newParameters);
           registration->SetInitialTransformParametersOfNextLevel(*newParameters);
         }
@@ -322,7 +329,6 @@ namespace clitk
       // The metric region with respect to the extracted transform region:
       // where should the metric be CALCULATED (depends on transform)
       typename FixedImageType::RegionType metricRegion = fixedImage->GetLargestPossibleRegion();
-      typename FixedImageType::RegionType::SizeType metricRegionSize=metricRegion.GetSize();
       typename FixedImageType::RegionType::IndexType metricRegionIndex=metricRegion.GetIndex();
       typename FixedImageType::PointType metricRegionOrigin=fixedImage->GetOrigin();
 
@@ -332,8 +338,8 @@ namespace clitk
       //============================================================================
       typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension >   MaskType;
       typedef itk::Image< unsigned char, InputImageType::ImageDimension >   ImageLabelType;
-      typename MaskType::Pointer        fixedMask = NULL;
-      typename ImageLabelType::Pointer  labels = NULL;
+      typename MaskType::Pointer        fixedMask = ITK_NULLPTR;
+      typename ImageLabelType::Pointer  labels = ITK_NULLPTR;
       if (m_ArgsInfo.referenceMask_given)
       {
         fixedMask = MaskType::New();
@@ -368,11 +374,12 @@ namespace clitk
         fixedMask->SetImage(labels);
 
         // Find the bounding box of the "inside" label
-        typedef itk::LabelGeometryImageFilter<ImageLabelType> GeometryImageFilterType;
-        typename GeometryImageFilterType::Pointer geometryImageFilter=GeometryImageFilterType::New();
-        geometryImageFilter->SetInput(labels);
-        geometryImageFilter->Update();
-        typename GeometryImageFilterType::BoundingBoxType boundingBox = geometryImageFilter->GetBoundingBox(1);
+        typedef itk::LabelStatisticsImageFilter<ImageLabelType, ImageLabelType> StatisticsImageFilterType;
+        typename StatisticsImageFilterType::Pointer statisticsImageFilter=StatisticsImageFilterType::New();
+        statisticsImageFilter->SetInput(labels);
+        statisticsImageFilter->SetLabelInput(labels);
+        statisticsImageFilter->Update();
+        typename StatisticsImageFilterType::BoundingBoxType boundingBox = statisticsImageFilter->GetBoundingBox(1);
 
         // Limit the transform region to the mask
         for (unsigned int i=0; i<InputImageType::ImageDimension; i++)
@@ -387,6 +394,7 @@ namespace clitk
         // Crop the fixedImage to the bounding box to facilitate multi-resolution
         typedef itk::ExtractImageFilter<FixedImageType,FixedImageType> ExtractImageFilterType;
         typename ExtractImageFilterType::Pointer extractImageFilter=ExtractImageFilterType::New();
+        extractImageFilter->SetDirectionCollapseToSubmatrix();
         extractImageFilter->SetInput(fixedImage);
         extractImageFilter->SetExtractionRegion(transformRegion);
         extractImageFilter->Update();
@@ -405,7 +413,7 @@ namespace clitk
       }
 
       typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension >   MaskType;
-      typename MaskType::Pointer  movingMask=NULL;
+      typename MaskType::Pointer  movingMask=ITK_NULLPTR;
       if (m_ArgsInfo.targetMask_given)
       {
         movingMask= MaskType::New();
@@ -478,7 +486,7 @@ namespace clitk
       // Rigid or Affine Transform
       //=======================================================
       typedef itk::AffineTransform <double,3> RigidTransformType;
-      RigidTransformType::Pointer rigidTransform=NULL;
+      RigidTransformType::Pointer rigidTransform=ITK_NULLPTR;
       if (m_ArgsInfo.initMatrix_given)
       {
         if(m_Verbose) std::cout<<"Reading the prior transform matrix "<< m_ArgsInfo.initMatrix_arg<<"..."<<std::endl;
@@ -493,6 +501,20 @@ namespace clitk
         itk::Vector<double,3> finalTranslation = clitk::GetTranslationPartMatrix3D(rigidTransformMatrix);
         rigidTransform->SetTranslation(finalTranslation);
       }
+      else if (m_ArgsInfo.centre_flag)
+      {
+        if(m_Verbose) std::cout<<"No itinial matrix given and \"centre\" flag switched on. Centering all images..."<<std::endl;
+        
+        rigidTransform=RigidTransformType::New();
+        
+        typedef itk::CenteredTransformInitializer<RigidTransformType, FixedImageType, MovingImageType > TransformInitializerType;
+        typename TransformInitializerType::Pointer initializer = TransformInitializerType::New();
+        initializer->SetTransform( rigidTransform );
+        initializer->SetFixedImage( fixedImage );
+        initializer->SetMovingImage( movingImage );        
+        initializer->GeometryOn();
+        initializer->InitializeTransform();
+      }
 
 
       //=======================================================
@@ -584,6 +606,25 @@ namespace clitk
       transform->SetParameters( parameters );
       if (m_ArgsInfo.initCoeff_given) initializer->SetInitialParameters(m_ArgsInfo.initCoeff_arg, parameters);
 
+      //-------------------------------------------------------------------------
+      // DEBUG: use an itk BSpline instead of multilabel BLUTs
+      //-------------------------------------------------------------------------
+      typedef itk::Transform< TCoordRep, 3, 3 > RegistrationTransformType;
+      RegistrationTransformType::Pointer regTransform(transform);
+      typedef itk::BSplineDeformableTransform<TCoordRep,SpaceDimension, 3> SingleBSplineTransformType;
+      typename SingleBSplineTransformType::Pointer sTransform;
+      if(m_ArgsInfo.itkbspline_flag) {
+        if( transform->GetTransforms().size()>1)
+          itkExceptionMacro(<< "invalid --itkbspline option if there is more than 1 label")
+        sTransform = SingleBSplineTransformType::New();
+        sTransform->SetBulkTransform( transform->GetTransforms()[0]->GetBulkTransform() );
+        sTransform->SetGridSpacing( transform->GetTransforms()[0]->GetGridSpacing() );
+        sTransform->SetGridOrigin( transform->GetTransforms()[0]->GetGridOrigin() );
+        sTransform->SetGridRegion( transform->GetTransforms()[0]->GetGridRegion() );
+        sTransform->SetParameters( transform->GetTransforms()[0]->GetParameters() );
+        regTransform = sTransform;
+        transform = ITK_NULLPTR; // free memory
+      }
 
       //=======================================================
       // Interpolator
@@ -607,16 +648,14 @@ namespace clitk
       typedef itk::ImageToImageMetric< FixedImageType, MovingImageType >  MetricType;
       typename  MetricType::Pointer metric=genericMetric->GetMetricPointer();
       if (movingMask) metric->SetMovingImageMask(movingMask);
-
-#ifdef ITK_USE_OPTIMIZED_REGISTRATION_METHODS
       if (threadsGiven) {
+#if ITK_VERSION_MAJOR <= 4
         metric->SetNumberOfThreads( threads );
-        if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl;
-      }
 #else
-      if (m_Verbose) std::cout<<"Not setting the number of threads (not compiled with USE_OPTIMIZED_REGISTRATION_METHODS)..."<<std::endl;
+        metric->SetNumberOfWorkUnits( threads );
 #endif
-
+        if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl;
+      }
 
       //=======================================================
       // Optimizer
@@ -625,7 +664,7 @@ namespace clitk
       GenericOptimizerType::Pointer genericOptimizer = GenericOptimizerType::New();
       genericOptimizer->SetArgsInfo(m_ArgsInfo);
       genericOptimizer->SetMaximize(genericMetric->GetMaximize());
-      genericOptimizer->SetNumberOfParameters(transform->GetNumberOfParameters());
+      genericOptimizer->SetNumberOfParameters(regTransform->GetNumberOfParameters());
       typedef itk::SingleValuedNonLinearOptimizer OptimizerType;
       OptimizerType::Pointer optimizer = genericOptimizer->GetOptimizerPointer();
 
@@ -638,9 +677,13 @@ namespace clitk
       registration->SetMetric(        metric        );
       registration->SetOptimizer(     optimizer     );
       registration->SetInterpolator(  interpolator  );
-      registration->SetTransform (transform);
+      registration->SetTransform (regTransform );
       if(threadsGiven) {
+#if ITK_VERSION_MAJOR <= 4
         registration->SetNumberOfThreads(threads);
+#else
+        registration->SetNumberOfWorkUnits(threads);
+#endif
         if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl;
       }
       registration->SetFixedImage(  croppedFixedImage   );
@@ -648,7 +691,7 @@ namespace clitk
       registration->SetFixedImageRegion( metricRegion );
       registration->SetFixedImagePyramid( fixedImagePyramid );
       registration->SetMovingImagePyramid( movingImagePyramid );
-      registration->SetInitialTransformParameters( transform->GetParameters() );
+      registration->SetInitialTransformParameters( regTransform->GetParameters() );
       registration->SetNumberOfLevels( m_ArgsInfo.levels_arg );
       if (m_Verbose) std::cout<<"Setting the number of resolution levels to "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
 
@@ -675,6 +718,10 @@ namespace clitk
 
         if (m_ArgsInfo.coeff_given)
         {
+          if(m_ArgsInfo.itkbspline_flag) {
+            itkExceptionMacro("--coeff and --itkbpline are incompatible");
+          }
+
           std::cout << std::endl << "Output coefficient images every " << m_ArgsInfo.coeffEveryN_arg << " iterations." << std::endl;
           typedef CommandIterationUpdateDVF<FixedImageType, OptimizerType, TransformType> DVFCommandType;
           typename DVFCommandType::Pointer observerdvf = DVFCommandType::New();
@@ -693,7 +740,7 @@ namespace clitk
 
       try
       {
-        registration->StartRegistration();
+        registration->Update();
       }
       catch( itk::ExceptionObject & err )
       {
@@ -707,7 +754,7 @@ namespace clitk
       // Get the result
       //=======================================================
       OptimizerType::ParametersType finalParameters =  registration->GetLastTransformParameters();
-      transform->SetParameters( finalParameters );
+      regTransform->SetParameters( finalParameters );
       if (m_Verbose)
       {
         std::cout<<"Stop condition description: "
@@ -747,21 +794,36 @@ namespace clitk
       // Compute the DVF (only deformable transform)
       //=======================================================
       typedef itk::Vector< float, SpaceDimension >  DisplacementType;
-      typedef itk::Image< DisplacementType, InputImageType::ImageDimension >  DeformationFieldType;
-      typedef itk::TransformToDeformationFieldSource<DeformationFieldType, double> ConvertorType;
+      typedef itk::Image< DisplacementType, InputImageType::ImageDimension >  DisplacementFieldType;
+#if (ITK_VERSION_MAJOR == 4) && (ITK_VERSION_MINOR < 6)
+      typedef itk::TransformToDisplacementFieldSource<DisplacementFieldType, double> ConvertorType;
+#else
+      typedef itk::TransformToDisplacementFieldFilter<DisplacementFieldType, double> ConvertorType;
+#endif
       typename ConvertorType::Pointer filter= ConvertorType::New();
+#if ITK_VERSION_MAJOR <= 4
       filter->SetNumberOfThreads(1);
-      transform->SetBulkTransform(NULL);
-      filter->SetTransform(transform);
+#else
+      filter->SetNumberOfWorkUnits(1);
+#endif
+      if(m_ArgsInfo.itkbspline_flag)
+        sTransform->SetBulkTransform(ITK_NULLPTR);
+      else
+        transform->SetBulkTransform(ITK_NULLPTR);
+      filter->SetTransform(regTransform);
+#if ITK_VERSION_MAJOR > 4 || (ITK_VERSION_MAJOR == 4 && ITK_VERSION_MINOR >= 6)
+      filter->SetReferenceImage(fixedImage);
+#else
       filter->SetOutputParametersFromImage(fixedImage);
+#endif
       filter->Update();
-      typename DeformationFieldType::Pointer field = filter->GetOutput();
+      typename DisplacementFieldType::Pointer field = filter->GetOutput();
 
 
       //=======================================================
       // Write the DVF
       //=======================================================
-      typedef itk::ImageFileWriter< DeformationFieldType >  FieldWriterType;
+      typedef itk::ImageFileWriter< DisplacementFieldType >  FieldWriterType;
       typename FieldWriterType::Pointer fieldWriter = FieldWriterType::New();
       fieldWriter->SetFileName( m_ArgsInfo.vf_arg );
       fieldWriter->SetInput( field );
@@ -782,8 +844,13 @@ namespace clitk
       //=======================================================
       typedef itk::ResampleImageFilter< MovingImageType, FixedImageType >    ResampleFilterType;
       typename ResampleFilterType::Pointer resampler = ResampleFilterType::New();
-      if (rigidTransform) transform->SetBulkTransform(rigidTransform);
-      resampler->SetTransform( transform );
+      if (rigidTransform) {
+        if(m_ArgsInfo.itkbspline_flag)
+          sTransform->SetBulkTransform(rigidTransform);
+        else
+          transform->SetBulkTransform(rigidTransform);
+      }
+      resampler->SetTransform( regTransform );
       resampler->SetInput( movingImage);
       resampler->SetOutputParametersFromImage(fixedImage);
       resampler->Update();