]> Creatis software - clitk.git/blob - registration/clitkBLUTDIRGenericFilter.cxx
close #56 Remove ITK_Review dependency
[clitk.git] / registration / clitkBLUTDIRGenericFilter.cxx
1 /*=========================================================================
2 Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4 Authors belong to:
5 - University of LYON              http://www.universite-lyon.fr/
6 - Léon Bérard cancer center       http://www.centreleonberard.fr
7 - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9 This software is distributed WITHOUT ANY WARRANTY; without even
10 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11 PURPOSE.  See the copyright notices for more information.
12
13 It is distributed under dual licence
14
15 - BSD        See included LICENSE.txt file
16 - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef clitkBLUTDIRGenericFilter_cxx
19 #define clitkBLUTDIRGenericFilter_cxx
20
21 /* =================================================
22  * @file   clitkBLUTDIRGenericFilter.cxx
23  * @author
24  * @date
25  *
26  * @brief
27  *
28  ===================================================*/
29
30 #include "clitkBLUTDIRGenericFilter.h"
31 #include "clitkBLUTDIRCommandIterationUpdateDVF.h"
32 #include "itkCenteredTransformInitializer.h"
33 #include "itkLabelStatisticsImageFilter.h"
34 #if (ITK_VERSION_MAJOR == 4) && (ITK_VERSION_MINOR < 6)
35 # include "itkTransformToDisplacementFieldSource.h"
36 #else
37 # include "itkTransformToDisplacementFieldFilter.h"
38 #endif
39
40 namespace clitk
41 {
42
43   //==============================================================================
44   // Creating an observer class that allows output at each iteration
45   //==============================================================================
46   class CommandIterationUpdate : public itk::Command
47   {
48     public:
49       typedef  CommandIterationUpdate   Self;
50       typedef  itk::Command             Superclass;
51       typedef  itk::SmartPointer<Self>  Pointer;
52       itkNewMacro( Self );
53     protected:
54       CommandIterationUpdate() {};
55     public:
56       typedef   clitk::GenericOptimizer<args_info_clitkBLUTDIR>     OptimizerType;
57       typedef   const OptimizerType   *           OptimizerPointer;
58
59       // Set the generic optimizer
60       void SetOptimizer(OptimizerPointer o){m_Optimizer=o;}
61
62       // Execute
63       void Execute(itk::Object *caller, const itk::EventObject & event)
64       {
65         Execute( (const itk::Object *)caller, event);
66       }
67
68       void Execute(const itk::Object * object, const itk::EventObject & event)
69       {
70         if( !(itk::IterationEvent().CheckEvent( &event )) )
71         {
72           return;
73         }
74
75         m_Optimizer->OutputIterationInfo();
76       }
77
78       OptimizerPointer m_Optimizer;
79   };
80
81   //===========================================================================//
82   //Constructor
83   //==========================================================================//
84   BLUTDIRGenericFilter::BLUTDIRGenericFilter():
85     ImageToImageGenericFilter<Self>("Register DIR")
86   {
87     InitializeImageType<2>();
88     InitializeImageType<3>();
89     m_Verbose=false;
90   }
91
92   //=========================================================================//
93   //SetArgsInfo
94   //==========================================================================//
95   void BLUTDIRGenericFilter::SetArgsInfo(const args_info_clitkBLUTDIR & a){
96     m_ArgsInfo=a;
97     if (m_ArgsInfo.reference_given) AddInputFilename(m_ArgsInfo.reference_arg);
98
99     if (m_ArgsInfo.target_given) {
100       AddInputFilename(m_ArgsInfo.target_arg);
101     }
102
103     if (m_ArgsInfo.output_given) SetOutputFilename(m_ArgsInfo.output_arg);
104     
105     if (m_ArgsInfo.verbose_given) m_Verbose=true;
106   }
107
108   //=========================================================================//
109   //===========================================================================//
110   template<unsigned int Dim>
111     void BLUTDIRGenericFilter::InitializeImageType()
112     {
113       ADD_DEFAULT_IMAGE_TYPES(3);
114     }
115   //--------------------------------------------------------------------
116
117   //==============================================================================
118   //Creating an observer class that allows us to change parameters at subsequent levels
119   //==============================================================================
120   template <typename TRegistration,class args_info_clitkBLUTDIR>
121     class RegistrationInterfaceCommand : public itk::Command
122   {
123     public:
124       typedef RegistrationInterfaceCommand   Self;
125       typedef itk::Command             Superclass;
126       typedef itk::SmartPointer<Self>  Pointer;
127       itkNewMacro( Self );
128     protected:
129       RegistrationInterfaceCommand() { };
130     public:
131
132       // Registration
133       typedef   TRegistration                              RegistrationType;
134       typedef   RegistrationType *                         RegistrationPointer;
135
136       // Transform
137       typedef typename RegistrationType::FixedImageType FixedImageType;
138       typedef typename FixedImageType::RegionType RegionType;
139       itkStaticConstMacro(ImageDimension, unsigned int,FixedImageType::ImageDimension);
140       typedef clitk::MultipleBSplineDeformableTransform<double, ImageDimension, ImageDimension> TransformType;
141       typedef clitk::MultipleBSplineDeformableTransformInitializer<TransformType, FixedImageType> InitializerType;
142       typedef typename InitializerType::CoefficientImageType CoefficientImageType;
143       typedef itk::CastImageFilter<CoefficientImageType, CoefficientImageType> CastImageFilterType;
144       typedef typename TransformType::ParametersType ParametersType;
145       typedef typename InitializerType::Pointer InitializerPointer;
146       typedef   CommandIterationUpdate::Pointer CommandIterationUpdatePointer;
147
148       // Optimizer
149       typedef clitk::GenericOptimizer<args_info_clitkBLUTDIR> GenericOptimizerType;
150       typedef typename GenericOptimizerType::Pointer GenericOptimizerPointer;
151
152       // Metric
153       typedef typename RegistrationType::FixedImageType    InternalImageType;
154       typedef clitk::GenericMetric<args_info_clitkBLUTDIR, InternalImageType, InternalImageType> GenericMetricType;
155       typedef typename GenericMetricType::Pointer GenericMetricPointer;
156
157       // Two arguments are passed to the Execute() method: the first
158       // is the pointer to the object which invoked the event and the
159       // second is the event that was invoked.
160       void Execute(itk::Object * object, const itk::EventObject & event)
161       {
162         if( !(itk::IterationEvent().CheckEvent( &event )) )
163         {
164           return;
165         }
166
167         // Get the levels
168         RegistrationPointer registration = dynamic_cast<RegistrationPointer>( object );
169         unsigned int numberOfLevels=registration->GetNumberOfLevels();
170         unsigned int currentLevel=registration->GetCurrentLevel()+1;
171
172         // Output the levels
173         std::cout<<std::endl;
174         std::cout<<"========================================"<<std::endl;
175         std::cout<<"Starting resolution level "<<currentLevel<<" of "<<numberOfLevels<<"..."<<std::endl;
176         std::cout<<"========================================"<<std::endl;
177         std::cout<<std::endl;
178
179         // Higher level?
180         if (currentLevel>1)
181         {
182           // fixed image region pyramid
183           typedef clitk::MultiResolutionPyramidRegionFilter<InternalImageType> FixedImageRegionPyramidType;
184           typename FixedImageRegionPyramidType::Pointer fixedImageRegionPyramid=FixedImageRegionPyramidType::New();
185           fixedImageRegionPyramid->SetRegion(m_MetricRegion);
186           fixedImageRegionPyramid->SetSchedule(registration->GetFixedImagePyramid()->GetSchedule());
187
188           // Reinitialize the metric (!= number of samples)
189           m_GenericMetric= GenericMetricType::New();
190           m_GenericMetric->SetArgsInfo(m_ArgsInfo);
191           m_GenericMetric->SetFixedImage(registration->GetFixedImagePyramid()->GetOutput(registration->GetCurrentLevel()));
192           if (m_ArgsInfo.referenceMask_given)  m_GenericMetric->SetFixedImageMask(registration->GetMetric()->GetFixedImageMask());
193           m_GenericMetric->SetFixedImageRegion(fixedImageRegionPyramid->GetOutput(registration->GetCurrentLevel()));
194           typedef itk::ImageToImageMetric< InternalImageType, InternalImageType >  MetricType;
195           typename  MetricType::Pointer metric=m_GenericMetric->GetMetricPointer();
196           registration->SetMetric(metric);
197
198           // Get the current coefficient image and make a COPY
199           typename itk::ImageDuplicator<CoefficientImageType>::Pointer caster = itk::ImageDuplicator<CoefficientImageType>::New();
200           std::vector<typename CoefficientImageType::Pointer> currentCoefficientImages = m_Initializer->GetTransform()->GetCoefficientImages();
201           for (unsigned i = 0; i < currentCoefficientImages.size(); ++i)
202           {
203             caster->SetInputImage(currentCoefficientImages[i]);
204             caster->Update();
205             currentCoefficientImages[i] = caster->GetOutput();
206           }
207
208           /*
209           // Write the intermediate result?
210           if (m_ArgsInfo.intermediate_given>=numberOfLevels)
211             writeImage<CoefficientImageType>(currentCoefficientImage, m_ArgsInfo.intermediate_arg[currentLevel-2], m_ArgsInfo.verbose_flag);
212             */
213
214           // Set the new transform properties
215           m_Initializer->SetImage(registration->GetFixedImagePyramid()->GetOutput(currentLevel-1));
216           if( m_Initializer->m_ControlPointSpacingIsGiven)
217             m_Initializer->SetControlPointSpacing(m_Initializer->m_ControlPointSpacingArray[registration->GetCurrentLevel()]);
218           if( m_Initializer->m_NumberOfControlPointsIsGiven)
219             m_Initializer->SetNumberOfControlPointsInsideTheImage(m_Initializer->m_NumberOfControlPointsInsideTheImageArray[registration->GetCurrentLevel()]);
220
221           // Reinitialize the transform
222           if (m_ArgsInfo.verbose_flag) std::cout<<"Initializing transform for level "<<currentLevel<<" of "<<numberOfLevels<<"..."<<std::endl;
223           m_Initializer->InitializeTransform();
224           ParametersType* newParameters= new typename TransformType::ParametersType(m_Initializer->GetTransform()->GetNumberOfParameters());
225
226           // DS : if we want to skip the last pyramid level, force to only 1 iteration
227           DD(m_ArgsInfo.skipLastPyramidLevel_flag);
228           if ((currentLevel == numberOfLevels) && (m_ArgsInfo.skipLastPyramidLevel_flag)) {
229             DD(m_ArgsInfo.maxIt_arg);
230             std::cout << "I skip the last pyramid level : set max iteration to 0" << std::endl;
231             m_ArgsInfo.maxIt_arg = 0;
232             DD(m_ArgsInfo.maxIt_arg);
233           }
234
235           // Reinitialize an Optimizer (!= number of parameters)
236           m_GenericOptimizer = GenericOptimizerType::New();
237           m_GenericOptimizer->SetArgsInfo(m_ArgsInfo);
238           m_GenericOptimizer->SetMaximize(m_Maximize);
239           m_GenericOptimizer->SetNumberOfParameters(m_Initializer->GetTransform()->GetNumberOfParameters());
240
241
242           typedef itk::SingleValuedNonLinearOptimizer OptimizerType;
243           OptimizerType::Pointer optimizer = m_GenericOptimizer->GetOptimizerPointer();
244           optimizer->AddObserver( itk::IterationEvent(), m_CommandIterationUpdate);
245           registration->SetOptimizer(optimizer);
246           m_CommandIterationUpdate->SetOptimizer(m_GenericOptimizer);
247
248           // Set the previous transform parameters to the registration
249           // if(m_Initializer->m_Parameters!=NULL )delete m_Initializer->m_Parameters;
250           m_Initializer->SetInitialParameters(currentCoefficientImages, *newParameters);
251           registration->SetInitialTransformParametersOfNextLevel(*newParameters);
252         }
253       }
254
255       void Execute(const itk::Object * , const itk::EventObject & )
256       { return; }
257
258
259       // Members
260       void SetInitializer(InitializerPointer i){m_Initializer=i;}
261       InitializerPointer m_Initializer;
262
263       void SetArgsInfo(args_info_clitkBLUTDIR a){m_ArgsInfo=a;}
264       args_info_clitkBLUTDIR m_ArgsInfo;
265
266       void SetCommandIterationUpdate(CommandIterationUpdatePointer c){m_CommandIterationUpdate=c;};
267       CommandIterationUpdatePointer m_CommandIterationUpdate;
268
269       GenericOptimizerPointer m_GenericOptimizer;
270       void SetMaximize(bool b){m_Maximize=b;}
271       bool m_Maximize;
272
273       GenericMetricPointer m_GenericMetric;
274       void SetMetricRegion(RegionType i){m_MetricRegion=i;}
275       RegionType m_MetricRegion;
276
277
278   };
279
280   //==============================================================================
281   // Update with the number of dimensions and pixeltype
282   //==============================================================================
283   template<class InputImageType>
284     void BLUTDIRGenericFilter::UpdateWithInputImageType()
285     {
286       if (m_Verbose) std::cout << "BLUTDIRGenericFilter::UpdateWithInputImageType()" << std::endl;
287       
288       //=============================================================================
289       //Input
290       //=============================================================================
291       bool threadsGiven=m_ArgsInfo.threads_given;
292       int threads=m_ArgsInfo.threads_arg;
293       typedef typename InputImageType::PixelType PixelType;
294
295       typedef double TCoordRep;
296
297       typename InputImageType::Pointer fixedImage = this->template GetInput<InputImageType>(0);
298
299       typename InputImageType::Pointer inputFixedImage = this->template GetInput<InputImageType>(0);
300
301       // typedef input2
302       typename InputImageType::Pointer movingImage = this->template GetInput<InputImageType>(1);
303
304       typename InputImageType::Pointer inputMovingImage = this->template GetInput<InputImageType>(1);
305
306       typedef itk::Image< PixelType,InputImageType::ImageDimension >  FixedImageType;
307       typedef itk::Image< PixelType, InputImageType::ImageDimension>  MovingImageType;
308       const unsigned int SpaceDimension = InputImageType::ImageDimension;
309       //Whatever the pixel type, internally we work with an image represented in float
310       //Reading reference image
311       if (m_Verbose) std::cout<<"Reading images..."<<std::endl;
312       //=======================================================
313       //Input
314       //=======================================================
315       typename FixedImageType::Pointer croppedFixedImage=fixedImage;
316       //=======================================================
317       // Regions
318       //=======================================================
319       // The original input region
320       typename FixedImageType::RegionType fixedImageRegion = fixedImage->GetLargestPossibleRegion();
321
322       // The transform region with respect to the input region:
323       // where should the transform be DEFINED (depends on mask)
324       typename FixedImageType::RegionType transformRegion = fixedImage->GetLargestPossibleRegion();
325       typename FixedImageType::RegionType::SizeType transformRegionSize=transformRegion.GetSize();
326       typename FixedImageType::RegionType::IndexType transformRegionIndex=transformRegion.GetIndex();
327       typename FixedImageType::PointType transformRegionOrigin=fixedImage->GetOrigin();
328
329       // The metric region with respect to the extracted transform region:
330       // where should the metric be CALCULATED (depends on transform)
331       typename FixedImageType::RegionType metricRegion = fixedImage->GetLargestPossibleRegion();
332       typename FixedImageType::RegionType::IndexType metricRegionIndex=metricRegion.GetIndex();
333       typename FixedImageType::PointType metricRegionOrigin=fixedImage->GetOrigin();
334
335
336       //===========================================================================
337       // If given, we connect a mask to reference or target
338       //============================================================================
339       typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension >   MaskType;
340       typedef itk::Image< unsigned char, InputImageType::ImageDimension >   ImageLabelType;
341       typename MaskType::Pointer        fixedMask = NULL;
342       typename ImageLabelType::Pointer  labels = NULL;
343       if (m_ArgsInfo.referenceMask_given)
344       {
345         fixedMask = MaskType::New();
346         labels = ImageLabelType::New();
347         typedef itk::ImageFileReader< ImageLabelType >    LabelReaderType;
348         typename LabelReaderType::Pointer  labelReader = LabelReaderType::New();
349         labelReader->SetFileName(m_ArgsInfo.referenceMask_arg);
350         try
351         {
352           labelReader->Update();
353         }
354         catch( itk::ExceptionObject & err )
355         {
356           std::cerr << "ExceptionObject caught while reading mask or labels !" << std::endl;
357           std::cerr << err << std::endl;
358           return;
359         }
360         if (m_Verbose)std::cout <<"Reference image mask was read..." <<std::endl;
361
362         // Resample labels
363         typedef itk::ResampleImageFilter<ImageLabelType, ImageLabelType> ResamplerType;
364         typename ResamplerType::Pointer resampler = ResamplerType::New();
365         typedef itk::NearestNeighborInterpolateImageFunction<ImageLabelType, TCoordRep> InterpolatorType;
366         typename InterpolatorType::Pointer interpolator = InterpolatorType::New();
367         resampler->SetOutputParametersFromImage(fixedImage);
368         resampler->SetInterpolator(interpolator);
369         resampler->SetInput(labelReader->GetOutput());
370         resampler->Update();
371         labels = resampler->GetOutput();
372
373         // Set the image to the spatialObject
374         fixedMask->SetImage(labels);
375
376         // Find the bounding box of the "inside" label
377         typedef itk::LabelStatisticsImageFilter<ImageLabelType, ImageLabelType> StatisticsImageFilterType;
378         typename StatisticsImageFilterType::Pointer statisticsImageFilter=StatisticsImageFilterType::New();
379         statisticsImageFilter->SetInput(labels);
380         statisticsImageFilter->SetLabelInput(labels);
381         statisticsImageFilter->Update();
382         typename StatisticsImageFilterType::BoundingBoxType boundingBox = statisticsImageFilter->GetBoundingBox(1);
383
384         // Limit the transform region to the mask
385         for (unsigned int i=0; i<InputImageType::ImageDimension; i++)
386         {
387           transformRegionIndex[i]=boundingBox[2*i];
388           transformRegionSize[i]=boundingBox[2*i+1]-boundingBox[2*i]+1;
389         }
390         transformRegion.SetSize(transformRegionSize);
391         transformRegion.SetIndex(transformRegionIndex);
392         fixedImage->TransformIndexToPhysicalPoint(transformRegion.GetIndex(), transformRegionOrigin);
393
394         // Crop the fixedImage to the bounding box to facilitate multi-resolution
395         typedef itk::ExtractImageFilter<FixedImageType,FixedImageType> ExtractImageFilterType;
396         typename ExtractImageFilterType::Pointer extractImageFilter=ExtractImageFilterType::New();
397         extractImageFilter->SetDirectionCollapseToSubmatrix();
398         extractImageFilter->SetInput(fixedImage);
399         extractImageFilter->SetExtractionRegion(transformRegion);
400         extractImageFilter->Update();
401         croppedFixedImage=extractImageFilter->GetOutput();
402
403         // Update the metric region
404         metricRegion = croppedFixedImage->GetLargestPossibleRegion();
405         metricRegionIndex=metricRegion.GetIndex();
406         croppedFixedImage->TransformIndexToPhysicalPoint(metricRegionIndex, metricRegionOrigin);
407
408         // Set start index to zero (with respect to croppedFixedImage/transform region)
409         metricRegionIndex.Fill(0);
410         metricRegion.SetIndex(metricRegionIndex);
411         croppedFixedImage->SetRegions(metricRegion);
412         croppedFixedImage->SetOrigin(metricRegionOrigin);
413       }
414
415       typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension >   MaskType;
416       typename MaskType::Pointer  movingMask=NULL;
417       if (m_ArgsInfo.targetMask_given)
418       {
419         movingMask= MaskType::New();
420         typedef itk::Image< unsigned char, InputImageType::ImageDimension >   ImageMaskType;
421         typedef itk::ImageFileReader< ImageMaskType >    MaskReaderType;
422         typename MaskReaderType::Pointer  maskReader = MaskReaderType::New();
423         maskReader->SetFileName(m_ArgsInfo.targetMask_arg);
424         try
425         {
426           maskReader->Update();
427         }
428         catch( itk::ExceptionObject & err )
429         {
430           std::cerr << "ExceptionObject caught !" << std::endl;
431           std::cerr << err << std::endl;
432         }
433         if (m_Verbose)std::cout <<"Target image mask was read..." <<std::endl;
434
435         movingMask->SetImage( maskReader->GetOutput() );
436       }
437
438
439       //=======================================================
440       // Output Regions
441       //=======================================================
442
443       if (m_Verbose)
444       {
445         // Fixed image region
446         std::cout<<"The fixed image has its origin at "<<fixedImage->GetOrigin()<<std::endl
447           <<"The fixed image region starts at index "<<fixedImageRegion.GetIndex()<<std::endl
448           <<"The fixed image region has size "<< fixedImageRegion.GetSize()<<std::endl;
449
450         // Transform region
451         std::cout<<"The transform has its origin at "<<transformRegionOrigin<<std::endl
452           <<"The transform region will start at index "<<transformRegion.GetIndex()<<std::endl
453           <<"The transform region has size "<< transformRegion.GetSize()<<std::endl;
454
455         // Metric region
456         std::cout<<"The metric region has its origin at "<<metricRegionOrigin<<std::endl
457           <<"The metric region will start at index "<<metricRegion.GetIndex()<<std::endl
458           <<"The metric region has size "<< metricRegion.GetSize()<<std::endl;
459
460       }
461
462
463       //=======================================================
464       // Pyramids (update them for transform initializer)
465       //=======================================================
466       typedef itk::RecursiveMultiResolutionPyramidImageFilter< FixedImageType, FixedImageType>    FixedImagePyramidType;
467       typedef itk::RecursiveMultiResolutionPyramidImageFilter< MovingImageType, MovingImageType>    MovingImagePyramidType;
468       typename FixedImagePyramidType::Pointer fixedImagePyramid = FixedImagePyramidType::New();
469       typename MovingImagePyramidType::Pointer movingImagePyramid = MovingImagePyramidType::New();
470       fixedImagePyramid->SetUseShrinkImageFilter(false);
471       fixedImagePyramid->SetInput(croppedFixedImage);
472       fixedImagePyramid->SetNumberOfLevels(m_ArgsInfo.levels_arg);
473       movingImagePyramid->SetUseShrinkImageFilter(false);
474       movingImagePyramid->SetInput(movingImage);
475       movingImagePyramid->SetNumberOfLevels(m_ArgsInfo.levels_arg);
476       if (m_Verbose) std::cout<<"Creating the image pyramid..."<<std::endl;
477       fixedImagePyramid->Update();
478       movingImagePyramid->Update();
479       typedef clitk::MultiResolutionPyramidRegionFilter<FixedImageType> FixedImageRegionPyramidType;
480       typename FixedImageRegionPyramidType::Pointer fixedImageRegionPyramid=FixedImageRegionPyramidType::New();
481       fixedImageRegionPyramid->SetRegion(metricRegion);
482       fixedImageRegionPyramid->SetSchedule(fixedImagePyramid->GetSchedule());
483
484
485       //=======================================================
486       // Rigid or Affine Transform
487       //=======================================================
488       typedef itk::AffineTransform <double,3> RigidTransformType;
489       RigidTransformType::Pointer rigidTransform=NULL;
490       if (m_ArgsInfo.initMatrix_given)
491       {
492         if(m_Verbose) std::cout<<"Reading the prior transform matrix "<< m_ArgsInfo.initMatrix_arg<<"..."<<std::endl;
493         rigidTransform=RigidTransformType::New();
494         itk::Matrix<double,4,4> rigidTransformMatrix=clitk::ReadMatrix3D(m_ArgsInfo.initMatrix_arg);
495
496         //Set the rotation
497         itk::Matrix<double,3,3> finalRotation = clitk::GetRotationalPartMatrix3D(rigidTransformMatrix);
498         rigidTransform->SetMatrix(finalRotation);
499
500         //Set the translation
501         itk::Vector<double,3> finalTranslation = clitk::GetTranslationPartMatrix3D(rigidTransformMatrix);
502         rigidTransform->SetTranslation(finalTranslation);
503       }
504       else if (m_ArgsInfo.centre_flag)
505       {
506         if(m_Verbose) std::cout<<"No itinial matrix given and \"centre\" flag switched on. Centering all images..."<<std::endl;
507         
508         rigidTransform=RigidTransformType::New();
509         
510         typedef itk::CenteredTransformInitializer<RigidTransformType, FixedImageType, MovingImageType > TransformInitializerType;
511         typename TransformInitializerType::Pointer initializer = TransformInitializerType::New();
512         initializer->SetTransform( rigidTransform );
513         initializer->SetFixedImage( fixedImage );
514         initializer->SetMovingImage( movingImage );        
515         initializer->GeometryOn();
516         initializer->InitializeTransform();
517       }
518
519
520       //=======================================================
521       // B-LUT FFD Transform
522       //=======================================================
523       typedef  clitk::MultipleBSplineDeformableTransform<TCoordRep,InputImageType::ImageDimension, SpaceDimension > TransformType;
524       typename TransformType::Pointer transform = TransformType::New();
525       if (labels) transform->SetLabels(labels);
526       if (rigidTransform) transform->SetBulkTransform(rigidTransform);
527
528       //-------------------------------------------------------------------------
529       // The transform initializer
530       //-------------------------------------------------------------------------
531       typedef clitk::MultipleBSplineDeformableTransformInitializer< TransformType,FixedImageType> InitializerType;
532       typename InitializerType::Pointer initializer = InitializerType::New();
533       initializer->SetVerbose(m_Verbose);
534       initializer->SetImage(fixedImagePyramid->GetOutput(0));
535       initializer->SetTransform(transform);
536
537       //-------------------------------------------------------------------------
538       // Order
539       //-------------------------------------------------------------------------
540       typename FixedImageType::RegionType::SizeType splineOrders ;
541       splineOrders.Fill(3);
542       if (m_ArgsInfo.order_given)
543         for(unsigned int i=0; i<InputImageType::ImageDimension;i++)
544           splineOrders[i]=m_ArgsInfo.order_arg[i];
545       if (m_Verbose) std::cout<<"Setting the spline orders  to "<<splineOrders<<"..."<<std::endl;
546       initializer->SetSplineOrders(splineOrders);
547
548       //-------------------------------------------------------------------------
549       // Levels
550       //-------------------------------------------------------------------------
551
552       // Spacing
553       if (m_ArgsInfo.spacing_given)
554       {
555         initializer->m_ControlPointSpacingArray.resize(m_ArgsInfo.levels_arg);
556         initializer->SetControlPointSpacing(m_ArgsInfo.spacing_arg);
557         initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1]=initializer->m_ControlPointSpacing;
558         if (m_Verbose) std::cout<<"Using a control point spacing of "<<initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1]
559           <<" at level "<<m_ArgsInfo.levels_arg<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
560
561         for (int i=1; i<m_ArgsInfo.levels_arg; i++ )
562         {
563           initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1-i]=initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-i]*2;
564           if (m_Verbose) std::cout<<"Using a control point spacing of "<<initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1-i]
565             <<" at level "<<m_ArgsInfo.levels_arg-i<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
566         }
567
568       }
569
570       // Control
571       if (m_ArgsInfo.control_given)
572       {
573         initializer->m_NumberOfControlPointsInsideTheImageArray.resize(m_ArgsInfo.levels_arg);
574         initializer->SetNumberOfControlPointsInsideTheImage(m_ArgsInfo.control_arg);
575         initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1]=initializer->m_NumberOfControlPointsInsideTheImage;
576         if (m_Verbose) std::cout<<"Using "<< initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1]<<"control points inside the image"
577           <<" at level "<<m_ArgsInfo.levels_arg<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
578
579         for (int i=1; i<m_ArgsInfo.levels_arg; i++ )
580         {
581           for(unsigned int j=0;j<InputImageType::ImageDimension;j++)
582             initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i][j]=ceil ((double)initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-i][j]/2.);
583           //        initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i]=ceil ((double)initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-i]/2.);
584           if (m_Verbose) std::cout<<"Using "<< initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i]<<"control points inside the image"
585             <<" at level "<<m_ArgsInfo.levels_arg<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
586
587         }
588       }
589
590       // Inialize on the first level
591       if (m_ArgsInfo.verbose_flag) std::cout<<"Initializing transform for level 1 of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
592       if (m_ArgsInfo.spacing_given) initializer->SetControlPointSpacing(        initializer->m_ControlPointSpacingArray[0]);
593       if (m_ArgsInfo.control_given) initializer->SetNumberOfControlPointsInsideTheImage(initializer->m_NumberOfControlPointsInsideTheImageArray[0]);
594       if (m_ArgsInfo.samplingFactor_given) initializer->SetSamplingFactors(m_ArgsInfo.samplingFactor_arg);
595
596       // Initialize
597       initializer->InitializeTransform();
598
599       //-------------------------------------------------------------------------
600       // Initial parameters (passed by reference)
601       //-------------------------------------------------------------------------
602       typedef typename TransformType::ParametersType     ParametersType;
603       const unsigned int numberOfParameters =    transform->GetNumberOfParameters();
604       ParametersType parameters(numberOfParameters);
605       parameters.Fill( 0.0 );
606       transform->SetParameters( parameters );
607       if (m_ArgsInfo.initCoeff_given) initializer->SetInitialParameters(m_ArgsInfo.initCoeff_arg, parameters);
608
609       //-------------------------------------------------------------------------
610       // DEBUG: use an itk BSpline instead of multilabel BLUTs
611       //-------------------------------------------------------------------------
612       typedef itk::Transform< TCoordRep, 3, 3 > RegistrationTransformType;
613       RegistrationTransformType::Pointer regTransform(transform);
614       typedef itk::BSplineDeformableTransform<TCoordRep,SpaceDimension, 3> SingleBSplineTransformType;
615       typename SingleBSplineTransformType::Pointer sTransform;
616       if(m_ArgsInfo.itkbspline_flag) {
617         if( transform->GetTransforms().size()>1)
618           itkExceptionMacro(<< "invalid --itkbspline option if there is more than 1 label")
619         sTransform = SingleBSplineTransformType::New();
620         sTransform->SetBulkTransform( transform->GetTransforms()[0]->GetBulkTransform() );
621         sTransform->SetGridSpacing( transform->GetTransforms()[0]->GetGridSpacing() );
622         sTransform->SetGridOrigin( transform->GetTransforms()[0]->GetGridOrigin() );
623         sTransform->SetGridRegion( transform->GetTransforms()[0]->GetGridRegion() );
624         sTransform->SetParameters( transform->GetTransforms()[0]->GetParameters() );
625         regTransform = sTransform;
626         transform = NULL; // free memory
627       }
628
629       //=======================================================
630       // Interpolator
631       //=======================================================
632       typedef clitk::GenericInterpolator<args_info_clitkBLUTDIR, FixedImageType, TCoordRep > GenericInterpolatorType;
633       typename   GenericInterpolatorType::Pointer genericInterpolator=GenericInterpolatorType::New();
634       genericInterpolator->SetArgsInfo(m_ArgsInfo);
635       typedef itk::InterpolateImageFunction< FixedImageType, TCoordRep >  InterpolatorType;
636       typename  InterpolatorType::Pointer interpolator=genericInterpolator->GetInterpolatorPointer();
637
638
639       //=======================================================
640       // Metric
641       //=======================================================
642       typedef clitk::GenericMetric<args_info_clitkBLUTDIR, FixedImageType,MovingImageType > GenericMetricType;
643       typename GenericMetricType::Pointer genericMetric=GenericMetricType::New();
644       genericMetric->SetArgsInfo(m_ArgsInfo);
645       genericMetric->SetFixedImage(fixedImagePyramid->GetOutput(0));
646       if (fixedMask) genericMetric->SetFixedImageMask(fixedMask);
647       genericMetric->SetFixedImageRegion(fixedImageRegionPyramid->GetOutput(0));
648       typedef itk::ImageToImageMetric< FixedImageType, MovingImageType >  MetricType;
649       typename  MetricType::Pointer metric=genericMetric->GetMetricPointer();
650       if (movingMask) metric->SetMovingImageMask(movingMask);
651       if (threadsGiven) {
652         metric->SetNumberOfThreads( threads );
653         if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl;
654       }
655
656       //=======================================================
657       // Optimizer
658       //=======================================================
659       typedef clitk::GenericOptimizer<args_info_clitkBLUTDIR> GenericOptimizerType;
660       GenericOptimizerType::Pointer genericOptimizer = GenericOptimizerType::New();
661       genericOptimizer->SetArgsInfo(m_ArgsInfo);
662       genericOptimizer->SetMaximize(genericMetric->GetMaximize());
663       genericOptimizer->SetNumberOfParameters(regTransform->GetNumberOfParameters());
664       typedef itk::SingleValuedNonLinearOptimizer OptimizerType;
665       OptimizerType::Pointer optimizer = genericOptimizer->GetOptimizerPointer();
666
667
668       //=======================================================
669       // Registration
670       //=======================================================
671       typedef itk::MultiResolutionImageRegistrationMethod<  FixedImageType, MovingImageType >    RegistrationType;
672       typename RegistrationType::Pointer   registration  = RegistrationType::New();
673       registration->SetMetric(        metric        );
674       registration->SetOptimizer(     optimizer     );
675       registration->SetInterpolator(  interpolator  );
676       registration->SetTransform (regTransform );
677       if(threadsGiven) {
678         registration->SetNumberOfThreads(threads);
679         if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl;
680       }
681       registration->SetFixedImage(  croppedFixedImage   );
682       registration->SetMovingImage(  movingImage   );
683       registration->SetFixedImageRegion( metricRegion );
684       registration->SetFixedImagePyramid( fixedImagePyramid );
685       registration->SetMovingImagePyramid( movingImagePyramid );
686       registration->SetInitialTransformParameters( regTransform->GetParameters() );
687       registration->SetNumberOfLevels( m_ArgsInfo.levels_arg );
688       if (m_Verbose) std::cout<<"Setting the number of resolution levels to "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
689
690
691       //================================================================================================
692       // Observers
693       //================================================================================================
694       if (m_Verbose)
695       {
696         // Output iteration info
697         CommandIterationUpdate::Pointer observer = CommandIterationUpdate::New();
698         observer->SetOptimizer(genericOptimizer);
699         optimizer->AddObserver( itk::IterationEvent(), observer );
700
701         // Output level info
702         typedef RegistrationInterfaceCommand<RegistrationType,args_info_clitkBLUTDIR> CommandType;
703         typename CommandType::Pointer command = CommandType::New();
704         command->SetInitializer(initializer);
705         command->SetArgsInfo(m_ArgsInfo);
706         command->SetCommandIterationUpdate(observer);
707         command->SetMaximize(genericMetric->GetMaximize());
708         command->SetMetricRegion(metricRegion);
709         registration->AddObserver( itk::IterationEvent(), command );
710
711         if (m_ArgsInfo.coeff_given)
712         {
713           if(m_ArgsInfo.itkbspline_flag) {
714             itkExceptionMacro("--coeff and --itkbpline are incompatible");
715           }
716
717           std::cout << std::endl << "Output coefficient images every " << m_ArgsInfo.coeffEveryN_arg << " iterations." << std::endl;
718           typedef CommandIterationUpdateDVF<FixedImageType, OptimizerType, TransformType> DVFCommandType;
719           typename DVFCommandType::Pointer observerdvf = DVFCommandType::New();
720           observerdvf->SetFixedImage(fixedImage);
721           observerdvf->SetTransform(transform);
722           observerdvf->SetArgsInfo(m_ArgsInfo);
723           optimizer->AddObserver( itk::IterationEvent(), observerdvf );
724         }
725       }
726
727
728       //=======================================================
729       // Let's go
730       //=======================================================
731       if (m_Verbose) std::cout << std::endl << "Starting Registration" << std::endl;
732
733       try
734       {
735         registration->Update();
736       }
737       catch( itk::ExceptionObject & err )
738       {
739         std::cerr << "ExceptionObject caught while registering!" << std::endl;
740         std::cerr << err << std::endl;
741         return;
742       }
743
744
745       //=======================================================
746       // Get the result
747       //=======================================================
748       OptimizerType::ParametersType finalParameters =  registration->GetLastTransformParameters();
749       regTransform->SetParameters( finalParameters );
750       if (m_Verbose)
751       {
752         std::cout<<"Stop condition description: "
753           <<registration->GetOptimizer()->GetStopConditionDescription()<<std::endl;
754       }
755
756
757       //=======================================================
758       // Get the BSpline coefficient images and write them
759       //=======================================================
760       if (m_ArgsInfo.coeff_given)
761       {
762         typedef typename TransformType::CoefficientImageType CoefficientImageType;
763         std::vector<typename CoefficientImageType::Pointer> coefficientImages = transform->GetCoefficientImages();
764         typedef itk::ImageFileWriter<CoefficientImageType> CoeffWriterType;
765         typename CoeffWriterType::Pointer coeffWriter = CoeffWriterType::New();
766         unsigned nLabels = transform->GetnLabels();
767
768         std::string fname(m_ArgsInfo.coeff_arg);
769         int dotpos = fname.length() - 1;
770         while (dotpos >= 0 && fname[dotpos] != '.')
771           dotpos--;
772
773         for (unsigned i = 0; i < nLabels; ++i)
774         {
775           std::ostringstream osfname;
776           osfname << fname.substr(0, dotpos) << '_' << i << fname.substr(dotpos);
777           coeffWriter->SetInput(coefficientImages[i]);
778           coeffWriter->SetFileName(osfname.str());
779           coeffWriter->Update();
780         }
781       }
782
783
784
785       //=======================================================
786       // Compute the DVF (only deformable transform)
787       //=======================================================
788       typedef itk::Vector< float, SpaceDimension >  DisplacementType;
789       typedef itk::Image< DisplacementType, InputImageType::ImageDimension >  DisplacementFieldType;
790 #if (ITK_VERSION_MAJOR == 4) && (ITK_VERSION_MINOR < 6)
791       typedef itk::TransformToDisplacementFieldSource<DisplacementFieldType, double> ConvertorType;
792 #else
793       typedef itk::TransformToDisplacementFieldFilter<DisplacementFieldType, double> ConvertorType;
794 #endif
795       typename ConvertorType::Pointer filter= ConvertorType::New();
796       filter->SetNumberOfThreads(1);
797       if(m_ArgsInfo.itkbspline_flag)
798         sTransform->SetBulkTransform(NULL);
799       else
800         transform->SetBulkTransform(NULL);
801       filter->SetTransform(regTransform);
802 #if ITK_VERSION_MAJOR > 4 || (ITK_VERSION_MAJOR == 4 && ITK_VERSION_MINOR >= 6)
803       filter->SetReferenceImage(fixedImage);
804 #else
805       filter->SetOutputParametersFromImage(fixedImage);
806 #endif
807       filter->Update();
808       typename DisplacementFieldType::Pointer field = filter->GetOutput();
809
810
811       //=======================================================
812       // Write the DVF
813       //=======================================================
814       typedef itk::ImageFileWriter< DisplacementFieldType >  FieldWriterType;
815       typename FieldWriterType::Pointer fieldWriter = FieldWriterType::New();
816       fieldWriter->SetFileName( m_ArgsInfo.vf_arg );
817       fieldWriter->SetInput( field );
818       try
819       {
820         fieldWriter->Update();
821       }
822       catch( itk::ExceptionObject & excp )
823       {
824         std::cerr << "Exception thrown writing the DVF" << std::endl;
825         std::cerr << excp << std::endl;
826         return;
827       }
828
829
830       //=======================================================
831       // Resample the moving image
832       //=======================================================
833       typedef itk::ResampleImageFilter< MovingImageType, FixedImageType >    ResampleFilterType;
834       typename ResampleFilterType::Pointer resampler = ResampleFilterType::New();
835       if (rigidTransform) {
836         if(m_ArgsInfo.itkbspline_flag)
837           sTransform->SetBulkTransform(rigidTransform);
838         else
839           transform->SetBulkTransform(rigidTransform);
840       }
841       resampler->SetTransform( regTransform );
842       resampler->SetInput( movingImage);
843       resampler->SetOutputParametersFromImage(fixedImage);
844       resampler->Update();
845       typename FixedImageType::Pointer result=resampler->GetOutput();
846
847       //     typedef itk::WarpImageFilter< MovingImageType, FixedImageType, DeformationFieldType >    WarpFilterType;
848       //     typename WarpFilterType::Pointer warp = WarpFilterType::New();
849
850       //     warp->SetDeformationField( field );
851       //     warp->SetInput( movingImageReader->GetOutput() );
852       //     warp->SetOutputOrigin(  fixedImage->GetOrigin() );
853       //     warp->SetOutputSpacing( fixedImage->GetSpacing() );
854       //     warp->SetOutputDirection( fixedImage->GetDirection() );
855       //     warp->SetEdgePaddingValue( 0.0 );
856       //     warp->Update();
857
858
859       //=======================================================
860       // Write the warped image
861       //=======================================================
862       typedef itk::ImageFileWriter< FixedImageType >  WriterType;
863       typename WriterType::Pointer      writer =  WriterType::New();
864       writer->SetFileName( m_ArgsInfo.output_arg );
865       writer->SetInput( result    );
866
867       try
868       {
869         writer->Update();
870       }
871       catch( itk::ExceptionObject & err )
872       {
873         std::cerr << "ExceptionObject caught !" << std::endl;
874         std::cerr << err << std::endl;
875         return;
876       }
877
878
879       //=======================================================
880       // Calculate the difference after the deformable transform
881       //=======================================================
882       typedef clitk::DifferenceImageFilter<  FixedImageType, FixedImageType> DifferenceFilterType;
883       if (m_ArgsInfo.after_given)
884       {
885         typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New();
886         difference->SetValidInput( fixedImage );
887         difference->SetTestInput( result );
888
889         try
890         {
891           difference->Update();
892         }
893         catch( itk::ExceptionObject & err )
894         {
895           std::cerr << "ExceptionObject caught calculating the difference !" << std::endl;
896           std::cerr << err << std::endl;
897           return;
898         }
899
900         typename WriterType::Pointer differenceWriter=WriterType::New();
901         differenceWriter->SetInput(difference->GetOutput());
902         differenceWriter->SetFileName(m_ArgsInfo.after_arg);
903         differenceWriter->Update();
904
905       }
906
907
908       //=======================================================
909       // Calculate the difference before the deformable transform
910       //=======================================================
911       if( m_ArgsInfo.before_given )
912       {
913
914         typename FixedImageType::Pointer moving=FixedImageType::New();
915         if (m_ArgsInfo.initMatrix_given)
916         {
917           typedef itk::ResampleImageFilter<MovingImageType, FixedImageType> ResamplerType;
918           typename ResamplerType::Pointer resampler=ResamplerType::New();
919           resampler->SetInput(movingImage);
920           resampler->SetOutputOrigin(fixedImage->GetOrigin());
921           resampler->SetSize(fixedImage->GetLargestPossibleRegion().GetSize());
922           resampler->SetOutputSpacing(fixedImage->GetSpacing());
923           resampler->SetDefaultPixelValue( 0. );
924           if (rigidTransform ) resampler->SetTransform(rigidTransform);
925           resampler->Update();
926           moving=resampler->GetOutput();
927         }
928         else
929           moving=movingImage;
930
931         typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New();
932         difference->SetValidInput( fixedImage );
933         difference->SetTestInput( moving );
934
935         try
936         {
937           difference->Update();
938         }
939         catch( itk::ExceptionObject & err )
940         {
941           std::cerr << "ExceptionObject caught calculating the difference !" << std::endl;
942           std::cerr << err << std::endl;
943           return;
944         }
945
946         typename WriterType::Pointer differenceWriter=WriterType::New();
947         writer->SetFileName( m_ArgsInfo.before_arg  );
948         writer->SetInput( difference->GetOutput()  );
949         writer->Update( );
950       }
951
952       return;
953
954     }
955 }//end clitk
956
957 #endif // #define clitkBLUTDIRGenericFilter_txx