X-Git-Url: https://git.creatis.insa-lyon.fr/pubgit/?a=blobdiff_plain;f=registration%2FclitkBLUTDIRGenericFilter.cxx;h=24d461361f9e27a232e54a9e3fcfd757de02f8db;hb=HEAD;hp=c728fcda9ff090d19fe2dcfb10100d427c9ebed9;hpb=657652a78c2e2717a6f77e027049173442ca29f0;p=clitk.git diff --git a/registration/clitkBLUTDIRGenericFilter.cxx b/registration/clitkBLUTDIRGenericFilter.cxx old mode 100755 new mode 100644 index c728fcd..24d4613 --- a/registration/clitkBLUTDIRGenericFilter.cxx +++ b/registration/clitkBLUTDIRGenericFilter.cxx @@ -1,73 +1,969 @@ /*========================================================================= - Program: vv http://www.creatis.insa-lyon.fr/rio/vv +Program: vv http://www.creatis.insa-lyon.fr/rio/vv - Authors belong to: - - University of LYON http://www.universite-lyon.fr/ - - Léon Bérard cancer center http://oncora1.lyon.fnclcc.fr - - CREATIS CNRS laboratory http://www.creatis.insa-lyon.fr +Authors belong to: +- University of LYON http://www.universite-lyon.fr/ +- Léon Bérard cancer center http://www.centreleonberard.fr +- CREATIS CNRS laboratory http://www.creatis.insa-lyon.fr - This software is distributed WITHOUT ANY WARRANTY; without even - the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR - PURPOSE. See the copyright notices for more information. +This software is distributed WITHOUT ANY WARRANTY; without even +the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR +PURPOSE. See the copyright notices for more information. - It is distributed under dual licence +It is distributed under dual licence - - BSD See included LICENSE.txt file - - CeCILL-B http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html -======================================================================-====*/ +- BSD See included LICENSE.txt file +- CeCILL-B http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html +===========================================================================**/ #ifndef clitkBLUTDIRGenericFilter_cxx #define clitkBLUTDIRGenericFilter_cxx /* ================================================= * @file clitkBLUTDIRGenericFilter.cxx - * @author - * @date - * - * @brief - * + * @author + * @date + * + * @brief + * ===================================================*/ #include "clitkBLUTDIRGenericFilter.h" - +#include "clitkBLUTDIRCommandIterationUpdateDVF.h" +#include "itkCenteredTransformInitializer.h" +#include "itkLabelStatisticsImageFilter.h" +#if (ITK_VERSION_MAJOR == 4) && (ITK_VERSION_MINOR < 6) +# include "itkTransformToDisplacementFieldSource.h" +#else +# include "itkTransformToDisplacementFieldFilter.h" +#endif namespace clitk { + //============================================================================== + // Creating an observer class that allows output at each iteration + //============================================================================== + class CommandIterationUpdate : public itk::Command + { + public: + typedef CommandIterationUpdate Self; + typedef itk::Command Superclass; + typedef itk::SmartPointer Pointer; + itkNewMacro( Self ); + protected: + CommandIterationUpdate() {}; + public: + typedef clitk::GenericOptimizer OptimizerType; + typedef const OptimizerType * OptimizerPointer; + + // Set the generic optimizer + void SetOptimizer(OptimizerPointer o){m_Optimizer=o;} + + // Execute + void Execute(itk::Object *caller, const itk::EventObject & event) + { + Execute( (const itk::Object *)caller, event); + } + + void Execute(const itk::Object * object, const itk::EventObject & event) + { + if( !(itk::IterationEvent().CheckEvent( &event )) ) + { + return; + } + + m_Optimizer->OutputIterationInfo(); + } + + OptimizerPointer m_Optimizer; + }; - //----------------------------------------------------------- - // Constructor - //----------------------------------------------------------- - BLUTDIRGenericFilter::BLUTDIRGenericFilter() + //===========================================================================// + //Constructor + //==========================================================================// + BLUTDIRGenericFilter::BLUTDIRGenericFilter(): + ImageToImageGenericFilter("Register DIR") { + InitializeImageType<2>(); + InitializeImageType<3>(); m_Verbose=false; - m_ReferenceFileName=""; } + //=========================================================================// + //SetArgsInfo + //==========================================================================// + void BLUTDIRGenericFilter::SetArgsInfo(const args_info_clitkBLUTDIR & a){ + m_ArgsInfo=a; + if (m_ArgsInfo.reference_given) AddInputFilename(m_ArgsInfo.reference_arg); - //----------------------------------------------------------- - // Update - //----------------------------------------------------------- - void BLUTDIRGenericFilter::Update() - { - // Read the Dimension and PixelType - int Dimension; - std::string PixelType; - ReadImageDimensionAndPixelType(m_ReferenceFileName, Dimension, PixelType); + if (m_ArgsInfo.target_given) { + AddInputFilename(m_ArgsInfo.target_arg); + } + if (m_ArgsInfo.output_given) SetOutputFilename(m_ArgsInfo.output_arg); - // Call UpdateWithDim - //if(Dimension==2) UpdateWithDim<2>(PixelType); - //else - if(Dimension==3) UpdateWithDim<3>(PixelType); - // else if (Dimension==4)UpdateWithDim<4>(PixelType); - else + if (m_ArgsInfo.verbose_given) m_Verbose=true; + } + + //=========================================================================// + //===========================================================================// + template + void BLUTDIRGenericFilter::InitializeImageType() + { + ADD_DEFAULT_IMAGE_TYPES(3); + } + //-------------------------------------------------------------------- + + //============================================================================== + //Creating an observer class that allows us to change parameters at subsequent levels + //============================================================================== + template + class RegistrationInterfaceCommand : public itk::Command + { + public: + typedef RegistrationInterfaceCommand Self; + typedef itk::Command Superclass; + typedef itk::SmartPointer Pointer; + itkNewMacro( Self ); + protected: + RegistrationInterfaceCommand() { }; + public: + + // Registration + typedef TRegistration RegistrationType; + typedef RegistrationType * RegistrationPointer; + + // Transform + typedef typename RegistrationType::FixedImageType FixedImageType; + typedef typename FixedImageType::RegionType RegionType; + itkStaticConstMacro(ImageDimension, unsigned int,FixedImageType::ImageDimension); + typedef clitk::MultipleBSplineDeformableTransform TransformType; + typedef clitk::MultipleBSplineDeformableTransformInitializer InitializerType; + typedef typename InitializerType::CoefficientImageType CoefficientImageType; + typedef itk::CastImageFilter CastImageFilterType; + typedef typename TransformType::ParametersType ParametersType; + typedef typename InitializerType::Pointer InitializerPointer; + typedef CommandIterationUpdate::Pointer CommandIterationUpdatePointer; + + // Optimizer + typedef clitk::GenericOptimizer GenericOptimizerType; + typedef typename GenericOptimizerType::Pointer GenericOptimizerPointer; + + // Metric + typedef typename RegistrationType::FixedImageType InternalImageType; + typedef clitk::GenericMetric GenericMetricType; + typedef typename GenericMetricType::Pointer GenericMetricPointer; + + // Two arguments are passed to the Execute() method: the first + // is the pointer to the object which invoked the event and the + // second is the event that was invoked. + void Execute(itk::Object * object, const itk::EventObject & event) { - std::cout<<"Error, Only for 2 or 3 Dimensions!!!"<( object ); + unsigned int numberOfLevels=registration->GetNumberOfLevels(); + unsigned int currentLevel=registration->GetCurrentLevel()+1; + + // Output the levels + std::cout<1) + { + // fixed image region pyramid + typedef clitk::MultiResolutionPyramidRegionFilter FixedImageRegionPyramidType; + typename FixedImageRegionPyramidType::Pointer fixedImageRegionPyramid=FixedImageRegionPyramidType::New(); + fixedImageRegionPyramid->SetRegion(m_MetricRegion); + fixedImageRegionPyramid->SetSchedule(registration->GetFixedImagePyramid()->GetSchedule()); + + // Reinitialize the metric (!= number of samples) + m_GenericMetric= GenericMetricType::New(); + m_GenericMetric->SetArgsInfo(m_ArgsInfo); + m_GenericMetric->SetFixedImage(registration->GetFixedImagePyramid()->GetOutput(registration->GetCurrentLevel())); + if (m_ArgsInfo.referenceMask_given) m_GenericMetric->SetFixedImageMask(registration->GetMetric()->GetFixedImageMask()); + m_GenericMetric->SetFixedImageRegion(fixedImageRegionPyramid->GetOutput(registration->GetCurrentLevel())); + typedef itk::ImageToImageMetric< InternalImageType, InternalImageType > MetricType; + typename MetricType::Pointer metric=m_GenericMetric->GetMetricPointer(); + registration->SetMetric(metric); + + // Get the current coefficient image and make a COPY + typename itk::ImageDuplicator::Pointer caster = itk::ImageDuplicator::New(); + std::vector currentCoefficientImages = m_Initializer->GetTransform()->GetCoefficientImages(); + for (unsigned i = 0; i < currentCoefficientImages.size(); ++i) + { + caster->SetInputImage(currentCoefficientImages[i]); + caster->Update(); + currentCoefficientImages[i] = caster->GetOutput(); + } + + /* + // Write the intermediate result? + if (m_ArgsInfo.intermediate_given>=numberOfLevels) + writeImage(currentCoefficientImage, m_ArgsInfo.intermediate_arg[currentLevel-2], m_ArgsInfo.verbose_flag); + */ + + // Set the new transform properties + m_Initializer->SetImage(registration->GetFixedImagePyramid()->GetOutput(currentLevel-1)); + if( m_Initializer->m_ControlPointSpacingIsGiven) + m_Initializer->SetControlPointSpacing(m_Initializer->m_ControlPointSpacingArray[registration->GetCurrentLevel()]); + if( m_Initializer->m_NumberOfControlPointsIsGiven) + m_Initializer->SetNumberOfControlPointsInsideTheImage(m_Initializer->m_NumberOfControlPointsInsideTheImageArray[registration->GetCurrentLevel()]); + + // Reinitialize the transform + if (m_ArgsInfo.verbose_flag) std::cout<<"Initializing transform for level "<InitializeTransform(); + ParametersType* newParameters= new typename TransformType::ParametersType(m_Initializer->GetTransform()->GetNumberOfParameters()); + + // DS : if we want to skip the last pyramid level, force to only 1 iteration + DD(m_ArgsInfo.skipLastPyramidLevel_flag); + if ((currentLevel == numberOfLevels) && (m_ArgsInfo.skipLastPyramidLevel_flag)) { + DD(m_ArgsInfo.maxIt_arg); + std::cout << "I skip the last pyramid level : set max iteration to 0" << std::endl; + m_ArgsInfo.maxIt_arg = 0; + DD(m_ArgsInfo.maxIt_arg); + } + + // Reinitialize an Optimizer (!= number of parameters) + m_GenericOptimizer = GenericOptimizerType::New(); + m_GenericOptimizer->SetArgsInfo(m_ArgsInfo); + m_GenericOptimizer->SetMaximize(m_Maximize); + m_GenericOptimizer->SetNumberOfParameters(m_Initializer->GetTransform()->GetNumberOfParameters()); + + + typedef itk::SingleValuedNonLinearOptimizer OptimizerType; + OptimizerType::Pointer optimizer = m_GenericOptimizer->GetOptimizerPointer(); + optimizer->AddObserver( itk::IterationEvent(), m_CommandIterationUpdate); + registration->SetOptimizer(optimizer); + m_CommandIterationUpdate->SetOptimizer(m_GenericOptimizer); + + // Set the previous transform parameters to the registration + // if(m_Initializer->m_Parameters!=ITK_NULLPTR )delete m_Initializer->m_Parameters; + m_Initializer->SetInitialParameters(currentCoefficientImages, *newParameters); + registration->SetInitialTransformParametersOfNextLevel(*newParameters); + } + } + + void Execute(const itk::Object * , const itk::EventObject & ) + { return; } + + + // Members + void SetInitializer(InitializerPointer i){m_Initializer=i;} + InitializerPointer m_Initializer; + + void SetArgsInfo(args_info_clitkBLUTDIR a){m_ArgsInfo=a;} + args_info_clitkBLUTDIR m_ArgsInfo; + + void SetCommandIterationUpdate(CommandIterationUpdatePointer c){m_CommandIterationUpdate=c;}; + CommandIterationUpdatePointer m_CommandIterationUpdate; + + GenericOptimizerPointer m_GenericOptimizer; + void SetMaximize(bool b){m_Maximize=b;} + bool m_Maximize; + + GenericMetricPointer m_GenericMetric; + void SetMetricRegion(RegionType i){m_MetricRegion=i;} + RegionType m_MetricRegion; + + + }; + + //============================================================================== + // Update with the number of dimensions and pixeltype + //============================================================================== + template + void BLUTDIRGenericFilter::UpdateWithInputImageType() + { + if (m_Verbose) std::cout << "BLUTDIRGenericFilter::UpdateWithInputImageType()" << std::endl; + + //============================================================================= + //Input + //============================================================================= + bool threadsGiven=m_ArgsInfo.threads_given; + int threads=m_ArgsInfo.threads_arg; + typedef typename InputImageType::PixelType PixelType; + + typedef double TCoordRep; + + typename InputImageType::Pointer fixedImage = this->template GetInput(0); + + typename InputImageType::Pointer inputFixedImage = this->template GetInput(0); + + // typedef input2 + typename InputImageType::Pointer movingImage = this->template GetInput(1); + + typename InputImageType::Pointer inputMovingImage = this->template GetInput(1); + + typedef itk::Image< PixelType,InputImageType::ImageDimension > FixedImageType; + typedef itk::Image< PixelType, InputImageType::ImageDimension> MovingImageType; + const unsigned int SpaceDimension = InputImageType::ImageDimension; + //Whatever the pixel type, internally we work with an image represented in float + //Reading reference image + if (m_Verbose) std::cout<<"Reading images..."<GetLargestPossibleRegion(); + + // The transform region with respect to the input region: + // where should the transform be DEFINED (depends on mask) + typename FixedImageType::RegionType transformRegion = fixedImage->GetLargestPossibleRegion(); + typename FixedImageType::RegionType::SizeType transformRegionSize=transformRegion.GetSize(); + typename FixedImageType::RegionType::IndexType transformRegionIndex=transformRegion.GetIndex(); + typename FixedImageType::PointType transformRegionOrigin=fixedImage->GetOrigin(); + + // The metric region with respect to the extracted transform region: + // where should the metric be CALCULATED (depends on transform) + typename FixedImageType::RegionType metricRegion = fixedImage->GetLargestPossibleRegion(); + typename FixedImageType::RegionType::IndexType metricRegionIndex=metricRegion.GetIndex(); + typename FixedImageType::PointType metricRegionOrigin=fixedImage->GetOrigin(); + + + //=========================================================================== + // If given, we connect a mask to reference or target + //============================================================================ + typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension > MaskType; + typedef itk::Image< unsigned char, InputImageType::ImageDimension > ImageLabelType; + typename MaskType::Pointer fixedMask = ITK_NULLPTR; + typename ImageLabelType::Pointer labels = ITK_NULLPTR; + if (m_ArgsInfo.referenceMask_given) + { + fixedMask = MaskType::New(); + labels = ImageLabelType::New(); + typedef itk::ImageFileReader< ImageLabelType > LabelReaderType; + typename LabelReaderType::Pointer labelReader = LabelReaderType::New(); + labelReader->SetFileName(m_ArgsInfo.referenceMask_arg); + try + { + labelReader->Update(); + } + catch( itk::ExceptionObject & err ) + { + std::cerr << "ExceptionObject caught while reading mask or labels !" << std::endl; + std::cerr << err << std::endl; + return; + } + if (m_Verbose)std::cout <<"Reference image mask was read..." < ResamplerType; + typename ResamplerType::Pointer resampler = ResamplerType::New(); + typedef itk::NearestNeighborInterpolateImageFunction InterpolatorType; + typename InterpolatorType::Pointer interpolator = InterpolatorType::New(); + resampler->SetOutputParametersFromImage(fixedImage); + resampler->SetInterpolator(interpolator); + resampler->SetInput(labelReader->GetOutput()); + resampler->Update(); + labels = resampler->GetOutput(); + + // Set the image to the spatialObject + fixedMask->SetImage(labels); + + // Find the bounding box of the "inside" label + typedef itk::LabelStatisticsImageFilter StatisticsImageFilterType; + typename StatisticsImageFilterType::Pointer statisticsImageFilter=StatisticsImageFilterType::New(); + statisticsImageFilter->SetInput(labels); + statisticsImageFilter->SetLabelInput(labels); + statisticsImageFilter->Update(); + typename StatisticsImageFilterType::BoundingBoxType boundingBox = statisticsImageFilter->GetBoundingBox(1); + + // Limit the transform region to the mask + for (unsigned int i=0; iTransformIndexToPhysicalPoint(transformRegion.GetIndex(), transformRegionOrigin); + + // Crop the fixedImage to the bounding box to facilitate multi-resolution + typedef itk::ExtractImageFilter ExtractImageFilterType; + typename ExtractImageFilterType::Pointer extractImageFilter=ExtractImageFilterType::New(); + extractImageFilter->SetDirectionCollapseToSubmatrix(); + extractImageFilter->SetInput(fixedImage); + extractImageFilter->SetExtractionRegion(transformRegion); + extractImageFilter->Update(); + croppedFixedImage=extractImageFilter->GetOutput(); + + // Update the metric region + metricRegion = croppedFixedImage->GetLargestPossibleRegion(); + metricRegionIndex=metricRegion.GetIndex(); + croppedFixedImage->TransformIndexToPhysicalPoint(metricRegionIndex, metricRegionOrigin); + + // Set start index to zero (with respect to croppedFixedImage/transform region) + metricRegionIndex.Fill(0); + metricRegion.SetIndex(metricRegionIndex); + croppedFixedImage->SetRegions(metricRegion); + croppedFixedImage->SetOrigin(metricRegionOrigin); + } + + typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension > MaskType; + typename MaskType::Pointer movingMask=ITK_NULLPTR; + if (m_ArgsInfo.targetMask_given) + { + movingMask= MaskType::New(); + typedef itk::Image< unsigned char, InputImageType::ImageDimension > ImageMaskType; + typedef itk::ImageFileReader< ImageMaskType > MaskReaderType; + typename MaskReaderType::Pointer maskReader = MaskReaderType::New(); + maskReader->SetFileName(m_ArgsInfo.targetMask_arg); + try + { + maskReader->Update(); + } + catch( itk::ExceptionObject & err ) + { + std::cerr << "ExceptionObject caught !" << std::endl; + std::cerr << err << std::endl; + } + if (m_Verbose)std::cout <<"Target image mask was read..." <SetImage( maskReader->GetOutput() ); + } + + + //======================================================= + // Output Regions + //======================================================= + + if (m_Verbose) + { + // Fixed image region + std::cout<<"The fixed image has its origin at "<GetOrigin()< FixedImagePyramidType; + typedef itk::RecursiveMultiResolutionPyramidImageFilter< MovingImageType, MovingImageType> MovingImagePyramidType; + typename FixedImagePyramidType::Pointer fixedImagePyramid = FixedImagePyramidType::New(); + typename MovingImagePyramidType::Pointer movingImagePyramid = MovingImagePyramidType::New(); + fixedImagePyramid->SetUseShrinkImageFilter(false); + fixedImagePyramid->SetInput(croppedFixedImage); + fixedImagePyramid->SetNumberOfLevels(m_ArgsInfo.levels_arg); + movingImagePyramid->SetUseShrinkImageFilter(false); + movingImagePyramid->SetInput(movingImage); + movingImagePyramid->SetNumberOfLevels(m_ArgsInfo.levels_arg); + if (m_Verbose) std::cout<<"Creating the image pyramid..."<Update(); + movingImagePyramid->Update(); + typedef clitk::MultiResolutionPyramidRegionFilter FixedImageRegionPyramidType; + typename FixedImageRegionPyramidType::Pointer fixedImageRegionPyramid=FixedImageRegionPyramidType::New(); + fixedImageRegionPyramid->SetRegion(metricRegion); + fixedImageRegionPyramid->SetSchedule(fixedImagePyramid->GetSchedule()); + + + //======================================================= + // Rigid or Affine Transform + //======================================================= + typedef itk::AffineTransform RigidTransformType; + RigidTransformType::Pointer rigidTransform=ITK_NULLPTR; + if (m_ArgsInfo.initMatrix_given) + { + if(m_Verbose) std::cout<<"Reading the prior transform matrix "<< m_ArgsInfo.initMatrix_arg<<"..."< rigidTransformMatrix=clitk::ReadMatrix3D(m_ArgsInfo.initMatrix_arg); + + //Set the rotation + itk::Matrix finalRotation = clitk::GetRotationalPartMatrix3D(rigidTransformMatrix); + rigidTransform->SetMatrix(finalRotation); + + //Set the translation + itk::Vector finalTranslation = clitk::GetTranslationPartMatrix3D(rigidTransformMatrix); + rigidTransform->SetTranslation(finalTranslation); + } + else if (m_ArgsInfo.centre_flag) + { + if(m_Verbose) std::cout<<"No itinial matrix given and \"centre\" flag switched on. Centering all images..."< TransformInitializerType; + typename TransformInitializerType::Pointer initializer = TransformInitializerType::New(); + initializer->SetTransform( rigidTransform ); + initializer->SetFixedImage( fixedImage ); + initializer->SetMovingImage( movingImage ); + initializer->GeometryOn(); + initializer->InitializeTransform(); + } + + + //======================================================= + // B-LUT FFD Transform + //======================================================= + typedef clitk::MultipleBSplineDeformableTransform TransformType; + typename TransformType::Pointer transform = TransformType::New(); + if (labels) transform->SetLabels(labels); + if (rigidTransform) transform->SetBulkTransform(rigidTransform); + + //------------------------------------------------------------------------- + // The transform initializer + //------------------------------------------------------------------------- + typedef clitk::MultipleBSplineDeformableTransformInitializer< TransformType,FixedImageType> InitializerType; + typename InitializerType::Pointer initializer = InitializerType::New(); + initializer->SetVerbose(m_Verbose); + initializer->SetImage(fixedImagePyramid->GetOutput(0)); + initializer->SetTransform(transform); + + //------------------------------------------------------------------------- + // Order + //------------------------------------------------------------------------- + typename FixedImageType::RegionType::SizeType splineOrders ; + splineOrders.Fill(3); + if (m_ArgsInfo.order_given) + for(unsigned int i=0; iSetSplineOrders(splineOrders); + + //------------------------------------------------------------------------- + // Levels + //------------------------------------------------------------------------- + + // Spacing + if (m_ArgsInfo.spacing_given) + { + initializer->m_ControlPointSpacingArray.resize(m_ArgsInfo.levels_arg); + initializer->SetControlPointSpacing(m_ArgsInfo.spacing_arg); + initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1]=initializer->m_ControlPointSpacing; + if (m_Verbose) std::cout<<"Using a control point spacing of "<m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1] + <<" at level "<m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1-i]=initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-i]*2; + if (m_Verbose) std::cout<<"Using a control point spacing of "<m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1-i] + <<" at level "<m_NumberOfControlPointsInsideTheImageArray.resize(m_ArgsInfo.levels_arg); + initializer->SetNumberOfControlPointsInsideTheImage(m_ArgsInfo.control_arg); + initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1]=initializer->m_NumberOfControlPointsInsideTheImage; + if (m_Verbose) std::cout<<"Using "<< initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1]<<"control points inside the image" + <<" at level "<m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i][j]=ceil ((double)initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-i][j]/2.); + // initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i]=ceil ((double)initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-i]/2.); + if (m_Verbose) std::cout<<"Using "<< initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i]<<"control points inside the image" + <<" at level "<SetControlPointSpacing( initializer->m_ControlPointSpacingArray[0]); + if (m_ArgsInfo.control_given) initializer->SetNumberOfControlPointsInsideTheImage(initializer->m_NumberOfControlPointsInsideTheImageArray[0]); + if (m_ArgsInfo.samplingFactor_given) initializer->SetSamplingFactors(m_ArgsInfo.samplingFactor_arg); + + // Initialize + initializer->InitializeTransform(); + + //------------------------------------------------------------------------- + // Initial parameters (passed by reference) + //------------------------------------------------------------------------- + typedef typename TransformType::ParametersType ParametersType; + const unsigned int numberOfParameters = transform->GetNumberOfParameters(); + ParametersType parameters(numberOfParameters); + parameters.Fill( 0.0 ); + transform->SetParameters( parameters ); + if (m_ArgsInfo.initCoeff_given) initializer->SetInitialParameters(m_ArgsInfo.initCoeff_arg, parameters); + + //------------------------------------------------------------------------- + // DEBUG: use an itk BSpline instead of multilabel BLUTs + //------------------------------------------------------------------------- + typedef itk::Transform< TCoordRep, 3, 3 > RegistrationTransformType; + RegistrationTransformType::Pointer regTransform(transform); + typedef itk::BSplineDeformableTransform SingleBSplineTransformType; + typename SingleBSplineTransformType::Pointer sTransform; + if(m_ArgsInfo.itkbspline_flag) { + if( transform->GetTransforms().size()>1) + itkExceptionMacro(<< "invalid --itkbspline option if there is more than 1 label") + sTransform = SingleBSplineTransformType::New(); + sTransform->SetBulkTransform( transform->GetTransforms()[0]->GetBulkTransform() ); + sTransform->SetGridSpacing( transform->GetTransforms()[0]->GetGridSpacing() ); + sTransform->SetGridOrigin( transform->GetTransforms()[0]->GetGridOrigin() ); + sTransform->SetGridRegion( transform->GetTransforms()[0]->GetGridRegion() ); + sTransform->SetParameters( transform->GetTransforms()[0]->GetParameters() ); + regTransform = sTransform; + transform = ITK_NULLPTR; // free memory + } + + //======================================================= + // Interpolator + //======================================================= + typedef clitk::GenericInterpolator GenericInterpolatorType; + typename GenericInterpolatorType::Pointer genericInterpolator=GenericInterpolatorType::New(); + genericInterpolator->SetArgsInfo(m_ArgsInfo); + typedef itk::InterpolateImageFunction< FixedImageType, TCoordRep > InterpolatorType; + typename InterpolatorType::Pointer interpolator=genericInterpolator->GetInterpolatorPointer(); + + + //======================================================= + // Metric + //======================================================= + typedef clitk::GenericMetric GenericMetricType; + typename GenericMetricType::Pointer genericMetric=GenericMetricType::New(); + genericMetric->SetArgsInfo(m_ArgsInfo); + genericMetric->SetFixedImage(fixedImagePyramid->GetOutput(0)); + if (fixedMask) genericMetric->SetFixedImageMask(fixedMask); + genericMetric->SetFixedImageRegion(fixedImageRegionPyramid->GetOutput(0)); + typedef itk::ImageToImageMetric< FixedImageType, MovingImageType > MetricType; + typename MetricType::Pointer metric=genericMetric->GetMetricPointer(); + if (movingMask) metric->SetMovingImageMask(movingMask); + if (threadsGiven) { +#if ITK_VERSION_MAJOR <= 4 + metric->SetNumberOfThreads( threads ); +#else + metric->SetNumberOfWorkUnits( threads ); +#endif + if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl; + } + + //======================================================= + // Optimizer + //======================================================= + typedef clitk::GenericOptimizer GenericOptimizerType; + GenericOptimizerType::Pointer genericOptimizer = GenericOptimizerType::New(); + genericOptimizer->SetArgsInfo(m_ArgsInfo); + genericOptimizer->SetMaximize(genericMetric->GetMaximize()); + genericOptimizer->SetNumberOfParameters(regTransform->GetNumberOfParameters()); + typedef itk::SingleValuedNonLinearOptimizer OptimizerType; + OptimizerType::Pointer optimizer = genericOptimizer->GetOptimizerPointer(); + + + //======================================================= + // Registration + //======================================================= + typedef itk::MultiResolutionImageRegistrationMethod< FixedImageType, MovingImageType > RegistrationType; + typename RegistrationType::Pointer registration = RegistrationType::New(); + registration->SetMetric( metric ); + registration->SetOptimizer( optimizer ); + registration->SetInterpolator( interpolator ); + registration->SetTransform (regTransform ); + if(threadsGiven) { +#if ITK_VERSION_MAJOR <= 4 + registration->SetNumberOfThreads(threads); +#else + registration->SetNumberOfWorkUnits(threads); +#endif + if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl; + } + registration->SetFixedImage( croppedFixedImage ); + registration->SetMovingImage( movingImage ); + registration->SetFixedImageRegion( metricRegion ); + registration->SetFixedImagePyramid( fixedImagePyramid ); + registration->SetMovingImagePyramid( movingImagePyramid ); + registration->SetInitialTransformParameters( regTransform->GetParameters() ); + registration->SetNumberOfLevels( m_ArgsInfo.levels_arg ); + if (m_Verbose) std::cout<<"Setting the number of resolution levels to "<SetOptimizer(genericOptimizer); + optimizer->AddObserver( itk::IterationEvent(), observer ); + + // Output level info + typedef RegistrationInterfaceCommand CommandType; + typename CommandType::Pointer command = CommandType::New(); + command->SetInitializer(initializer); + command->SetArgsInfo(m_ArgsInfo); + command->SetCommandIterationUpdate(observer); + command->SetMaximize(genericMetric->GetMaximize()); + command->SetMetricRegion(metricRegion); + registration->AddObserver( itk::IterationEvent(), command ); + + if (m_ArgsInfo.coeff_given) + { + if(m_ArgsInfo.itkbspline_flag) { + itkExceptionMacro("--coeff and --itkbpline are incompatible"); + } + + std::cout << std::endl << "Output coefficient images every " << m_ArgsInfo.coeffEveryN_arg << " iterations." << std::endl; + typedef CommandIterationUpdateDVF DVFCommandType; + typename DVFCommandType::Pointer observerdvf = DVFCommandType::New(); + observerdvf->SetFixedImage(fixedImage); + observerdvf->SetTransform(transform); + observerdvf->SetArgsInfo(m_ArgsInfo); + optimizer->AddObserver( itk::IterationEvent(), observerdvf ); + } + } + + + //======================================================= + // Let's go + //======================================================= + if (m_Verbose) std::cout << std::endl << "Starting Registration" << std::endl; + + try + { + registration->Update(); + } + catch( itk::ExceptionObject & err ) + { + std::cerr << "ExceptionObject caught while registering!" << std::endl; + std::cerr << err << std::endl; + return; + } + + + //======================================================= + // Get the result + //======================================================= + OptimizerType::ParametersType finalParameters = registration->GetLastTransformParameters(); + regTransform->SetParameters( finalParameters ); + if (m_Verbose) + { + std::cout<<"Stop condition description: " + <GetOptimizer()->GetStopConditionDescription()< coefficientImages = transform->GetCoefficientImages(); + typedef itk::ImageFileWriter CoeffWriterType; + typename CoeffWriterType::Pointer coeffWriter = CoeffWriterType::New(); + unsigned nLabels = transform->GetnLabels(); + + std::string fname(m_ArgsInfo.coeff_arg); + int dotpos = fname.length() - 1; + while (dotpos >= 0 && fname[dotpos] != '.') + dotpos--; + + for (unsigned i = 0; i < nLabels; ++i) + { + std::ostringstream osfname; + osfname << fname.substr(0, dotpos) << '_' << i << fname.substr(dotpos); + coeffWriter->SetInput(coefficientImages[i]); + coeffWriter->SetFileName(osfname.str()); + coeffWriter->Update(); + } + } + + + + //======================================================= + // Compute the DVF (only deformable transform) + //======================================================= + typedef itk::Vector< float, SpaceDimension > DisplacementType; + typedef itk::Image< DisplacementType, InputImageType::ImageDimension > DisplacementFieldType; +#if (ITK_VERSION_MAJOR == 4) && (ITK_VERSION_MINOR < 6) + typedef itk::TransformToDisplacementFieldSource ConvertorType; +#else + typedef itk::TransformToDisplacementFieldFilter ConvertorType; +#endif + typename ConvertorType::Pointer filter= ConvertorType::New(); +#if ITK_VERSION_MAJOR <= 4 + filter->SetNumberOfThreads(1); +#else + filter->SetNumberOfWorkUnits(1); +#endif + if(m_ArgsInfo.itkbspline_flag) + sTransform->SetBulkTransform(ITK_NULLPTR); + else + transform->SetBulkTransform(ITK_NULLPTR); + filter->SetTransform(regTransform); +#if ITK_VERSION_MAJOR > 4 || (ITK_VERSION_MAJOR == 4 && ITK_VERSION_MINOR >= 6) + filter->SetReferenceImage(fixedImage); +#else + filter->SetOutputParametersFromImage(fixedImage); +#endif + filter->Update(); + typename DisplacementFieldType::Pointer field = filter->GetOutput(); + + + //======================================================= + // Write the DVF + //======================================================= + typedef itk::ImageFileWriter< DisplacementFieldType > FieldWriterType; + typename FieldWriterType::Pointer fieldWriter = FieldWriterType::New(); + fieldWriter->SetFileName( m_ArgsInfo.vf_arg ); + fieldWriter->SetInput( field ); + try + { + fieldWriter->Update(); + } + catch( itk::ExceptionObject & excp ) + { + std::cerr << "Exception thrown writing the DVF" << std::endl; + std::cerr << excp << std::endl; + return; + } + + + //======================================================= + // Resample the moving image + //======================================================= + typedef itk::ResampleImageFilter< MovingImageType, FixedImageType > ResampleFilterType; + typename ResampleFilterType::Pointer resampler = ResampleFilterType::New(); + if (rigidTransform) { + if(m_ArgsInfo.itkbspline_flag) + sTransform->SetBulkTransform(rigidTransform); + else + transform->SetBulkTransform(rigidTransform); + } + resampler->SetTransform( regTransform ); + resampler->SetInput( movingImage); + resampler->SetOutputParametersFromImage(fixedImage); + resampler->Update(); + typename FixedImageType::Pointer result=resampler->GetOutput(); + + // typedef itk::WarpImageFilter< MovingImageType, FixedImageType, DeformationFieldType > WarpFilterType; + // typename WarpFilterType::Pointer warp = WarpFilterType::New(); + + // warp->SetDeformationField( field ); + // warp->SetInput( movingImageReader->GetOutput() ); + // warp->SetOutputOrigin( fixedImage->GetOrigin() ); + // warp->SetOutputSpacing( fixedImage->GetSpacing() ); + // warp->SetOutputDirection( fixedImage->GetDirection() ); + // warp->SetEdgePaddingValue( 0.0 ); + // warp->Update(); + + + //======================================================= + // Write the warped image + //======================================================= + typedef itk::ImageFileWriter< FixedImageType > WriterType; + typename WriterType::Pointer writer = WriterType::New(); + writer->SetFileName( m_ArgsInfo.output_arg ); + writer->SetInput( result ); + + try + { + writer->Update(); + } + catch( itk::ExceptionObject & err ) + { + std::cerr << "ExceptionObject caught !" << std::endl; + std::cerr << err << std::endl; + return; + } + + + //======================================================= + // Calculate the difference after the deformable transform + //======================================================= + typedef clitk::DifferenceImageFilter< FixedImageType, FixedImageType> DifferenceFilterType; + if (m_ArgsInfo.after_given) + { + typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New(); + difference->SetValidInput( fixedImage ); + difference->SetTestInput( result ); + + try + { + difference->Update(); + } + catch( itk::ExceptionObject & err ) + { + std::cerr << "ExceptionObject caught calculating the difference !" << std::endl; + std::cerr << err << std::endl; + return; + } + + typename WriterType::Pointer differenceWriter=WriterType::New(); + differenceWriter->SetInput(difference->GetOutput()); + differenceWriter->SetFileName(m_ArgsInfo.after_arg); + differenceWriter->Update(); + + } + + + //======================================================= + // Calculate the difference before the deformable transform + //======================================================= + if( m_ArgsInfo.before_given ) + { + + typename FixedImageType::Pointer moving=FixedImageType::New(); + if (m_ArgsInfo.initMatrix_given) + { + typedef itk::ResampleImageFilter ResamplerType; + typename ResamplerType::Pointer resampler=ResamplerType::New(); + resampler->SetInput(movingImage); + resampler->SetOutputOrigin(fixedImage->GetOrigin()); + resampler->SetSize(fixedImage->GetLargestPossibleRegion().GetSize()); + resampler->SetOutputSpacing(fixedImage->GetSpacing()); + resampler->SetDefaultPixelValue( 0. ); + if (rigidTransform ) resampler->SetTransform(rigidTransform); + resampler->Update(); + moving=resampler->GetOutput(); + } + else + moving=movingImage; + + typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New(); + difference->SetValidInput( fixedImage ); + difference->SetTestInput( moving ); + + try + { + difference->Update(); + } + catch( itk::ExceptionObject & err ) + { + std::cerr << "ExceptionObject caught calculating the difference !" << std::endl; + std::cerr << err << std::endl; + return; + } + + typename WriterType::Pointer differenceWriter=WriterType::New(); + writer->SetFileName( m_ArgsInfo.before_arg ); + writer->SetInput( difference->GetOutput() ); + writer->Update( ); } - } + return; -} //end clitk + } +}//end clitk -#endif //#define clitkBLUTDIRGenericFilter_cxx +#endif // #define clitkBLUTDIRGenericFilter_txx