]> Creatis software - clitk.git/blobdiff - registration/clitkBLUTDIRGenericFilter.cxx
added initial image centralization to BLUTDIR
[clitk.git] / registration / clitkBLUTDIRGenericFilter.cxx
old mode 100755 (executable)
new mode 100644 (file)
index 3e6c6ec..9f808e7
@@ -3,7 +3,7 @@ Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
 
 Authors belong to:
 - University of LYON              http://www.universite-lyon.fr/
-- Léon Bérard cancer center       http://oncora1.lyon.fnclcc.fr
+- Léon Bérard cancer center       http://www.centreleonberard.fr
 - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
 
 This software is distributed WITHOUT ANY WARRANTY; without even
@@ -14,7 +14,7 @@ It is distributed under dual licence
 
 - BSD        See included LICENSE.txt file
 - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
-======================================================================-====*/
+===========================================================================**/
 #ifndef clitkBLUTDIRGenericFilter_cxx
 #define clitkBLUTDIRGenericFilter_cxx
 
@@ -28,7 +28,9 @@ It is distributed under dual licence
  ===================================================*/
 
 #include "clitkBLUTDIRGenericFilter.h"
-
+#include "clitkBLUTDIRCommandIterationUpdateDVF.h"
+#include "itkCenteredTransformInitializer.h"
+  
 namespace clitk
 {
 
@@ -78,7 +80,7 @@ namespace clitk
   {
     InitializeImageType<2>();
     InitializeImageType<3>();
-    m_Verbose=true;
+    m_Verbose=false;
   }
 
   //=========================================================================//
@@ -93,6 +95,8 @@ namespace clitk
     }
 
     if (m_ArgsInfo.output_given) SetOutputFilename(m_ArgsInfo.output_arg);
+    
+    if (m_ArgsInfo.verbose_given) m_Verbose=true;
   }
 
   //=========================================================================//
@@ -127,8 +131,8 @@ namespace clitk
       typedef typename RegistrationType::FixedImageType FixedImageType;
       typedef typename FixedImageType::RegionType RegionType;
       itkStaticConstMacro(ImageDimension, unsigned int,FixedImageType::ImageDimension);
-      typedef clitk::BSplineDeformableTransform<double, ImageDimension, ImageDimension> TransformType;
-      typedef clitk::BSplineDeformableTransformInitializer<TransformType, FixedImageType> InitializerType;
+      typedef clitk::MultipleBSplineDeformableTransform<double, ImageDimension, ImageDimension> TransformType;
+      typedef clitk::MultipleBSplineDeformableTransformInitializer<TransformType, FixedImageType> InitializerType;
       typedef typename InitializerType::CoefficientImageType CoefficientImageType;
       typedef itk::CastImageFilter<CoefficientImageType, CoefficientImageType> CastImageFilterType;
       typedef typename TransformType::ParametersType ParametersType;
@@ -186,14 +190,20 @@ namespace clitk
           registration->SetMetric(metric);
 
           // Get the current coefficient image and make a COPY
-          typename itk::ImageDuplicator<CoefficientImageType>::Pointer caster=itk::ImageDuplicator<CoefficientImageType>::New();
-          caster->SetInputImage(m_Initializer->GetTransform()->GetCoefficientImage());
-          caster->Update();
-          typename CoefficientImageType::Pointer currentCoefficientImage=caster->GetOutput();
+          typename itk::ImageDuplicator<CoefficientImageType>::Pointer caster = itk::ImageDuplicator<CoefficientImageType>::New();
+          std::vector<typename CoefficientImageType::Pointer> currentCoefficientImages = m_Initializer->GetTransform()->GetCoefficientImages();
+          for (unsigned i = 0; i < currentCoefficientImages.size(); ++i)
+          {
+            caster->SetInputImage(currentCoefficientImages[i]);
+            caster->Update();
+            currentCoefficientImages[i] = caster->GetOutput();
+          }
 
+          /*
           // Write the intermediate result?
           if (m_ArgsInfo.intermediate_given>=numberOfLevels)
             writeImage<CoefficientImageType>(currentCoefficientImage, m_ArgsInfo.intermediate_arg[currentLevel-2], m_ArgsInfo.verbose_flag);
+            */
 
           // Set the new transform properties
           m_Initializer->SetImage(registration->GetFixedImagePyramid()->GetOutput(currentLevel-1));
