/*========================================================================= Program: vv http://www.creatis.insa-lyon.fr/rio/vv Authors belong to: - University of LYON http://www.universite-lyon.fr/ - Léon Bérard cancer center http://www.centreleonberard.fr - CREATIS CNRS laboratory http://www.creatis.insa-lyon.fr This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the copyright notices for more information. It is distributed under dual licence - BSD See included LICENSE.txt file - CeCILL-B http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html ===========================================================================**/ #ifndef clitkBLUTDIRGenericFilter_cxx #define clitkBLUTDIRGenericFilter_cxx /* ================================================= * @file clitkBLUTDIRGenericFilter.cxx * @author * @date * * @brief * ===================================================*/ #include "clitkBLUTDIRGenericFilter.h" #include "clitkBLUTDIRCommandIterationUpdateDVF.h" #include "itkCenteredTransformInitializer.h" #include "itkLabelStatisticsImageFilter.h" #if (ITK_VERSION_MAJOR == 4) && (ITK_VERSION_MINOR < 6) # include "itkTransformToDisplacementFieldSource.h" #else # include "itkTransformToDisplacementFieldFilter.h" #endif namespace clitk { //============================================================================== // Creating an observer class that allows output at each iteration //============================================================================== class CommandIterationUpdate : public itk::Command { public: typedef CommandIterationUpdate Self; typedef itk::Command Superclass; typedef itk::SmartPointer Pointer; itkNewMacro( Self ); protected: CommandIterationUpdate() {}; public: typedef clitk::GenericOptimizer OptimizerType; typedef const OptimizerType * OptimizerPointer; // Set the generic optimizer void SetOptimizer(OptimizerPointer o){m_Optimizer=o;} // Execute void Execute(itk::Object *caller, const itk::EventObject & event) { Execute( (const itk::Object *)caller, event); } void Execute(const itk::Object * object, const itk::EventObject & event) { if( !(itk::IterationEvent().CheckEvent( &event )) ) { return; } m_Optimizer->OutputIterationInfo(); } OptimizerPointer m_Optimizer; }; //===========================================================================// //Constructor //==========================================================================// BLUTDIRGenericFilter::BLUTDIRGenericFilter(): ImageToImageGenericFilter("Register DIR") { InitializeImageType<2>(); InitializeImageType<3>(); m_Verbose=false; } //=========================================================================// //SetArgsInfo //==========================================================================// void BLUTDIRGenericFilter::SetArgsInfo(const args_info_clitkBLUTDIR & a){ m_ArgsInfo=a; if (m_ArgsInfo.reference_given) AddInputFilename(m_ArgsInfo.reference_arg); if (m_ArgsInfo.target_given) { AddInputFilename(m_ArgsInfo.target_arg); } if (m_ArgsInfo.output_given) SetOutputFilename(m_ArgsInfo.output_arg); if (m_ArgsInfo.verbose_given) m_Verbose=true; } //=========================================================================// //===========================================================================// template void BLUTDIRGenericFilter::InitializeImageType() { ADD_DEFAULT_IMAGE_TYPES(3); } //-------------------------------------------------------------------- //============================================================================== //Creating an observer class that allows us to change parameters at subsequent levels //============================================================================== template class RegistrationInterfaceCommand : public itk::Command { public: typedef RegistrationInterfaceCommand Self; typedef itk::Command Superclass; typedef itk::SmartPointer Pointer; itkNewMacro( Self ); protected: RegistrationInterfaceCommand() { }; public: // Registration typedef TRegistration RegistrationType; typedef RegistrationType * RegistrationPointer; // Transform typedef typename RegistrationType::FixedImageType FixedImageType; typedef typename FixedImageType::RegionType RegionType; itkStaticConstMacro(ImageDimension, unsigned int,FixedImageType::ImageDimension); typedef clitk::MultipleBSplineDeformableTransform TransformType; typedef clitk::MultipleBSplineDeformableTransformInitializer InitializerType; typedef typename InitializerType::CoefficientImageType CoefficientImageType; typedef itk::CastImageFilter CastImageFilterType; typedef typename TransformType::ParametersType ParametersType; typedef typename InitializerType::Pointer InitializerPointer; typedef CommandIterationUpdate::Pointer CommandIterationUpdatePointer; // Optimizer typedef clitk::GenericOptimizer GenericOptimizerType; typedef typename GenericOptimizerType::Pointer GenericOptimizerPointer; // Metric typedef typename RegistrationType::FixedImageType InternalImageType; typedef clitk::GenericMetric GenericMetricType; typedef typename GenericMetricType::Pointer GenericMetricPointer; // Two arguments are passed to the Execute() method: the first // is the pointer to the object which invoked the event and the // second is the event that was invoked. void Execute(itk::Object * object, const itk::EventObject & event) { if( !(itk::IterationEvent().CheckEvent( &event )) ) { return; } // Get the levels RegistrationPointer registration = dynamic_cast( object ); unsigned int numberOfLevels=registration->GetNumberOfLevels(); unsigned int currentLevel=registration->GetCurrentLevel()+1; // Output the levels std::cout<1) { // fixed image region pyramid typedef clitk::MultiResolutionPyramidRegionFilter FixedImageRegionPyramidType; typename FixedImageRegionPyramidType::Pointer fixedImageRegionPyramid=FixedImageRegionPyramidType::New(); fixedImageRegionPyramid->SetRegion(m_MetricRegion); fixedImageRegionPyramid->SetSchedule(registration->GetFixedImagePyramid()->GetSchedule()); // Reinitialize the metric (!= number of samples) m_GenericMetric= GenericMetricType::New(); m_GenericMetric->SetArgsInfo(m_ArgsInfo); m_GenericMetric->SetFixedImage(registration->GetFixedImagePyramid()->GetOutput(registration->GetCurrentLevel())); if (m_ArgsInfo.referenceMask_given) m_GenericMetric->SetFixedImageMask(registration->GetMetric()->GetFixedImageMask()); m_GenericMetric->SetFixedImageRegion(fixedImageRegionPyramid->GetOutput(registration->GetCurrentLevel())); typedef itk::ImageToImageMetric< InternalImageType, InternalImageType > MetricType; typename MetricType::Pointer metric=m_GenericMetric->GetMetricPointer(); registration->SetMetric(metric); // Get the current coefficient image and make a COPY typename itk::ImageDuplicator::Pointer caster = itk::ImageDuplicator::New(); std::vector currentCoefficientImages = m_Initializer->GetTransform()->GetCoefficientImages(); for (unsigned i = 0; i < currentCoefficientImages.size(); ++i) { caster->SetInputImage(currentCoefficientImages[i]); caster->Update(); currentCoefficientImages[i] = caster->GetOutput(); } /* // Write the intermediate result? if (m_ArgsInfo.intermediate_given>=numberOfLevels) writeImage(currentCoefficientImage, m_ArgsInfo.intermediate_arg[currentLevel-2], m_ArgsInfo.verbose_flag); */ // Set the new transform properties m_Initializer->SetImage(registration->GetFixedImagePyramid()->GetOutput(currentLevel-1)); if( m_Initializer->m_ControlPointSpacingIsGiven) m_Initializer->SetControlPointSpacing(m_Initializer->m_ControlPointSpacingArray[registration->GetCurrentLevel()]); if( m_Initializer->m_NumberOfControlPointsIsGiven) m_Initializer->SetNumberOfControlPointsInsideTheImage(m_Initializer->m_NumberOfControlPointsInsideTheImageArray[registration->GetCurrentLevel()]); // Reinitialize the transform if (m_ArgsInfo.verbose_flag) std::cout<<"Initializing transform for level "<InitializeTransform(); ParametersType* newParameters= new typename TransformType::ParametersType(m_Initializer->GetTransform()->GetNumberOfParameters()); // DS : if we want to skip the last pyramid level, force to only 1 iteration DD(m_ArgsInfo.skipLastPyramidLevel_flag); if ((currentLevel == numberOfLevels) && (m_ArgsInfo.skipLastPyramidLevel_flag)) { DD(m_ArgsInfo.maxIt_arg); std::cout << "I skip the last pyramid level : set max iteration to 0" << std::endl; m_ArgsInfo.maxIt_arg = 0; DD(m_ArgsInfo.maxIt_arg); } // Reinitialize an Optimizer (!= number of parameters) m_GenericOptimizer = GenericOptimizerType::New(); m_GenericOptimizer->SetArgsInfo(m_ArgsInfo); m_GenericOptimizer->SetMaximize(m_Maximize); m_GenericOptimizer->SetNumberOfParameters(m_Initializer->GetTransform()->GetNumberOfParameters()); typedef itk::SingleValuedNonLinearOptimizer OptimizerType; OptimizerType::Pointer optimizer = m_GenericOptimizer->GetOptimizerPointer(); optimizer->AddObserver( itk::IterationEvent(), m_CommandIterationUpdate); registration->SetOptimizer(optimizer); m_CommandIterationUpdate->SetOptimizer(m_GenericOptimizer); // Set the previous transform parameters to the registration // if(m_Initializer->m_Parameters!=ITK_NULLPTR )delete m_Initializer->m_Parameters; m_Initializer->SetInitialParameters(currentCoefficientImages, *newParameters); registration->SetInitialTransformParametersOfNextLevel(*newParameters); } } void Execute(const itk::Object * , const itk::EventObject & ) { return; } // Members void SetInitializer(InitializerPointer i){m_Initializer=i;} InitializerPointer m_Initializer; void SetArgsInfo(args_info_clitkBLUTDIR a){m_ArgsInfo=a;} args_info_clitkBLUTDIR m_ArgsInfo; void SetCommandIterationUpdate(CommandIterationUpdatePointer c){m_CommandIterationUpdate=c;}; CommandIterationUpdatePointer m_CommandIterationUpdate; GenericOptimizerPointer m_GenericOptimizer; void SetMaximize(bool b){m_Maximize=b;} bool m_Maximize; GenericMetricPointer m_GenericMetric; void SetMetricRegion(RegionType i){m_MetricRegion=i;} RegionType m_MetricRegion; }; //============================================================================== // Update with the number of dimensions and pixeltype //============================================================================== template void BLUTDIRGenericFilter::UpdateWithInputImageType() { if (m_Verbose) std::cout << "BLUTDIRGenericFilter::UpdateWithInputImageType()" << std::endl; //============================================================================= //Input //============================================================================= bool threadsGiven=m_ArgsInfo.threads_given; int threads=m_ArgsInfo.threads_arg; typedef typename InputImageType::PixelType PixelType; typedef double TCoordRep; typename InputImageType::Pointer fixedImage = this->template GetInput(0); typename InputImageType::Pointer inputFixedImage = this->template GetInput(0); // typedef input2 typename InputImageType::Pointer movingImage = this->template GetInput(1); typename InputImageType::Pointer inputMovingImage = this->template GetInput(1); typedef itk::Image< PixelType,InputImageType::ImageDimension > FixedImageType; typedef itk::Image< PixelType, InputImageType::ImageDimension> MovingImageType; const unsigned int SpaceDimension = InputImageType::ImageDimension; //Whatever the pixel type, internally we work with an image represented in float //Reading reference image if (m_Verbose) std::cout<<"Reading images..."<GetLargestPossibleRegion(); // The transform region with respect to the input region: // where should the transform be DEFINED (depends on mask) typename FixedImageType::RegionType transformRegion = fixedImage->GetLargestPossibleRegion(); typename FixedImageType::RegionType::SizeType transformRegionSize=transformRegion.GetSize(); typename FixedImageType::RegionType::IndexType transformRegionIndex=transformRegion.GetIndex(); typename FixedImageType::PointType transformRegionOrigin=fixedImage->GetOrigin(); // The metric region with respect to the extracted transform region: // where should the metric be CALCULATED (depends on transform) typename FixedImageType::RegionType metricRegion = fixedImage->GetLargestPossibleRegion(); typename FixedImageType::RegionType::IndexType metricRegionIndex=metricRegion.GetIndex(); typename FixedImageType::PointType metricRegionOrigin=fixedImage->GetOrigin(); //=========================================================================== // If given, we connect a mask to reference or target //============================================================================ typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension > MaskType; typedef itk::Image< unsigned char, InputImageType::ImageDimension > ImageLabelType; typename MaskType::Pointer fixedMask = ITK_NULLPTR; typename ImageLabelType::Pointer labels = ITK_NULLPTR; if (m_ArgsInfo.referenceMask_given) { fixedMask = MaskType::New(); labels = ImageLabelType::New(); typedef itk::ImageFileReader< ImageLabelType > LabelReaderType; typename LabelReaderType::Pointer labelReader = LabelReaderType::New(); labelReader->SetFileName(m_ArgsInfo.referenceMask_arg); try { labelReader->Update(); } catch( itk::ExceptionObject & err ) { std::cerr << "ExceptionObject caught while reading mask or labels !" << std::endl; std::cerr << err << std::endl; return; } if (m_Verbose)std::cout <<"Reference image mask was read..." < ResamplerType; typename ResamplerType::Pointer resampler = ResamplerType::New(); typedef itk::NearestNeighborInterpolateImageFunction InterpolatorType; typename InterpolatorType::Pointer interpolator = InterpolatorType::New(); resampler->SetOutputParametersFromImage(fixedImage); resampler->SetInterpolator(interpolator); resampler->SetInput(labelReader->GetOutput()); resampler->Update(); labels = resampler->GetOutput(); // Set the image to the spatialObject fixedMask->SetImage(labels); // Find the bounding box of the "inside" label typedef itk::LabelStatisticsImageFilter StatisticsImageFilterType; typename StatisticsImageFilterType::Pointer statisticsImageFilter=StatisticsImageFilterType::New(); statisticsImageFilter->SetInput(labels); statisticsImageFilter->SetLabelInput(labels); statisticsImageFilter->Update(); typename StatisticsImageFilterType::BoundingBoxType boundingBox = statisticsImageFilter->GetBoundingBox(1); // Limit the transform region to the mask for (unsigned int i=0; iTransformIndexToPhysicalPoint(transformRegion.GetIndex(), transformRegionOrigin); // Crop the fixedImage to the bounding box to facilitate multi-resolution typedef itk::ExtractImageFilter ExtractImageFilterType; typename ExtractImageFilterType::Pointer extractImageFilter=ExtractImageFilterType::New(); extractImageFilter->SetDirectionCollapseToSubmatrix(); extractImageFilter->SetInput(fixedImage); extractImageFilter->SetExtractionRegion(transformRegion); extractImageFilter->Update(); croppedFixedImage=extractImageFilter->GetOutput(); // Update the metric region metricRegion = croppedFixedImage->GetLargestPossibleRegion(); metricRegionIndex=metricRegion.GetIndex(); croppedFixedImage->TransformIndexToPhysicalPoint(metricRegionIndex, metricRegionOrigin); // Set start index to zero (with respect to croppedFixedImage/transform region) metricRegionIndex.Fill(0); metricRegion.SetIndex(metricRegionIndex); croppedFixedImage->SetRegions(metricRegion); croppedFixedImage->SetOrigin(metricRegionOrigin); } typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension > MaskType; typename MaskType::Pointer movingMask=ITK_NULLPTR; if (m_ArgsInfo.targetMask_given) { movingMask= MaskType::New(); typedef itk::Image< unsigned char, InputImageType::ImageDimension > ImageMaskType; typedef itk::ImageFileReader< ImageMaskType > MaskReaderType; typename MaskReaderType::Pointer maskReader = MaskReaderType::New(); maskReader->SetFileName(m_ArgsInfo.targetMask_arg); try { maskReader->Update(); } catch( itk::ExceptionObject & err ) { std::cerr << "ExceptionObject caught !" << std::endl; std::cerr << err << std::endl; } if (m_Verbose)std::cout <<"Target image mask was read..." <SetImage( maskReader->GetOutput() ); } //======================================================= // Output Regions //======================================================= if (m_Verbose) { // Fixed image region std::cout<<"The fixed image has its origin at "<GetOrigin()< FixedImagePyramidType; typedef itk::RecursiveMultiResolutionPyramidImageFilter< MovingImageType, MovingImageType> MovingImagePyramidType; typename FixedImagePyramidType::Pointer fixedImagePyramid = FixedImagePyramidType::New(); typename MovingImagePyramidType::Pointer movingImagePyramid = MovingImagePyramidType::New(); fixedImagePyramid->SetUseShrinkImageFilter(false); fixedImagePyramid->SetInput(croppedFixedImage); fixedImagePyramid->SetNumberOfLevels(m_ArgsInfo.levels_arg); movingImagePyramid->SetUseShrinkImageFilter(false); movingImagePyramid->SetInput(movingImage); movingImagePyramid->SetNumberOfLevels(m_ArgsInfo.levels_arg); if (m_Verbose) std::cout<<"Creating the image pyramid..."<Update(); movingImagePyramid->Update(); typedef clitk::MultiResolutionPyramidRegionFilter FixedImageRegionPyramidType; typename FixedImageRegionPyramidType::Pointer fixedImageRegionPyramid=FixedImageRegionPyramidType::New(); fixedImageRegionPyramid->SetRegion(metricRegion); fixedImageRegionPyramid->SetSchedule(fixedImagePyramid->GetSchedule()); //======================================================= // Rigid or Affine Transform //======================================================= typedef itk::AffineTransform RigidTransformType; RigidTransformType::Pointer rigidTransform=ITK_NULLPTR; if (m_ArgsInfo.initMatrix_given) { if(m_Verbose) std::cout<<"Reading the prior transform matrix "<< m_ArgsInfo.initMatrix_arg<<"..."< rigidTransformMatrix=clitk::ReadMatrix3D(m_ArgsInfo.initMatrix_arg); //Set the rotation itk::Matrix finalRotation = clitk::GetRotationalPartMatrix3D(rigidTransformMatrix); rigidTransform->SetMatrix(finalRotation); //Set the translation itk::Vector finalTranslation = clitk::GetTranslationPartMatrix3D(rigidTransformMatrix); rigidTransform->SetTranslation(finalTranslation); } else if (m_ArgsInfo.centre_flag) { if(m_Verbose) std::cout<<"No itinial matrix given and \"centre\" flag switched on. Centering all images..."< TransformInitializerType; typename TransformInitializerType::Pointer initializer = TransformInitializerType::New(); initializer->SetTransform( rigidTransform ); initializer->SetFixedImage( fixedImage ); initializer->SetMovingImage( movingImage ); initializer->GeometryOn(); initializer->InitializeTransform(); } //======================================================= // B-LUT FFD Transform //======================================================= typedef clitk::MultipleBSplineDeformableTransform TransformType; typename TransformType::Pointer transform = TransformType::New(); if (labels) transform->SetLabels(labels); if (rigidTransform) transform->SetBulkTransform(rigidTransform); //------------------------------------------------------------------------- // The transform initializer //------------------------------------------------------------------------- typedef clitk::MultipleBSplineDeformableTransformInitializer< TransformType,FixedImageType> InitializerType; typename InitializerType::Pointer initializer = InitializerType::New(); initializer->SetVerbose(m_Verbose); initializer->SetImage(fixedImagePyramid->GetOutput(0)); initializer->SetTransform(transform); //------------------------------------------------------------------------- // Order //------------------------------------------------------------------------- typename FixedImageType::RegionType::SizeType splineOrders ; splineOrders.Fill(3); if (m_ArgsInfo.order_given) for(unsigned int i=0; iSetSplineOrders(splineOrders); //------------------------------------------------------------------------- // Levels //------------------------------------------------------------------------- // Spacing if (m_ArgsInfo.spacing_given) { initializer->m_ControlPointSpacingArray.resize(m_ArgsInfo.levels_arg); initializer->SetControlPointSpacing(m_ArgsInfo.spacing_arg); initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1]=initializer->m_ControlPointSpacing; if (m_Verbose) std::cout<<"Using a control point spacing of "<m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1] <<" at level "<m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1-i]=initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-i]*2; if (m_Verbose) std::cout<<"Using a control point spacing of "<m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1-i] <<" at level "<m_NumberOfControlPointsInsideTheImageArray.resize(m_ArgsInfo.levels_arg); initializer->SetNumberOfControlPointsInsideTheImage(m_ArgsInfo.control_arg); initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1]=initializer->m_NumberOfControlPointsInsideTheImage; if (m_Verbose) std::cout<<"Using "<< initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1]<<"control points inside the image" <<" at level "<m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i][j]=ceil ((double)initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-i][j]/2.); // initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i]=ceil ((double)initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-i]/2.); if (m_Verbose) std::cout<<"Using "<< initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i]<<"control points inside the image" <<" at level "<SetControlPointSpacing( initializer->m_ControlPointSpacingArray[0]); if (m_ArgsInfo.control_given) initializer->SetNumberOfControlPointsInsideTheImage(initializer->m_NumberOfControlPointsInsideTheImageArray[0]); if (m_ArgsInfo.samplingFactor_given) initializer->SetSamplingFactors(m_ArgsInfo.samplingFactor_arg); // Initialize initializer->InitializeTransform(); //------------------------------------------------------------------------- // Initial parameters (passed by reference) //------------------------------------------------------------------------- typedef typename TransformType::ParametersType ParametersType; const unsigned int numberOfParameters = transform->GetNumberOfParameters(); ParametersType parameters(numberOfParameters); parameters.Fill( 0.0 ); transform->SetParameters( parameters ); if (m_ArgsInfo.initCoeff_given) initializer->SetInitialParameters(m_ArgsInfo.initCoeff_arg, parameters); //------------------------------------------------------------------------- // DEBUG: use an itk BSpline instead of multilabel BLUTs //------------------------------------------------------------------------- typedef itk::Transform< TCoordRep, 3, 3 > RegistrationTransformType; RegistrationTransformType::Pointer regTransform(transform); typedef itk::BSplineDeformableTransform SingleBSplineTransformType; typename SingleBSplineTransformType::Pointer sTransform; if(m_ArgsInfo.itkbspline_flag) { if( transform->GetTransforms().size()>1) itkExceptionMacro(<< "invalid --itkbspline option if there is more than 1 label") sTransform = SingleBSplineTransformType::New(); sTransform->SetBulkTransform( transform->GetTransforms()[0]->GetBulkTransform() ); sTransform->SetGridSpacing( transform->GetTransforms()[0]->GetGridSpacing() ); sTransform->SetGridOrigin( transform->GetTransforms()[0]->GetGridOrigin() ); sTransform->SetGridRegion( transform->GetTransforms()[0]->GetGridRegion() ); sTransform->SetParameters( transform->GetTransforms()[0]->GetParameters() ); regTransform = sTransform; transform = ITK_NULLPTR; // free memory } //======================================================= // Interpolator //======================================================= typedef clitk::GenericInterpolator GenericInterpolatorType; typename GenericInterpolatorType::Pointer genericInterpolator=GenericInterpolatorType::New(); genericInterpolator->SetArgsInfo(m_ArgsInfo); typedef itk::InterpolateImageFunction< FixedImageType, TCoordRep > InterpolatorType; typename InterpolatorType::Pointer interpolator=genericInterpolator->GetInterpolatorPointer(); //======================================================= // Metric //======================================================= typedef clitk::GenericMetric GenericMetricType; typename GenericMetricType::Pointer genericMetric=GenericMetricType::New(); genericMetric->SetArgsInfo(m_ArgsInfo); genericMetric->SetFixedImage(fixedImagePyramid->GetOutput(0)); if (fixedMask) genericMetric->SetFixedImageMask(fixedMask); genericMetric->SetFixedImageRegion(fixedImageRegionPyramid->GetOutput(0)); typedef itk::ImageToImageMetric< FixedImageType, MovingImageType > MetricType; typename MetricType::Pointer metric=genericMetric->GetMetricPointer(); if (movingMask) metric->SetMovingImageMask(movingMask); if (threadsGiven) { #if ITK_VERSION_MAJOR <= 4 metric->SetNumberOfThreads( threads ); #else metric->SetNumberOfWorkUnits( threads ); #endif if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl; } //======================================================= // Optimizer //======================================================= typedef clitk::GenericOptimizer GenericOptimizerType; GenericOptimizerType::Pointer genericOptimizer = GenericOptimizerType::New(); genericOptimizer->SetArgsInfo(m_ArgsInfo); genericOptimizer->SetMaximize(genericMetric->GetMaximize()); genericOptimizer->SetNumberOfParameters(regTransform->GetNumberOfParameters()); typedef itk::SingleValuedNonLinearOptimizer OptimizerType; OptimizerType::Pointer optimizer = genericOptimizer->GetOptimizerPointer(); //======================================================= // Registration //======================================================= typedef itk::MultiResolutionImageRegistrationMethod< FixedImageType, MovingImageType > RegistrationType; typename RegistrationType::Pointer registration = RegistrationType::New(); registration->SetMetric( metric ); registration->SetOptimizer( optimizer ); registration->SetInterpolator( interpolator ); registration->SetTransform (regTransform ); if(threadsGiven) { #if ITK_VERSION_MAJOR <= 4 registration->SetNumberOfThreads(threads); #else registration->SetNumberOfWorkUnits(threads); #endif if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl; } registration->SetFixedImage( croppedFixedImage ); registration->SetMovingImage( movingImage ); registration->SetFixedImageRegion( metricRegion ); registration->SetFixedImagePyramid( fixedImagePyramid ); registration->SetMovingImagePyramid( movingImagePyramid ); registration->SetInitialTransformParameters( regTransform->GetParameters() ); registration->SetNumberOfLevels( m_ArgsInfo.levels_arg ); if (m_Verbose) std::cout<<"Setting the number of resolution levels to "<SetOptimizer(genericOptimizer); optimizer->AddObserver( itk::IterationEvent(), observer ); // Output level info typedef RegistrationInterfaceCommand CommandType; typename CommandType::Pointer command = CommandType::New(); command->SetInitializer(initializer); command->SetArgsInfo(m_ArgsInfo); command->SetCommandIterationUpdate(observer); command->SetMaximize(genericMetric->GetMaximize()); command->SetMetricRegion(metricRegion); registration->AddObserver( itk::IterationEvent(), command ); if (m_ArgsInfo.coeff_given) { if(m_ArgsInfo.itkbspline_flag) { itkExceptionMacro("--coeff and --itkbpline are incompatible"); } std::cout << std::endl << "Output coefficient images every " << m_ArgsInfo.coeffEveryN_arg << " iterations." << std::endl; typedef CommandIterationUpdateDVF DVFCommandType; typename DVFCommandType::Pointer observerdvf = DVFCommandType::New(); observerdvf->SetFixedImage(fixedImage); observerdvf->SetTransform(transform); observerdvf->SetArgsInfo(m_ArgsInfo); optimizer->AddObserver( itk::IterationEvent(), observerdvf ); } } //======================================================= // Let's go //======================================================= if (m_Verbose) std::cout << std::endl << "Starting Registration" << std::endl; try { registration->Update(); } catch( itk::ExceptionObject & err ) { std::cerr << "ExceptionObject caught while registering!" << std::endl; std::cerr << err << std::endl; return; } //======================================================= // Get the result //======================================================= OptimizerType::ParametersType finalParameters = registration->GetLastTransformParameters(); regTransform->SetParameters( finalParameters ); if (m_Verbose) { std::cout<<"Stop condition description: " <GetOptimizer()->GetStopConditionDescription()< coefficientImages = transform->GetCoefficientImages(); typedef itk::ImageFileWriter CoeffWriterType; typename CoeffWriterType::Pointer coeffWriter = CoeffWriterType::New(); unsigned nLabels = transform->GetnLabels(); std::string fname(m_ArgsInfo.coeff_arg); int dotpos = fname.length() - 1; while (dotpos >= 0 && fname[dotpos] != '.') dotpos--; for (unsigned i = 0; i < nLabels; ++i) { std::ostringstream osfname; osfname << fname.substr(0, dotpos) << '_' << i << fname.substr(dotpos); coeffWriter->SetInput(coefficientImages[i]); coeffWriter->SetFileName(osfname.str()); coeffWriter->Update(); } } //======================================================= // Compute the DVF (only deformable transform) //======================================================= typedef itk::Vector< float, SpaceDimension > DisplacementType; typedef itk::Image< DisplacementType, InputImageType::ImageDimension > DisplacementFieldType; #if (ITK_VERSION_MAJOR == 4) && (ITK_VERSION_MINOR < 6) typedef itk::TransformToDisplacementFieldSource ConvertorType; #else typedef itk::TransformToDisplacementFieldFilter ConvertorType; #endif typename ConvertorType::Pointer filter= ConvertorType::New(); #if ITK_VERSION_MAJOR <= 4 filter->SetNumberOfThreads(1); #else filter->SetNumberOfWorkUnits(1); #endif if(m_ArgsInfo.itkbspline_flag) sTransform->SetBulkTransform(ITK_NULLPTR); else transform->SetBulkTransform(ITK_NULLPTR); filter->SetTransform(regTransform); #if ITK_VERSION_MAJOR > 4 || (ITK_VERSION_MAJOR == 4 && ITK_VERSION_MINOR >= 6) filter->SetReferenceImage(fixedImage); #else filter->SetOutputParametersFromImage(fixedImage); #endif filter->Update(); typename DisplacementFieldType::Pointer field = filter->GetOutput(); //======================================================= // Write the DVF //======================================================= typedef itk::ImageFileWriter< DisplacementFieldType > FieldWriterType; typename FieldWriterType::Pointer fieldWriter = FieldWriterType::New(); fieldWriter->SetFileName( m_ArgsInfo.vf_arg ); fieldWriter->SetInput( field ); try { fieldWriter->Update(); } catch( itk::ExceptionObject & excp ) { std::cerr << "Exception thrown writing the DVF" << std::endl; std::cerr << excp << std::endl; return; } //======================================================= // Resample the moving image //======================================================= typedef itk::ResampleImageFilter< MovingImageType, FixedImageType > ResampleFilterType; typename ResampleFilterType::Pointer resampler = ResampleFilterType::New(); if (rigidTransform) { if(m_ArgsInfo.itkbspline_flag) sTransform->SetBulkTransform(rigidTransform); else transform->SetBulkTransform(rigidTransform); } resampler->SetTransform( regTransform ); resampler->SetInput( movingImage); resampler->SetOutputParametersFromImage(fixedImage); resampler->Update(); typename FixedImageType::Pointer result=resampler->GetOutput(); // typedef itk::WarpImageFilter< MovingImageType, FixedImageType, DeformationFieldType > WarpFilterType; // typename WarpFilterType::Pointer warp = WarpFilterType::New(); // warp->SetDeformationField( field ); // warp->SetInput( movingImageReader->GetOutput() ); // warp->SetOutputOrigin( fixedImage->GetOrigin() ); // warp->SetOutputSpacing( fixedImage->GetSpacing() ); // warp->SetOutputDirection( fixedImage->GetDirection() ); // warp->SetEdgePaddingValue( 0.0 ); // warp->Update(); //======================================================= // Write the warped image //======================================================= typedef itk::ImageFileWriter< FixedImageType > WriterType; typename WriterType::Pointer writer = WriterType::New(); writer->SetFileName( m_ArgsInfo.output_arg ); writer->SetInput( result ); try { writer->Update(); } catch( itk::ExceptionObject & err ) { std::cerr << "ExceptionObject caught !" << std::endl; std::cerr << err << std::endl; return; } //======================================================= // Calculate the difference after the deformable transform //======================================================= typedef clitk::DifferenceImageFilter< FixedImageType, FixedImageType> DifferenceFilterType; if (m_ArgsInfo.after_given) { typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New(); difference->SetValidInput( fixedImage ); difference->SetTestInput( result ); try { difference->Update(); } catch( itk::ExceptionObject & err ) { std::cerr << "ExceptionObject caught calculating the difference !" << std::endl; std::cerr << err << std::endl; return; } typename WriterType::Pointer differenceWriter=WriterType::New(); differenceWriter->SetInput(difference->GetOutput()); differenceWriter->SetFileName(m_ArgsInfo.after_arg); differenceWriter->Update(); } //======================================================= // Calculate the difference before the deformable transform //======================================================= if( m_ArgsInfo.before_given ) { typename FixedImageType::Pointer moving=FixedImageType::New(); if (m_ArgsInfo.initMatrix_given) { typedef itk::ResampleImageFilter ResamplerType; typename ResamplerType::Pointer resampler=ResamplerType::New(); resampler->SetInput(movingImage); resampler->SetOutputOrigin(fixedImage->GetOrigin()); resampler->SetSize(fixedImage->GetLargestPossibleRegion().GetSize()); resampler->SetOutputSpacing(fixedImage->GetSpacing()); resampler->SetDefaultPixelValue( 0. ); if (rigidTransform ) resampler->SetTransform(rigidTransform); resampler->Update(); moving=resampler->GetOutput(); } else moving=movingImage; typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New(); difference->SetValidInput( fixedImage ); difference->SetTestInput( moving ); try { difference->Update(); } catch( itk::ExceptionObject & err ) { std::cerr << "ExceptionObject caught calculating the difference !" << std::endl; std::cerr << err << std::endl; return; } typename WriterType::Pointer differenceWriter=WriterType::New(); writer->SetFileName( m_ArgsInfo.before_arg ); writer->SetInput( difference->GetOutput() ); writer->Update( ); } return; } }//end clitk #endif // #define clitkBLUTDIRGenericFilter_txx