X-Git-Url: https://git.creatis.insa-lyon.fr/pubgit/?a=blobdiff_plain;f=registration%2FclitkBLUTDIRGenericFilter.cxx;h=9f7b84f456f36496f1be156400a0c43441dd2983;hb=refs%2Fheads%2FextentSimon;hp=3e6c6ec18b36d8569b9d37e2e0e10a179e69f2c2;hpb=6a18aca6ed8f2384cd2183b3bea8737e0ac6a55c;p=clitk.git diff --git a/registration/clitkBLUTDIRGenericFilter.cxx b/registration/clitkBLUTDIRGenericFilter.cxx old mode 100755 new mode 100644 index 3e6c6ec..9f7b84f --- a/registration/clitkBLUTDIRGenericFilter.cxx +++ b/registration/clitkBLUTDIRGenericFilter.cxx @@ -3,7 +3,7 @@ Program: vv http://www.creatis.insa-lyon.fr/rio/vv Authors belong to: - University of LYON http://www.universite-lyon.fr/ -- Léon Bérard cancer center http://oncora1.lyon.fnclcc.fr +- Léon Bérard cancer center http://www.centreleonberard.fr - CREATIS CNRS laboratory http://www.creatis.insa-lyon.fr This software is distributed WITHOUT ANY WARRANTY; without even @@ -14,7 +14,7 @@ It is distributed under dual licence - BSD See included LICENSE.txt file - CeCILL-B http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html -======================================================================-====*/ +===========================================================================**/ #ifndef clitkBLUTDIRGenericFilter_cxx #define clitkBLUTDIRGenericFilter_cxx @@ -28,6 +28,17 @@ It is distributed under dual licence ===================================================*/ #include "clitkBLUTDIRGenericFilter.h" +#include "clitkBLUTDIRCommandIterationUpdateDVF.h" +#include "itkCenteredTransformInitializer.h" +#if ITK_VERSION_MAJOR >= 4 +# if ITK_VERSION_MINOR < 6 +# include "itkTransformToDisplacementFieldSource.h" +# else +# include "itkTransformToDisplacementFieldFilter.h" +# endif +#else +# include "itkTransformToDeformationFieldSource.h" +#endif namespace clitk { @@ -78,7 +89,7 @@ namespace clitk { InitializeImageType<2>(); InitializeImageType<3>(); - m_Verbose=true; + m_Verbose=false; } //=========================================================================// @@ -93,6 +104,8 @@ namespace clitk } if (m_ArgsInfo.output_given) SetOutputFilename(m_ArgsInfo.output_arg); + + if (m_ArgsInfo.verbose_given) m_Verbose=true; } //=========================================================================// @@ -127,8 +140,8 @@ namespace clitk typedef typename RegistrationType::FixedImageType FixedImageType; typedef typename FixedImageType::RegionType RegionType; itkStaticConstMacro(ImageDimension, unsigned int,FixedImageType::ImageDimension); - typedef clitk::BSplineDeformableTransform TransformType; - typedef clitk::BSplineDeformableTransformInitializer InitializerType; + typedef clitk::MultipleBSplineDeformableTransform TransformType; + typedef clitk::MultipleBSplineDeformableTransformInitializer InitializerType; typedef typename InitializerType::CoefficientImageType CoefficientImageType; typedef itk::CastImageFilter CastImageFilterType; typedef typename TransformType::ParametersType ParametersType; @@ -186,14 +199,20 @@ namespace clitk registration->SetMetric(metric); // Get the current coefficient image and make a COPY - typename itk::ImageDuplicator::Pointer caster=itk::ImageDuplicator::New(); - caster->SetInputImage(m_Initializer->GetTransform()->GetCoefficientImage()); - caster->Update(); - typename CoefficientImageType::Pointer currentCoefficientImage=caster->GetOutput(); + typename itk::ImageDuplicator::Pointer caster = itk::ImageDuplicator::New(); + std::vector currentCoefficientImages = m_Initializer->GetTransform()->GetCoefficientImages(); + for (unsigned i = 0; i < currentCoefficientImages.size(); ++i) + { + caster->SetInputImage(currentCoefficientImages[i]); + caster->Update(); + currentCoefficientImages[i] = caster->GetOutput(); + } + /* // Write the intermediate result? if (m_ArgsInfo.intermediate_given>=numberOfLevels) writeImage(currentCoefficientImage, m_ArgsInfo.intermediate_arg[currentLevel-2], m_ArgsInfo.verbose_flag); + */ // Set the new transform properties m_Initializer->SetImage(registration->GetFixedImagePyramid()->GetOutput(currentLevel-1)); @@ -231,7 +250,7 @@ namespace clitk // Set the previous transform parameters to the registration // if(m_Initializer->m_Parameters!=NULL )delete m_Initializer->m_Parameters; - m_Initializer->SetInitialParameters(currentCoefficientImage,*newParameters); + m_Initializer->SetInitialParameters(currentCoefficientImages, *newParameters); registration->SetInitialTransformParametersOfNextLevel(*newParameters); } } @@ -267,6 +286,8 @@ namespace clitk template void BLUTDIRGenericFilter::UpdateWithInputImageType() { + if (m_Verbose) std::cout << "BLUTDIRGenericFilter::UpdateWithInputImageType()" << std::endl; + //============================================================================= //Input //============================================================================= @@ -311,7 +332,6 @@ namespace clitk // The metric region with respect to the extracted transform region: // where should the metric be CALCULATED (depends on transform) typename FixedImageType::RegionType metricRegion = fixedImage->GetLargestPossibleRegion(); - typename FixedImageType::RegionType::SizeType metricRegionSize=metricRegion.GetSize(); typename FixedImageType::RegionType::IndexType metricRegionIndex=metricRegion.GetIndex(); typename FixedImageType::PointType metricRegionOrigin=fixedImage->GetOrigin(); @@ -320,33 +340,46 @@ namespace clitk // If given, we connect a mask to reference or target //============================================================================ typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension > MaskType; - typename MaskType::Pointer fixedMask=NULL; + typedef itk::Image< unsigned char, InputImageType::ImageDimension > ImageLabelType; + typename MaskType::Pointer fixedMask = NULL; + typename ImageLabelType::Pointer labels = NULL; if (m_ArgsInfo.referenceMask_given) { - fixedMask= MaskType::New(); - typedef itk::Image< unsigned char,InputImageType::ImageDimension > ImageMaskType; - typedef itk::ImageFileReader< ImageMaskType > MaskReaderType; - typename MaskReaderType::Pointer maskReader = MaskReaderType::New(); - maskReader->SetFileName(m_ArgsInfo.referenceMask_arg); + fixedMask = MaskType::New(); + labels = ImageLabelType::New(); + typedef itk::ImageFileReader< ImageLabelType > LabelReaderType; + typename LabelReaderType::Pointer labelReader = LabelReaderType::New(); + labelReader->SetFileName(m_ArgsInfo.referenceMask_arg); try { - maskReader->Update(); + labelReader->Update(); } catch( itk::ExceptionObject & err ) { - std::cerr << "ExceptionObject caught while reading mask !" << std::endl; + std::cerr << "ExceptionObject caught while reading mask or labels !" << std::endl; std::cerr << err << std::endl; return; } if (m_Verbose)std::cout <<"Reference image mask was read..." < ResamplerType; + typename ResamplerType::Pointer resampler = ResamplerType::New(); + typedef itk::NearestNeighborInterpolateImageFunction InterpolatorType; + typename InterpolatorType::Pointer interpolator = InterpolatorType::New(); + resampler->SetOutputParametersFromImage(fixedImage); + resampler->SetInterpolator(interpolator); + resampler->SetInput(labelReader->GetOutput()); + resampler->Update(); + labels = resampler->GetOutput(); + // Set the image to the spatialObject - fixedMask->SetImage( maskReader->GetOutput() ); + fixedMask->SetImage(labels); // Find the bounding box of the "inside" label - typedef itk::LabelGeometryImageFilter GeometryImageFilterType; + typedef itk::LabelGeometryImageFilter GeometryImageFilterType; typename GeometryImageFilterType::Pointer geometryImageFilter=GeometryImageFilterType::New(); - geometryImageFilter->SetInput(maskReader->GetOutput()); + geometryImageFilter->SetInput(labels); geometryImageFilter->Update(); typename GeometryImageFilterType::BoundingBoxType boundingBox = geometryImageFilter->GetBoundingBox(1); @@ -363,6 +396,7 @@ namespace clitk // Crop the fixedImage to the bounding box to facilitate multi-resolution typedef itk::ExtractImageFilter ExtractImageFilterType; typename ExtractImageFilterType::Pointer extractImageFilter=ExtractImageFilterType::New(); + extractImageFilter->SetDirectionCollapseToSubmatrix(); extractImageFilter->SetInput(fixedImage); extractImageFilter->SetExtractionRegion(transformRegion); extractImageFilter->Update(); @@ -469,20 +503,34 @@ namespace clitk itk::Vector finalTranslation = clitk::GetTranslationPartMatrix3D(rigidTransformMatrix); rigidTransform->SetTranslation(finalTranslation); } + else if (m_ArgsInfo.centre_flag) + { + if(m_Verbose) std::cout<<"No itinial matrix given and \"centre\" flag switched on. Centering all images..."< TransformInitializerType; + typename TransformInitializerType::Pointer initializer = TransformInitializerType::New(); + initializer->SetTransform( rigidTransform ); + initializer->SetFixedImage( fixedImage ); + initializer->SetMovingImage( movingImage ); + initializer->GeometryOn(); + initializer->InitializeTransform(); + } //======================================================= // B-LUT FFD Transform //======================================================= - typedef clitk::BSplineDeformableTransform TransformType; - typename TransformType::Pointer transform= TransformType::New(); - if (fixedMask) transform->SetMask( fixedMask ); - if (rigidTransform) transform->SetBulkTransform( rigidTransform ); + typedef clitk::MultipleBSplineDeformableTransform TransformType; + typename TransformType::Pointer transform = TransformType::New(); + if (labels) transform->SetLabels(labels); + if (rigidTransform) transform->SetBulkTransform(rigidTransform); //------------------------------------------------------------------------- // The transform initializer //------------------------------------------------------------------------- - typedef clitk::BSplineDeformableTransformInitializer< TransformType,FixedImageType> InitializerType; + typedef clitk::MultipleBSplineDeformableTransformInitializer< TransformType,FixedImageType> InitializerType; typename InitializerType::Pointer initializer = InitializerType::New(); initializer->SetVerbose(m_Verbose); initializer->SetImage(fixedImagePyramid->GetOutput(0)); @@ -560,6 +608,25 @@ namespace clitk transform->SetParameters( parameters ); if (m_ArgsInfo.initCoeff_given) initializer->SetInitialParameters(m_ArgsInfo.initCoeff_arg, parameters); + //------------------------------------------------------------------------- + // DEBUG: use an itk BSpline instead of multilabel BLUTs + //------------------------------------------------------------------------- + typedef itk::Transform< TCoordRep, 3, 3 > RegistrationTransformType; + RegistrationTransformType::Pointer regTransform(transform); + typedef itk::BSplineDeformableTransform SingleBSplineTransformType; + typename SingleBSplineTransformType::Pointer sTransform; + if(m_ArgsInfo.itkbspline_flag) { + if( transform->GetTransforms().size()>1) + itkExceptionMacro(<< "invalid --itkbspline option if there is more than 1 label") + sTransform = SingleBSplineTransformType::New(); + sTransform->SetBulkTransform( transform->GetTransforms()[0]->GetBulkTransform() ); + sTransform->SetGridSpacing( transform->GetTransforms()[0]->GetGridSpacing() ); + sTransform->SetGridOrigin( transform->GetTransforms()[0]->GetGridOrigin() ); + sTransform->SetGridRegion( transform->GetTransforms()[0]->GetGridRegion() ); + sTransform->SetParameters( transform->GetTransforms()[0]->GetParameters() ); + regTransform = sTransform; + transform = NULL; // free memory + } //======================================================= // Interpolator @@ -583,16 +650,10 @@ namespace clitk typedef itk::ImageToImageMetric< FixedImageType, MovingImageType > MetricType; typename MetricType::Pointer metric=genericMetric->GetMetricPointer(); if (movingMask) metric->SetMovingImageMask(movingMask); - -#ifdef ITK_USE_OPTIMIZED_REGISTRATION_METHODS if (threadsGiven) { metric->SetNumberOfThreads( threads ); if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl; } -#else - if (m_Verbose) std::cout<<"Not setting the number of threads (not compiled with USE_OPTIMIZED_REGISTRATION_METHODS)..."<SetArgsInfo(m_ArgsInfo); genericOptimizer->SetMaximize(genericMetric->GetMaximize()); - genericOptimizer->SetNumberOfParameters(transform->GetNumberOfParameters()); + genericOptimizer->SetNumberOfParameters(regTransform->GetNumberOfParameters()); typedef itk::SingleValuedNonLinearOptimizer OptimizerType; OptimizerType::Pointer optimizer = genericOptimizer->GetOptimizerPointer(); @@ -614,7 +675,7 @@ namespace clitk registration->SetMetric( metric ); registration->SetOptimizer( optimizer ); registration->SetInterpolator( interpolator ); - registration->SetTransform (transform); + registration->SetTransform (regTransform ); if(threadsGiven) { registration->SetNumberOfThreads(threads); if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl; @@ -624,7 +685,7 @@ namespace clitk registration->SetFixedImageRegion( metricRegion ); registration->SetFixedImagePyramid( fixedImagePyramid ); registration->SetMovingImagePyramid( movingImagePyramid ); - registration->SetInitialTransformParameters( transform->GetParameters() ); + registration->SetInitialTransformParameters( regTransform->GetParameters() ); registration->SetNumberOfLevels( m_ArgsInfo.levels_arg ); if (m_Verbose) std::cout<<"Setting the number of resolution levels to "<SetMaximize(genericMetric->GetMaximize()); command->SetMetricRegion(metricRegion); registration->AddObserver( itk::IterationEvent(), command ); + + if (m_ArgsInfo.coeff_given) + { + if(m_ArgsInfo.itkbspline_flag) { + itkExceptionMacro("--coeff and --itkbpline are incompatible"); + } + + std::cout << std::endl << "Output coefficient images every " << m_ArgsInfo.coeffEveryN_arg << " iterations." << std::endl; + typedef CommandIterationUpdateDVF DVFCommandType; + typename DVFCommandType::Pointer observerdvf = DVFCommandType::New(); + observerdvf->SetFixedImage(fixedImage); + observerdvf->SetTransform(transform); + observerdvf->SetArgsInfo(m_ArgsInfo); + optimizer->AddObserver( itk::IterationEvent(), observerdvf ); + } } @@ -658,7 +734,7 @@ namespace clitk try { - registration->StartRegistration(); + registration->Update(); } catch( itk::ExceptionObject & err ) { @@ -672,7 +748,7 @@ namespace clitk // Get the result //======================================================= OptimizerType::ParametersType finalParameters = registration->GetLastTransformParameters(); - transform->SetParameters( finalParameters ); + regTransform->SetParameters( finalParameters ); if (m_Verbose) { std::cout<<"Stop condition description: " @@ -686,12 +762,24 @@ namespace clitk if (m_ArgsInfo.coeff_given) { typedef typename TransformType::CoefficientImageType CoefficientImageType; - typename CoefficientImageType::Pointer coefficientImage =transform->GetCoefficientImage(); + std::vector coefficientImages = transform->GetCoefficientImages(); typedef itk::ImageFileWriter CoeffWriterType; - typename CoeffWriterType::Pointer coeffWriter=CoeffWriterType::New(); - coeffWriter->SetInput(coefficientImage); - coeffWriter->SetFileName(m_ArgsInfo.coeff_arg); - coeffWriter->Update(); + typename CoeffWriterType::Pointer coeffWriter = CoeffWriterType::New(); + unsigned nLabels = transform->GetnLabels(); + + std::string fname(m_ArgsInfo.coeff_arg); + int dotpos = fname.length() - 1; + while (dotpos >= 0 && fname[dotpos] != '.') + dotpos--; + + for (unsigned i = 0; i < nLabels; ++i) + { + std::ostringstream osfname; + osfname << fname.substr(0, dotpos) << '_' << i << fname.substr(dotpos); + coeffWriter->SetInput(coefficientImages[i]); + coeffWriter->SetFileName(osfname.str()); + coeffWriter->Update(); + } } @@ -700,21 +788,34 @@ namespace clitk // Compute the DVF (only deformable transform) //======================================================= typedef itk::Vector< float, SpaceDimension > DisplacementType; - typedef itk::Image< DisplacementType, InputImageType::ImageDimension > DeformationFieldType; - typedef itk::TransformToDeformationFieldSource ConvertorType; + typedef itk::Image< DisplacementType, InputImageType::ImageDimension > DisplacementFieldType; +#if ITK_VERSION_MAJOR >= 4 +# if ITK_VERSION_MINOR < 6 + typedef itk::TransformToDisplacementFieldSource ConvertorType; +# else + typedef itk::TransformToDisplacementFieldFilter ConvertorType; +# endif +#endif typename ConvertorType::Pointer filter= ConvertorType::New(); filter->SetNumberOfThreads(1); - transform->SetBulkTransform(NULL); - filter->SetTransform(transform); + if(m_ArgsInfo.itkbspline_flag) + sTransform->SetBulkTransform(NULL); + else + transform->SetBulkTransform(NULL); + filter->SetTransform(regTransform); +#if ITK_VERSION_MAJOR > 4 || (ITK_VERSION_MAJOR == 4 && ITK_VERSION_MINOR >= 6) + filter->SetReferenceImage(fixedImage); +#else filter->SetOutputParametersFromImage(fixedImage); +#endif filter->Update(); - typename DeformationFieldType::Pointer field = filter->GetOutput(); + typename DisplacementFieldType::Pointer field = filter->GetOutput(); //======================================================= // Write the DVF //======================================================= - typedef itk::ImageFileWriter< DeformationFieldType > FieldWriterType; + typedef itk::ImageFileWriter< DisplacementFieldType > FieldWriterType; typename FieldWriterType::Pointer fieldWriter = FieldWriterType::New(); fieldWriter->SetFileName( m_ArgsInfo.vf_arg ); fieldWriter->SetInput( field ); @@ -735,8 +836,13 @@ namespace clitk //======================================================= typedef itk::ResampleImageFilter< MovingImageType, FixedImageType > ResampleFilterType; typename ResampleFilterType::Pointer resampler = ResampleFilterType::New(); - if (rigidTransform) transform->SetBulkTransform(rigidTransform); - resampler->SetTransform( transform ); + if (rigidTransform) { + if(m_ArgsInfo.itkbspline_flag) + sTransform->SetBulkTransform(rigidTransform); + else + transform->SetBulkTransform(rigidTransform); + } + resampler->SetTransform( regTransform ); resampler->SetInput( movingImage); resampler->SetOutputParametersFromImage(fixedImage); resampler->Update();