@@ -231,7 +241,7 @@ namespace clitk
 
           // Set the previous transform parameters to the registration
           // if(m_Initializer->m_Parameters!=NULL )delete m_Initializer->m_Parameters;
-          m_Initializer->SetInitialParameters(currentCoefficientImage,*newParameters);
+          m_Initializer->SetInitialParameters(currentCoefficientImages, *newParameters);
           registration->SetInitialTransformParametersOfNextLevel(*newParameters);
         }
       }
@@ -267,6 +277,8 @@ namespace clitk
   template<class InputImageType>
     void BLUTDIRGenericFilter::UpdateWithInputImageType()
     {
+      if (m_Verbose) std::cout << "BLUTDIRGenericFilter::UpdateWithInputImageType()" << std::endl;
+      
       //=============================================================================
       //Input
       //=============================================================================
@@ -311,7 +323,6 @@ namespace clitk
       // The metric region with respect to the extracted transform region:
       // where should the metric be CALCULATED (depends on transform)
       typename FixedImageType::RegionType metricRegion = fixedImage->GetLargestPossibleRegion();
-      typename FixedImageType::RegionType::SizeType metricRegionSize=metricRegion.GetSize();
       typename FixedImageType::RegionType::IndexType metricRegionIndex=metricRegion.GetIndex();
       typename FixedImageType::PointType metricRegionOrigin=fixedImage->GetOrigin();
 
@@ -320,33 +331,46 @@ namespace clitk
       // If given, we connect a mask to reference or target
       //============================================================================
       typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension >   MaskType;
-      typename MaskType::Pointer  fixedMask=NULL;
+      typedef itk::Image< unsigned char, InputImageType::ImageDimension >   ImageLabelType;
+      typename MaskType::Pointer        fixedMask = NULL;
+      typename ImageLabelType::Pointer  labels = NULL;
       if (m_ArgsInfo.referenceMask_given)
       {
-        fixedMask= MaskType::New();
-        typedef itk::Image< unsigned char,InputImageType::ImageDimension >   ImageMaskType;
-        typedef itk::ImageFileReader< ImageMaskType >    MaskReaderType;
-        typename MaskReaderType::Pointer  maskReader = MaskReaderType::New();
-        maskReader->SetFileName(m_ArgsInfo.referenceMask_arg);
+        fixedMask = MaskType::New();
+        labels = ImageLabelType::New();
+        typedef itk::ImageFileReader< ImageLabelType >    LabelReaderType;
+        typename LabelReaderType::Pointer  labelReader = LabelReaderType::New();
+        labelReader->SetFileName(m_ArgsInfo.referenceMask_arg);
         try
         {
-          maskReader->Update();
+          labelReader->Update();
         }
         catch( itk::ExceptionObject & err )
         {
-          std::cerr << "ExceptionObject caught while reading mask !" << std::endl;
+          std::cerr << "ExceptionObject caught while reading mask or labels !" << std::endl;
           std::cerr << err << std::endl;
           return;
         }
         if (m_Verbose)std::cout <<"Reference image mask was read..." <<std::endl;
 
+        // Resample labels
+        typedef itk::ResampleImageFilter<ImageLabelType, ImageLabelType> ResamplerType;
+        typename ResamplerType::Pointer resampler = ResamplerType::New();
+        typedef itk::NearestNeighborInterpolateImageFunction<ImageLabelType, TCoordRep> InterpolatorType;
+        typename InterpolatorType::Pointer interpolator = InterpolatorType::New();
+        resampler->SetOutputParametersFromImage(fixedImage);
+        resampler->SetInterpolator(interpolator);
+        resampler->SetInput(labelReader->GetOutput());
+        resampler->Update();
+        labels = resampler->GetOutput();
+
         // Set the image to the spatialObject
-        fixedMask->SetImage( maskReader->GetOutput() );
+        fixedMask->SetImage(labels);
 
         // Find the bounding box of the "inside" label
-        typedef itk::LabelGeometryImageFilter<ImageMaskType> GeometryImageFilterType;
+        typedef itk::LabelGeometryImageFilter<ImageLabelType> GeometryImageFilterType;
         typename GeometryImageFilterType::Pointer geometryImageFilter=GeometryImageFilterType::New();
-        geometryImageFilter->SetInput(maskReader->GetOutput());
+        geometryImageFilter->SetInput(labels);
         geometryImageFilter->Update();
         typename GeometryImageFilterType::BoundingBoxType boundingBox = geometryImageFilter->GetBoundingBox(1);
 
@@ -363,6 +387,9 @@ namespace clitk
         // Crop the fixedImage to the bounding box to facilitate multi-resolution
         typedef itk::ExtractImageFilter<FixedImageType,FixedImageType> ExtractImageFilterType;
         typename ExtractImageFilterType::Pointer extractImageFilter=ExtractImageFilterType::New();
+#if ITK_VERSION_MAJOR == 4
+        extractImageFilter->SetDirectionCollapseToSubmatrix();
+#endif
         extractImageFilter->SetInput(fixedImage);
         extractImageFilter->SetExtractionRegion(transformRegion);
         extractImageFilter->Update();
@@ -469,20 +496,34 @@ namespace clitk
         itk::Vector<double,3> finalTranslation = clitk::GetTranslationPartMatrix3D(rigidTransformMatrix);
         rigidTransform->SetTranslation(finalTranslation);
       }
+      else
+      {
+        if(m_Verbose) std::cout<<"No itinial matrix given. Centering all images..."<<std::endl;
+        
+        rigidTransform=RigidTransformType::New();
+        
+        typedef itk::CenteredTransformInitializer<RigidTransformType, FixedImageType, MovingImageType > TransformInitializerType;
+        typename TransformInitializerType::Pointer initializer = TransformInitializerType::New();
+        initializer->SetTransform( rigidTransform );
+        initializer->SetFixedImage( fixedImage );
+        initializer->SetMovingImage( movingImage );        
+        initializer->GeometryOn();
+        initializer->InitializeTransform();
+      }
 
 
       //=======================================================
       // B-LUT FFD Transform
       //=======================================================
-      typedef  clitk::BSplineDeformableTransform<TCoordRep,InputImageType::ImageDimension, SpaceDimension > TransformType;
-      typename TransformType::Pointer transform= TransformType::New();
-      if (fixedMask) transform->SetMask( fixedMask );
-      if (rigidTransform) transform->SetBulkTransform( rigidTransform );
+      typedef  clitk::MultipleBSplineDeformableTransform<TCoordRep,InputImageType::ImageDimension, SpaceDimension > TransformType;
+      typename TransformType::Pointer transform = TransformType::New();
+      if (labels) transform->SetLabels(labels);
+      if (rigidTransform) transform->SetBulkTransform(rigidTransform);
 
       //-------------------------------------------------------------------------
       // The transform initializer
       //-------------------------------------------------------------------------
-      typedef clitk::BSplineDeformableTransformInitializer< TransformType,FixedImageType> InitializerType;
+      typedef clitk::MultipleBSplineDeformableTransformInitializer< TransformType,FixedImageType> InitializerType;
       typename InitializerType::Pointer initializer = InitializerType::New();
       initializer->SetVerbose(m_Verbose);
       initializer->SetImage(fixedImagePyramid->GetOutput(0));
@@ -560,6 +601,25 @@ namespace clitk
       transform->SetParameters( parameters );
       if (m_ArgsInfo.initCoeff_given) initializer->SetInitialParameters(m_ArgsInfo.initCoeff_arg, parameters);
 
+      //-------------------------------------------------------------------------
+      // DEBUG: use an itk BSpline instead of multilabel BLUTs
+      //-------------------------------------------------------------------------
+      typedef itk::Transform< TCoordRep, 3, 3 > RegistrationTransformType;
+      RegistrationTransformType::Pointer regTransform(transform);
+      typedef itk::BSplineDeformableTransform<TCoordRep,SpaceDimension, 3> SingleBSplineTransformType;
+      typename SingleBSplineTransformType::Pointer sTransform;
+      if(m_ArgsInfo.itkbspline_flag) {
+        if( transform->GetTransforms().size()>1)
+          itkExceptionMacro(<< "invalid --itkbspline option if there is more than 1 label")
+        sTransform = SingleBSplineTransformType::New();
+        sTransform->SetBulkTransform( transform->GetTransforms()[0]->GetBulkTransform() );
+        sTransform->SetGridSpacing( transform->GetTransforms()[0]->GetGridSpacing() );
+        sTransform->SetGridOrigin( transform->GetTransforms()[0]->GetGridOrigin() );
+        sTransform->SetGridRegion( transform->GetTransforms()[0]->GetGridRegion() );
+        sTransform->SetParameters( transform->GetTransforms()[0]->GetParameters() );
+        regTransform = sTransform;
+        transform = NULL; // free memory
+      }
 
       //=======================================================
       // Interpolator
@@ -584,7 +644,7 @@ namespace clitk
       typename  MetricType::Pointer metric=genericMetric->GetMetricPointer();
       if (movingMask) metric->SetMovingImageMask(movingMask);
 
-#ifdef ITK_USE_OPTIMIZED_REGISTRATION_METHODS
+#if defined(ITK_USE_OPTIMIZED_REGISTRATION_METHODS) || ITK_VERSION_MAJOR >= 4
       if (threadsGiven) {
         metric->SetNumberOfThreads( threads );
         if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl;
@@ -601,7 +661,7 @@ namespace clitk
       GenericOptimizerType::Pointer genericOptimizer = GenericOptimizerType::New();
       genericOptimizer->SetArgsInfo(m_ArgsInfo);
       genericOptimizer->SetMaximize(genericMetric->GetMaximize());
-      genericOptimizer->SetNumberOfParameters(transform->GetNumberOfParameters());
+      genericOptimizer->SetNumberOfParameters(regTransform->GetNumberOfParameters());
       typedef itk::SingleValuedNonLinearOptimizer OptimizerType;
       OptimizerType::Pointer optimizer = genericOptimizer->GetOptimizerPointer();
 
@@ -614,7 +674,7 @@ namespace clitk
       registration->SetMetric(        metric        );
       registration->SetOptimizer(     optimizer     );
       registration->SetInterpolator(  interpolator  );
-      registration->SetTransform (transform);
+      registration->SetTransform (regTransform );
       if(threadsGiven) {
         registration->SetNumberOfThreads(threads);
         if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl;
@@ -624,7 +684,7 @@ namespace clitk
       registration->SetFixedImageRegion( metricRegion );
       registration->SetFixedImagePyramid( fixedImagePyramid );
       registration->SetMovingImagePyramid( movingImagePyramid );
-      registration->SetInitialTransformParameters( transform->GetParameters() );
+      registration->SetInitialTransformParameters( regTransform->GetParameters() );
       registration->SetNumberOfLevels( m_ArgsInfo.levels_arg );
       if (m_Verbose) std::cout<<"Setting the number of resolution levels to "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
 
@@ -648,6 +708,21 @@ namespace clitk
         command->SetMaximize(genericMetric->GetMaximize());
         command->SetMetricRegion(metricRegion);
         registration->AddObserver( itk::IterationEvent(), command );
+
+        if (m_ArgsInfo.coeff_given)
+        {
+          if(m_ArgsInfo.itkbspline_flag) {
+            itkExceptionMacro("--coeff and --itkbpline are incompatible");
+          }
+
+          std::cout << std::endl << "Output coefficient images every " << m_ArgsInfo.coeffEveryN_arg << " iterations." << std::endl;
+          typedef CommandIterationUpdateDVF<FixedImageType, OptimizerType, TransformType> DVFCommandType;
+          typename DVFCommandType::Pointer observerdvf = DVFCommandType::New();
+          observerdvf->SetFixedImage(fixedImage);
+          observerdvf->SetTransform(transform);
+          observerdvf->SetArgsInfo(m_ArgsInfo);
+          optimizer->AddObserver( itk::IterationEvent(), observerdvf );
+        }
       }
 
 
@@ -672,7 +747,7 @@ namespace clitk
       // Get the result
       //=======================================================
       OptimizerType::ParametersType finalParameters =  registration->GetLastTransformParameters();
-      transform->SetParameters( finalParameters );
+      regTransform->SetParameters( finalParameters );
       if (m_Verbose)
       {
         std::cout<<"Stop condition description: "
@@ -686,12 +761,24 @@ namespace clitk
       if (m_ArgsInfo.coeff_given)
       {
         typedef typename TransformType::CoefficientImageType CoefficientImageType;
-        typename CoefficientImageType::Pointer coefficientImage =transform->GetCoefficientImage();
+        std::vector<typename CoefficientImageType::Pointer> coefficientImages = transform->GetCoefficientImages();
         typedef itk::ImageFileWriter<CoefficientImageType> CoeffWriterType;
-        typename CoeffWriterType::Pointer coeffWriter=CoeffWriterType::New();
-        coeffWriter->SetInput(coefficientImage);
-        coeffWriter->SetFileName(m_ArgsInfo.coeff_arg);
-        coeffWriter->Update();
+        typename CoeffWriterType::Pointer coeffWriter = CoeffWriterType::New();
+        unsigned nLabels = transform->GetnLabels();
+
+        std::string fname(m_ArgsInfo.coeff_arg);
+        int dotpos = fname.length() - 1;
+        while (dotpos >= 0 && fname[dotpos] != '.')
+          dotpos--;
+
+        for (unsigned i = 0; i < nLabels; ++i)
+        {
+          std::ostringstream osfname;
+          osfname << fname.substr(0, dotpos) << '_' << i << fname.substr(dotpos);
+          coeffWriter->SetInput(coefficientImages[i]);
+          coeffWriter->SetFileName(osfname.str());
+          coeffWriter->Update();
+        }
       }
 
 
@@ -700,21 +787,28 @@ namespace clitk
       // Compute the DVF (only deformable transform)
       //=======================================================
       typedef itk::Vector< float, SpaceDimension >  DisplacementType;
-      typedef itk::Image< DisplacementType, InputImageType::ImageDimension >  DeformationFieldType;
-      typedef itk::TransformToDeformationFieldSource<DeformationFieldType, double> ConvertorType;
+      typedef itk::Image< DisplacementType, InputImageType::ImageDimension >  DisplacementFieldType;
+#if ITK_VERSION_MAJOR >= 4
+      typedef itk::TransformToDisplacementFieldSource<DisplacementFieldType, double> ConvertorType;
+#else
+      typedef itk::TransformToDeformationFieldSource<DisplacementFieldType, double> ConvertorType;
+#endif
       typename ConvertorType::Pointer filter= ConvertorType::New();
       filter->SetNumberOfThreads(1);
-      transform->SetBulkTransform(NULL);
-      filter->SetTransform(transform);
+      if(m_ArgsInfo.itkbspline_flag)
+        sTransform->SetBulkTransform(NULL);
+      else
+        transform->SetBulkTransform(NULL);
+      filter->SetTransform(regTransform);
       filter->SetOutputParametersFromImage(fixedImage);
       filter->Update();
-      typename DeformationFieldType::Pointer field = filter->GetOutput();
+      typename DisplacementFieldType::Pointer field = filter->GetOutput();
 
 
       //=======================================================
       // Write the DVF
       //=======================================================
-      typedef itk::ImageFileWriter< DeformationFieldType >  FieldWriterType;
+      typedef itk::ImageFileWriter< DisplacementFieldType >  FieldWriterType;
       typename FieldWriterType::Pointer fieldWriter = FieldWriterType::New();
       fieldWriter->SetFileName( m_ArgsInfo.vf_arg );
       fieldWriter->SetInput( field );
@@ -735,8 +829,13 @@ namespace clitk
       //=======================================================
       typedef itk::ResampleImageFilter< MovingImageType, FixedImageType >    ResampleFilterType;
       typename ResampleFilterType::Pointer resampler = ResampleFilterType::New();
-      if (rigidTransform) transform->SetBulkTransform(rigidTransform);
-      resampler->SetTransform( transform );
+      if (rigidTransform) {
+        if(m_ArgsInfo.itkbspline_flag)
+          sTransform->SetBulkTransform(rigidTransform);
+        else
+          transform->SetBulkTransform(rigidTransform);
+      }
+      resampler->SetTransform( regTransform );
       resampler->SetInput( movingImage);
       resampler->SetOutputParametersFromImage(fixedImage);
       resampler->Update();