]> Creatis software - clitk.git/blob - registration/clitkBLUTDIRGenericFilter.cxx
COMP: fix test logic for ITK version which failed with the new ITK v5
[clitk.git] / registration / clitkBLUTDIRGenericFilter.cxx
1 /*=========================================================================
2 Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4 Authors belong to:
5 - University of LYON              http://www.universite-lyon.fr/
6 - Léon Bérard cancer center       http://www.centreleonberard.fr
7 - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9 This software is distributed WITHOUT ANY WARRANTY; without even
10 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11 PURPOSE.  See the copyright notices for more information.
12
13 It is distributed under dual licence
14
15 - BSD        See included LICENSE.txt file
16 - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef clitkBLUTDIRGenericFilter_cxx
19 #define clitkBLUTDIRGenericFilter_cxx
20
21 /* =================================================
22  * @file   clitkBLUTDIRGenericFilter.cxx
23  * @author
24  * @date
25  *
26  * @brief
27  *
28  ===================================================*/
29
30 #include "clitkBLUTDIRGenericFilter.h"
31 #include "clitkBLUTDIRCommandIterationUpdateDVF.h"
32 #include "itkCenteredTransformInitializer.h"
33 #if (ITK_VERSION_MAJOR == 4) && (ITK_VERSION_MINOR < 6)
34 # include "itkTransformToDisplacementFieldSource.h"
35 #else
36 # include "itkTransformToDisplacementFieldFilter.h"
37 #endif
38
39 namespace clitk
40 {
41
42   //==============================================================================
43   // Creating an observer class that allows output at each iteration
44   //==============================================================================
45   class CommandIterationUpdate : public itk::Command
46   {
47     public:
48       typedef  CommandIterationUpdate   Self;
49       typedef  itk::Command             Superclass;
50       typedef  itk::SmartPointer<Self>  Pointer;
51       itkNewMacro( Self );
52     protected:
53       CommandIterationUpdate() {};
54     public:
55       typedef   clitk::GenericOptimizer<args_info_clitkBLUTDIR>     OptimizerType;
56       typedef   const OptimizerType   *           OptimizerPointer;
57
58       // Set the generic optimizer
59       void SetOptimizer(OptimizerPointer o){m_Optimizer=o;}
60
61       // Execute
62       void Execute(itk::Object *caller, const itk::EventObject & event)
63       {
64         Execute( (const itk::Object *)caller, event);
65       }
66
67       void Execute(const itk::Object * object, const itk::EventObject & event)
68       {
69         if( !(itk::IterationEvent().CheckEvent( &event )) )
70         {
71           return;
72         }
73
74         m_Optimizer->OutputIterationInfo();
75       }
76
77       OptimizerPointer m_Optimizer;
78   };
79
80   //===========================================================================//
81   //Constructor
82   //==========================================================================//
83   BLUTDIRGenericFilter::BLUTDIRGenericFilter():
84     ImageToImageGenericFilter<Self>("Register DIR")
85   {
86     InitializeImageType<2>();
87     InitializeImageType<3>();
88     m_Verbose=false;
89   }
90
91   //=========================================================================//
92   //SetArgsInfo
93   //==========================================================================//
94   void BLUTDIRGenericFilter::SetArgsInfo(const args_info_clitkBLUTDIR & a){
95     m_ArgsInfo=a;
96     if (m_ArgsInfo.reference_given) AddInputFilename(m_ArgsInfo.reference_arg);
97
98     if (m_ArgsInfo.target_given) {
99       AddInputFilename(m_ArgsInfo.target_arg);
100     }
101
102     if (m_ArgsInfo.output_given) SetOutputFilename(m_ArgsInfo.output_arg);
103     
104     if (m_ArgsInfo.verbose_given) m_Verbose=true;
105   }
106
107   //=========================================================================//
108   //===========================================================================//
109   template<unsigned int Dim>
110     void BLUTDIRGenericFilter::InitializeImageType()
111     {
112       ADD_DEFAULT_IMAGE_TYPES(3);
113     }
114   //--------------------------------------------------------------------
115
116   //==============================================================================
117   //Creating an observer class that allows us to change parameters at subsequent levels
118   //==============================================================================
119   template <typename TRegistration,class args_info_clitkBLUTDIR>
120     class RegistrationInterfaceCommand : public itk::Command
121   {
122     public:
123       typedef RegistrationInterfaceCommand   Self;
124       typedef itk::Command             Superclass;
125       typedef itk::SmartPointer<Self>  Pointer;
126       itkNewMacro( Self );
127     protected:
128       RegistrationInterfaceCommand() { };
129     public:
130
131       // Registration
132       typedef   TRegistration                              RegistrationType;
133       typedef   RegistrationType *                         RegistrationPointer;
134
135       // Transform
136       typedef typename RegistrationType::FixedImageType FixedImageType;
137       typedef typename FixedImageType::RegionType RegionType;
138       itkStaticConstMacro(ImageDimension, unsigned int,FixedImageType::ImageDimension);
139       typedef clitk::MultipleBSplineDeformableTransform<double, ImageDimension, ImageDimension> TransformType;
140       typedef clitk::MultipleBSplineDeformableTransformInitializer<TransformType, FixedImageType> InitializerType;
141       typedef typename InitializerType::CoefficientImageType CoefficientImageType;
142       typedef itk::CastImageFilter<CoefficientImageType, CoefficientImageType> CastImageFilterType;
143       typedef typename TransformType::ParametersType ParametersType;
144       typedef typename InitializerType::Pointer InitializerPointer;
145       typedef   CommandIterationUpdate::Pointer CommandIterationUpdatePointer;
146
147       // Optimizer
148       typedef clitk::GenericOptimizer<args_info_clitkBLUTDIR> GenericOptimizerType;
149       typedef typename GenericOptimizerType::Pointer GenericOptimizerPointer;
150
151       // Metric
152       typedef typename RegistrationType::FixedImageType    InternalImageType;
153       typedef clitk::GenericMetric<args_info_clitkBLUTDIR, InternalImageType, InternalImageType> GenericMetricType;
154       typedef typename GenericMetricType::Pointer GenericMetricPointer;
155
156       // Two arguments are passed to the Execute() method: the first
157       // is the pointer to the object which invoked the event and the
158       // second is the event that was invoked.
159       void Execute(itk::Object * object, const itk::EventObject & event)
160       {
161         if( !(itk::IterationEvent().CheckEvent( &event )) )
162         {
163           return;
164         }
165
166         // Get the levels
167         RegistrationPointer registration = dynamic_cast<RegistrationPointer>( object );
168         unsigned int numberOfLevels=registration->GetNumberOfLevels();
169         unsigned int currentLevel=registration->GetCurrentLevel()+1;
170
171         // Output the levels
172         std::cout<<std::endl;
173         std::cout<<"========================================"<<std::endl;
174         std::cout<<"Starting resolution level "<<currentLevel<<" of "<<numberOfLevels<<"..."<<std::endl;
175         std::cout<<"========================================"<<std::endl;
176         std::cout<<std::endl;
177
178         // Higher level?
179         if (currentLevel>1)
180         {
181           // fixed image region pyramid
182           typedef clitk::MultiResolutionPyramidRegionFilter<InternalImageType> FixedImageRegionPyramidType;
183           typename FixedImageRegionPyramidType::Pointer fixedImageRegionPyramid=FixedImageRegionPyramidType::New();
184           fixedImageRegionPyramid->SetRegion(m_MetricRegion);
185           fixedImageRegionPyramid->SetSchedule(registration->GetFixedImagePyramid()->GetSchedule());
186
187           // Reinitialize the metric (!= number of samples)
188           m_GenericMetric= GenericMetricType::New();
189           m_GenericMetric->SetArgsInfo(m_ArgsInfo);
190           m_GenericMetric->SetFixedImage(registration->GetFixedImagePyramid()->GetOutput(registration->GetCurrentLevel()));
191           if (m_ArgsInfo.referenceMask_given)  m_GenericMetric->SetFixedImageMask(registration->GetMetric()->GetFixedImageMask());
192           m_GenericMetric->SetFixedImageRegion(fixedImageRegionPyramid->GetOutput(registration->GetCurrentLevel()));
193           typedef itk::ImageToImageMetric< InternalImageType, InternalImageType >  MetricType;
194           typename  MetricType::Pointer metric=m_GenericMetric->GetMetricPointer();
195           registration->SetMetric(metric);
196
197           // Get the current coefficient image and make a COPY
198           typename itk::ImageDuplicator<CoefficientImageType>::Pointer caster = itk::ImageDuplicator<CoefficientImageType>::New();
199           std::vector<typename CoefficientImageType::Pointer> currentCoefficientImages = m_Initializer->GetTransform()->GetCoefficientImages();
200           for (unsigned i = 0; i < currentCoefficientImages.size(); ++i)
201           {
202             caster->SetInputImage(currentCoefficientImages[i]);
203             caster->Update();
204             currentCoefficientImages[i] = caster->GetOutput();
205           }
206
207           /*
208           // Write the intermediate result?
209           if (m_ArgsInfo.intermediate_given>=numberOfLevels)
210             writeImage<CoefficientImageType>(currentCoefficientImage, m_ArgsInfo.intermediate_arg[currentLevel-2], m_ArgsInfo.verbose_flag);
211             */
212
213           // Set the new transform properties
214           m_Initializer->SetImage(registration->GetFixedImagePyramid()->GetOutput(currentLevel-1));
215           if( m_Initializer->m_ControlPointSpacingIsGiven)
216             m_Initializer->SetControlPointSpacing(m_Initializer->m_ControlPointSpacingArray[registration->GetCurrentLevel()]);
217           if( m_Initializer->m_NumberOfControlPointsIsGiven)
218             m_Initializer->SetNumberOfControlPointsInsideTheImage(m_Initializer->m_NumberOfControlPointsInsideTheImageArray[registration->GetCurrentLevel()]);
219
220           // Reinitialize the transform
221           if (m_ArgsInfo.verbose_flag) std::cout<<"Initializing transform for level "<<currentLevel<<" of "<<numberOfLevels<<"..."<<std::endl;
222           m_Initializer->InitializeTransform();
223           ParametersType* newParameters= new typename TransformType::ParametersType(m_Initializer->GetTransform()->GetNumberOfParameters());
224
225           // DS : if we want to skip the last pyramid level, force to only 1 iteration
226           DD(m_ArgsInfo.skipLastPyramidLevel_flag);
227           if ((currentLevel == numberOfLevels) && (m_ArgsInfo.skipLastPyramidLevel_flag)) {
228             DD(m_ArgsInfo.maxIt_arg);
229             std::cout << "I skip the last pyramid level : set max iteration to 0" << std::endl;
230             m_ArgsInfo.maxIt_arg = 0;
231             DD(m_ArgsInfo.maxIt_arg);
232           }
233
234           // Reinitialize an Optimizer (!= number of parameters)
235           m_GenericOptimizer = GenericOptimizerType::New();
236           m_GenericOptimizer->SetArgsInfo(m_ArgsInfo);
237           m_GenericOptimizer->SetMaximize(m_Maximize);
238           m_GenericOptimizer->SetNumberOfParameters(m_Initializer->GetTransform()->GetNumberOfParameters());
239
240
241           typedef itk::SingleValuedNonLinearOptimizer OptimizerType;
242           OptimizerType::Pointer optimizer = m_GenericOptimizer->GetOptimizerPointer();
243           optimizer->AddObserver( itk::IterationEvent(), m_CommandIterationUpdate);
244           registration->SetOptimizer(optimizer);
245           m_CommandIterationUpdate->SetOptimizer(m_GenericOptimizer);
246
247           // Set the previous transform parameters to the registration
248           // if(m_Initializer->m_Parameters!=NULL )delete m_Initializer->m_Parameters;
249           m_Initializer->SetInitialParameters(currentCoefficientImages, *newParameters);
250           registration->SetInitialTransformParametersOfNextLevel(*newParameters);
251         }
252       }
253
254       void Execute(const itk::Object * , const itk::EventObject & )
255       { return; }
256
257
258       // Members
259       void SetInitializer(InitializerPointer i){m_Initializer=i;}
260       InitializerPointer m_Initializer;
261
262       void SetArgsInfo(args_info_clitkBLUTDIR a){m_ArgsInfo=a;}
263       args_info_clitkBLUTDIR m_ArgsInfo;
264
265       void SetCommandIterationUpdate(CommandIterationUpdatePointer c){m_CommandIterationUpdate=c;};
266       CommandIterationUpdatePointer m_CommandIterationUpdate;
267
268       GenericOptimizerPointer m_GenericOptimizer;
269       void SetMaximize(bool b){m_Maximize=b;}
270       bool m_Maximize;
271
272       GenericMetricPointer m_GenericMetric;
273       void SetMetricRegion(RegionType i){m_MetricRegion=i;}
274       RegionType m_MetricRegion;
275
276
277   };
278
279   //==============================================================================
280   // Update with the number of dimensions and pixeltype
281   //==============================================================================
282   template<class InputImageType>
283     void BLUTDIRGenericFilter::UpdateWithInputImageType()
284     {
285       if (m_Verbose) std::cout << "BLUTDIRGenericFilter::UpdateWithInputImageType()" << std::endl;
286       
287       //=============================================================================
288       //Input
289       //=============================================================================
290       bool threadsGiven=m_ArgsInfo.threads_given;
291       int threads=m_ArgsInfo.threads_arg;
292       typedef typename InputImageType::PixelType PixelType;
293
294       typedef double TCoordRep;
295
296       typename InputImageType::Pointer fixedImage = this->template GetInput<InputImageType>(0);
297
298       typename InputImageType::Pointer inputFixedImage = this->template GetInput<InputImageType>(0);
299
300       // typedef input2
301       typename InputImageType::Pointer movingImage = this->template GetInput<InputImageType>(1);
302
303       typename InputImageType::Pointer inputMovingImage = this->template GetInput<InputImageType>(1);
304
305       typedef itk::Image< PixelType,InputImageType::ImageDimension >  FixedImageType;
306       typedef itk::Image< PixelType, InputImageType::ImageDimension>  MovingImageType;
307       const unsigned int SpaceDimension = InputImageType::ImageDimension;
308       //Whatever the pixel type, internally we work with an image represented in float
309       //Reading reference image
310       if (m_Verbose) std::cout<<"Reading images..."<<std::endl;
311       //=======================================================
312       //Input
313       //=======================================================
314       typename FixedImageType::Pointer croppedFixedImage=fixedImage;
315       //=======================================================
316       // Regions
317       //=======================================================
318       // The original input region
319       typename FixedImageType::RegionType fixedImageRegion = fixedImage->GetLargestPossibleRegion();
320
321       // The transform region with respect to the input region:
322       // where should the transform be DEFINED (depends on mask)
323       typename FixedImageType::RegionType transformRegion = fixedImage->GetLargestPossibleRegion();
324       typename FixedImageType::RegionType::SizeType transformRegionSize=transformRegion.GetSize();
325       typename FixedImageType::RegionType::IndexType transformRegionIndex=transformRegion.GetIndex();
326       typename FixedImageType::PointType transformRegionOrigin=fixedImage->GetOrigin();
327
328       // The metric region with respect to the extracted transform region:
329       // where should the metric be CALCULATED (depends on transform)
330       typename FixedImageType::RegionType metricRegion = fixedImage->GetLargestPossibleRegion();
331       typename FixedImageType::RegionType::IndexType metricRegionIndex=metricRegion.GetIndex();
332       typename FixedImageType::PointType metricRegionOrigin=fixedImage->GetOrigin();
333
334
335       //===========================================================================
336       // If given, we connect a mask to reference or target
337       //============================================================================
338       typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension >   MaskType;
339       typedef itk::Image< unsigned char, InputImageType::ImageDimension >   ImageLabelType;
340       typename MaskType::Pointer        fixedMask = NULL;
341       typename ImageLabelType::Pointer  labels = NULL;
342       if (m_ArgsInfo.referenceMask_given)
343       {
344         fixedMask = MaskType::New();
345         labels = ImageLabelType::New();
346         typedef itk::ImageFileReader< ImageLabelType >    LabelReaderType;
347         typename LabelReaderType::Pointer  labelReader = LabelReaderType::New();
348         labelReader->SetFileName(m_ArgsInfo.referenceMask_arg);
349         try
350         {
351           labelReader->Update();
352         }
353         catch( itk::ExceptionObject & err )
354         {
355           std::cerr << "ExceptionObject caught while reading mask or labels !" << std::endl;
356           std::cerr << err << std::endl;
357           return;
358         }
359         if (m_Verbose)std::cout <<"Reference image mask was read..." <<std::endl;
360
361         // Resample labels
362         typedef itk::ResampleImageFilter<ImageLabelType, ImageLabelType> ResamplerType;
363         typename ResamplerType::Pointer resampler = ResamplerType::New();
364         typedef itk::NearestNeighborInterpolateImageFunction<ImageLabelType, TCoordRep> InterpolatorType;
365         typename InterpolatorType::Pointer interpolator = InterpolatorType::New();
366         resampler->SetOutputParametersFromImage(fixedImage);
367         resampler->SetInterpolator(interpolator);
368         resampler->SetInput(labelReader->GetOutput());
369         resampler->Update();
370         labels = resampler->GetOutput();
371
372         // Set the image to the spatialObject
373         fixedMask->SetImage(labels);
374
375         // Find the bounding box of the "inside" label
376         typedef itk::LabelGeometryImageFilter<ImageLabelType> GeometryImageFilterType;
377         typename GeometryImageFilterType::Pointer geometryImageFilter=GeometryImageFilterType::New();
378         geometryImageFilter->SetInput(labels);
379         geometryImageFilter->Update();
380         typename GeometryImageFilterType::BoundingBoxType boundingBox = geometryImageFilter->GetBoundingBox(1);
381
382         // Limit the transform region to the mask
383         for (unsigned int i=0; i<InputImageType::ImageDimension; i++)
384         {
385           transformRegionIndex[i]=boundingBox[2*i];
386           transformRegionSize[i]=boundingBox[2*i+1]-boundingBox[2*i]+1;
387         }
388         transformRegion.SetSize(transformRegionSize);
389         transformRegion.SetIndex(transformRegionIndex);
390         fixedImage->TransformIndexToPhysicalPoint(transformRegion.GetIndex(), transformRegionOrigin);
391
392         // Crop the fixedImage to the bounding box to facilitate multi-resolution
393         typedef itk::ExtractImageFilter<FixedImageType,FixedImageType> ExtractImageFilterType;
394         typename ExtractImageFilterType::Pointer extractImageFilter=ExtractImageFilterType::New();
395         extractImageFilter->SetDirectionCollapseToSubmatrix();
396         extractImageFilter->SetInput(fixedImage);
397         extractImageFilter->SetExtractionRegion(transformRegion);
398         extractImageFilter->Update();
399         croppedFixedImage=extractImageFilter->GetOutput();
400
401         // Update the metric region
402         metricRegion = croppedFixedImage->GetLargestPossibleRegion();
403         metricRegionIndex=metricRegion.GetIndex();
404         croppedFixedImage->TransformIndexToPhysicalPoint(metricRegionIndex, metricRegionOrigin);
405
406         // Set start index to zero (with respect to croppedFixedImage/transform region)
407         metricRegionIndex.Fill(0);
408         metricRegion.SetIndex(metricRegionIndex);
409         croppedFixedImage->SetRegions(metricRegion);
410         croppedFixedImage->SetOrigin(metricRegionOrigin);
411       }
412
413       typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension >   MaskType;
414       typename MaskType::Pointer  movingMask=NULL;
415       if (m_ArgsInfo.targetMask_given)
416       {
417         movingMask= MaskType::New();
418         typedef itk::Image< unsigned char, InputImageType::ImageDimension >   ImageMaskType;
419         typedef itk::ImageFileReader< ImageMaskType >    MaskReaderType;
420         typename MaskReaderType::Pointer  maskReader = MaskReaderType::New();
421         maskReader->SetFileName(m_ArgsInfo.targetMask_arg);
422         try
423         {
424           maskReader->Update();
425         }
426         catch( itk::ExceptionObject & err )
427         {
428           std::cerr << "ExceptionObject caught !" << std::endl;
429           std::cerr << err << std::endl;
430         }
431         if (m_Verbose)std::cout <<"Target image mask was read..." <<std::endl;
432
433         movingMask->SetImage( maskReader->GetOutput() );
434       }
435
436
437       //=======================================================
438       // Output Regions
439       //=======================================================
440
441       if (m_Verbose)
442       {
443         // Fixed image region
444         std::cout<<"The fixed image has its origin at "<<fixedImage->GetOrigin()<<std::endl
445           <<"The fixed image region starts at index "<<fixedImageRegion.GetIndex()<<std::endl
446           <<"The fixed image region has size "<< fixedImageRegion.GetSize()<<std::endl;
447
448         // Transform region
449         std::cout<<"The transform has its origin at "<<transformRegionOrigin<<std::endl
450           <<"The transform region will start at index "<<transformRegion.GetIndex()<<std::endl
451           <<"The transform region has size "<< transformRegion.GetSize()<<std::endl;
452
453         // Metric region
454         std::cout<<"The metric region has its origin at "<<metricRegionOrigin<<std::endl
455           <<"The metric region will start at index "<<metricRegion.GetIndex()<<std::endl
456           <<"The metric region has size "<< metricRegion.GetSize()<<std::endl;
457
458       }
459
460
461       //=======================================================
462       // Pyramids (update them for transform initializer)
463       //=======================================================
464       typedef itk::RecursiveMultiResolutionPyramidImageFilter< FixedImageType, FixedImageType>    FixedImagePyramidType;
465       typedef itk::RecursiveMultiResolutionPyramidImageFilter< MovingImageType, MovingImageType>    MovingImagePyramidType;
466       typename FixedImagePyramidType::Pointer fixedImagePyramid = FixedImagePyramidType::New();
467       typename MovingImagePyramidType::Pointer movingImagePyramid = MovingImagePyramidType::New();
468       fixedImagePyramid->SetUseShrinkImageFilter(false);
469       fixedImagePyramid->SetInput(croppedFixedImage);
470       fixedImagePyramid->SetNumberOfLevels(m_ArgsInfo.levels_arg);
471       movingImagePyramid->SetUseShrinkImageFilter(false);
472       movingImagePyramid->SetInput(movingImage);
473       movingImagePyramid->SetNumberOfLevels(m_ArgsInfo.levels_arg);
474       if (m_Verbose) std::cout<<"Creating the image pyramid..."<<std::endl;
475       fixedImagePyramid->Update();
476       movingImagePyramid->Update();
477       typedef clitk::MultiResolutionPyramidRegionFilter<FixedImageType> FixedImageRegionPyramidType;
478       typename FixedImageRegionPyramidType::Pointer fixedImageRegionPyramid=FixedImageRegionPyramidType::New();
479       fixedImageRegionPyramid->SetRegion(metricRegion);
480       fixedImageRegionPyramid->SetSchedule(fixedImagePyramid->GetSchedule());
481
482
483       //=======================================================
484       // Rigid or Affine Transform
485       //=======================================================
486       typedef itk::AffineTransform <double,3> RigidTransformType;
487       RigidTransformType::Pointer rigidTransform=NULL;
488       if (m_ArgsInfo.initMatrix_given)
489       {
490         if(m_Verbose) std::cout<<"Reading the prior transform matrix "<< m_ArgsInfo.initMatrix_arg<<"..."<<std::endl;
491         rigidTransform=RigidTransformType::New();
492         itk::Matrix<double,4,4> rigidTransformMatrix=clitk::ReadMatrix3D(m_ArgsInfo.initMatrix_arg);
493
494         //Set the rotation
495         itk::Matrix<double,3,3> finalRotation = clitk::GetRotationalPartMatrix3D(rigidTransformMatrix);
496         rigidTransform->SetMatrix(finalRotation);
497
498         //Set the translation
499         itk::Vector<double,3> finalTranslation = clitk::GetTranslationPartMatrix3D(rigidTransformMatrix);
500         rigidTransform->SetTranslation(finalTranslation);
501       }
502       else if (m_ArgsInfo.centre_flag)
503       {
504         if(m_Verbose) std::cout<<"No itinial matrix given and \"centre\" flag switched on. Centering all images..."<<std::endl;
505         
506         rigidTransform=RigidTransformType::New();
507         
508         typedef itk::CenteredTransformInitializer<RigidTransformType, FixedImageType, MovingImageType > TransformInitializerType;
509         typename TransformInitializerType::Pointer initializer = TransformInitializerType::New();
510         initializer->SetTransform( rigidTransform );
511         initializer->SetFixedImage( fixedImage );
512         initializer->SetMovingImage( movingImage );        
513         initializer->GeometryOn();
514         initializer->InitializeTransform();
515       }
516
517
518       //=======================================================
519       // B-LUT FFD Transform
520       //=======================================================
521       typedef  clitk::MultipleBSplineDeformableTransform<TCoordRep,InputImageType::ImageDimension, SpaceDimension > TransformType;
522       typename TransformType::Pointer transform = TransformType::New();
523       if (labels) transform->SetLabels(labels);
524       if (rigidTransform) transform->SetBulkTransform(rigidTransform);
525
526       //-------------------------------------------------------------------------
527       // The transform initializer
528       //-------------------------------------------------------------------------
529       typedef clitk::MultipleBSplineDeformableTransformInitializer< TransformType,FixedImageType> InitializerType;
530       typename InitializerType::Pointer initializer = InitializerType::New();
531       initializer->SetVerbose(m_Verbose);
532       initializer->SetImage(fixedImagePyramid->GetOutput(0));
533       initializer->SetTransform(transform);
534
535       //-------------------------------------------------------------------------
536       // Order
537       //-------------------------------------------------------------------------
538       typename FixedImageType::RegionType::SizeType splineOrders ;
539       splineOrders.Fill(3);
540       if (m_ArgsInfo.order_given)
541         for(unsigned int i=0; i<InputImageType::ImageDimension;i++)
542           splineOrders[i]=m_ArgsInfo.order_arg[i];
543       if (m_Verbose) std::cout<<"Setting the spline orders  to "<<splineOrders<<"..."<<std::endl;
544       initializer->SetSplineOrders(splineOrders);
545
546       //-------------------------------------------------------------------------
547       // Levels
548       //-------------------------------------------------------------------------
549
550       // Spacing
551       if (m_ArgsInfo.spacing_given)
552       {
553         initializer->m_ControlPointSpacingArray.resize(m_ArgsInfo.levels_arg);
554         initializer->SetControlPointSpacing(m_ArgsInfo.spacing_arg);
555         initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1]=initializer->m_ControlPointSpacing;
556         if (m_Verbose) std::cout<<"Using a control point spacing of "<<initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1]
557           <<" at level "<<m_ArgsInfo.levels_arg<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
558
559         for (int i=1; i<m_ArgsInfo.levels_arg; i++ )
560         {
561           initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1-i]=initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-i]*2;
562           if (m_Verbose) std::cout<<"Using a control point spacing of "<<initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1-i]
563             <<" at level "<<m_ArgsInfo.levels_arg-i<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
564         }
565
566       }
567
568       // Control
569       if (m_ArgsInfo.control_given)
570       {
571         initializer->m_NumberOfControlPointsInsideTheImageArray.resize(m_ArgsInfo.levels_arg);
572         initializer->SetNumberOfControlPointsInsideTheImage(m_ArgsInfo.control_arg);
573         initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1]=initializer->m_NumberOfControlPointsInsideTheImage;
574         if (m_Verbose) std::cout<<"Using "<< initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1]<<"control points inside the image"
575           <<" at level "<<m_ArgsInfo.levels_arg<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
576
577         for (int i=1; i<m_ArgsInfo.levels_arg; i++ )
578         {
579           for(unsigned int j=0;j<InputImageType::ImageDimension;j++)
580             initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i][j]=ceil ((double)initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-i][j]/2.);
581           //        initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i]=ceil ((double)initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-i]/2.);
582           if (m_Verbose) std::cout<<"Using "<< initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i]<<"control points inside the image"
583             <<" at level "<<m_ArgsInfo.levels_arg<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
584
585         }
586       }
587
588       // Inialize on the first level
589       if (m_ArgsInfo.verbose_flag) std::cout<<"Initializing transform for level 1 of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
590       if (m_ArgsInfo.spacing_given) initializer->SetControlPointSpacing(        initializer->m_ControlPointSpacingArray[0]);
591       if (m_ArgsInfo.control_given) initializer->SetNumberOfControlPointsInsideTheImage(initializer->m_NumberOfControlPointsInsideTheImageArray[0]);
592       if (m_ArgsInfo.samplingFactor_given) initializer->SetSamplingFactors(m_ArgsInfo.samplingFactor_arg);
593
594       // Initialize
595       initializer->InitializeTransform();
596
597       //-------------------------------------------------------------------------
598       // Initial parameters (passed by reference)
599       //-------------------------------------------------------------------------
600       typedef typename TransformType::ParametersType     ParametersType;
601       const unsigned int numberOfParameters =    transform->GetNumberOfParameters();
602       ParametersType parameters(numberOfParameters);
603       parameters.Fill( 0.0 );
604       transform->SetParameters( parameters );
605       if (m_ArgsInfo.initCoeff_given) initializer->SetInitialParameters(m_ArgsInfo.initCoeff_arg, parameters);
606
607       //-------------------------------------------------------------------------
608       // DEBUG: use an itk BSpline instead of multilabel BLUTs
609       //-------------------------------------------------------------------------
610       typedef itk::Transform< TCoordRep, 3, 3 > RegistrationTransformType;
611       RegistrationTransformType::Pointer regTransform(transform);
612       typedef itk::BSplineDeformableTransform<TCoordRep,SpaceDimension, 3> SingleBSplineTransformType;
613       typename SingleBSplineTransformType::Pointer sTransform;
614       if(m_ArgsInfo.itkbspline_flag) {
615         if( transform->GetTransforms().size()>1)
616           itkExceptionMacro(<< "invalid --itkbspline option if there is more than 1 label")
617         sTransform = SingleBSplineTransformType::New();
618         sTransform->SetBulkTransform( transform->GetTransforms()[0]->GetBulkTransform() );
619         sTransform->SetGridSpacing( transform->GetTransforms()[0]->GetGridSpacing() );
620         sTransform->SetGridOrigin( transform->GetTransforms()[0]->GetGridOrigin() );
621         sTransform->SetGridRegion( transform->GetTransforms()[0]->GetGridRegion() );
622         sTransform->SetParameters( transform->GetTransforms()[0]->GetParameters() );
623         regTransform = sTransform;
624         transform = NULL; // free memory
625       }
626
627       //=======================================================
628       // Interpolator
629       //=======================================================
630       typedef clitk::GenericInterpolator<args_info_clitkBLUTDIR, FixedImageType, TCoordRep > GenericInterpolatorType;
631       typename   GenericInterpolatorType::Pointer genericInterpolator=GenericInterpolatorType::New();
632       genericInterpolator->SetArgsInfo(m_ArgsInfo);
633       typedef itk::InterpolateImageFunction< FixedImageType, TCoordRep >  InterpolatorType;
634       typename  InterpolatorType::Pointer interpolator=genericInterpolator->GetInterpolatorPointer();
635
636
637       //=======================================================
638       // Metric
639       //=======================================================
640       typedef clitk::GenericMetric<args_info_clitkBLUTDIR, FixedImageType,MovingImageType > GenericMetricType;
641       typename GenericMetricType::Pointer genericMetric=GenericMetricType::New();
642       genericMetric->SetArgsInfo(m_ArgsInfo);
643       genericMetric->SetFixedImage(fixedImagePyramid->GetOutput(0));
644       if (fixedMask) genericMetric->SetFixedImageMask(fixedMask);
645       genericMetric->SetFixedImageRegion(fixedImageRegionPyramid->GetOutput(0));
646       typedef itk::ImageToImageMetric< FixedImageType, MovingImageType >  MetricType;
647       typename  MetricType::Pointer metric=genericMetric->GetMetricPointer();
648       if (movingMask) metric->SetMovingImageMask(movingMask);
649       if (threadsGiven) {
650         metric->SetNumberOfThreads( threads );
651         if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl;
652       }
653
654       //=======================================================
655       // Optimizer
656       //=======================================================
657       typedef clitk::GenericOptimizer<args_info_clitkBLUTDIR> GenericOptimizerType;
658       GenericOptimizerType::Pointer genericOptimizer = GenericOptimizerType::New();
659       genericOptimizer->SetArgsInfo(m_ArgsInfo);
660       genericOptimizer->SetMaximize(genericMetric->GetMaximize());
661       genericOptimizer->SetNumberOfParameters(regTransform->GetNumberOfParameters());
662       typedef itk::SingleValuedNonLinearOptimizer OptimizerType;
663       OptimizerType::Pointer optimizer = genericOptimizer->GetOptimizerPointer();
664
665
666       //=======================================================
667       // Registration
668       //=======================================================
669       typedef itk::MultiResolutionImageRegistrationMethod<  FixedImageType, MovingImageType >    RegistrationType;
670       typename RegistrationType::Pointer   registration  = RegistrationType::New();
671       registration->SetMetric(        metric        );
672       registration->SetOptimizer(     optimizer     );
673       registration->SetInterpolator(  interpolator  );
674       registration->SetTransform (regTransform );
675       if(threadsGiven) {
676         registration->SetNumberOfThreads(threads);
677         if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl;
678       }
679       registration->SetFixedImage(  croppedFixedImage   );
680       registration->SetMovingImage(  movingImage   );
681       registration->SetFixedImageRegion( metricRegion );
682       registration->SetFixedImagePyramid( fixedImagePyramid );
683       registration->SetMovingImagePyramid( movingImagePyramid );
684       registration->SetInitialTransformParameters( regTransform->GetParameters() );
685       registration->SetNumberOfLevels( m_ArgsInfo.levels_arg );
686       if (m_Verbose) std::cout<<"Setting the number of resolution levels to "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
687
688
689       //================================================================================================
690       // Observers
691       //================================================================================================
692       if (m_Verbose)
693       {
694         // Output iteration info
695         CommandIterationUpdate::Pointer observer = CommandIterationUpdate::New();
696         observer->SetOptimizer(genericOptimizer);
697         optimizer->AddObserver( itk::IterationEvent(), observer );
698
699         // Output level info
700         typedef RegistrationInterfaceCommand<RegistrationType,args_info_clitkBLUTDIR> CommandType;
701         typename CommandType::Pointer command = CommandType::New();
702         command->SetInitializer(initializer);
703         command->SetArgsInfo(m_ArgsInfo);
704         command->SetCommandIterationUpdate(observer);
705         command->SetMaximize(genericMetric->GetMaximize());
706         command->SetMetricRegion(metricRegion);
707         registration->AddObserver( itk::IterationEvent(), command );
708
709         if (m_ArgsInfo.coeff_given)
710         {
711           if(m_ArgsInfo.itkbspline_flag) {
712             itkExceptionMacro("--coeff and --itkbpline are incompatible");
713           }
714
715           std::cout << std::endl << "Output coefficient images every " << m_ArgsInfo.coeffEveryN_arg << " iterations." << std::endl;
716           typedef CommandIterationUpdateDVF<FixedImageType, OptimizerType, TransformType> DVFCommandType;
717           typename DVFCommandType::Pointer observerdvf = DVFCommandType::New();
718           observerdvf->SetFixedImage(fixedImage);
719           observerdvf->SetTransform(transform);
720           observerdvf->SetArgsInfo(m_ArgsInfo);
721           optimizer->AddObserver( itk::IterationEvent(), observerdvf );
722         }
723       }
724
725
726       //=======================================================
727       // Let's go
728       //=======================================================
729       if (m_Verbose) std::cout << std::endl << "Starting Registration" << std::endl;
730
731       try
732       {
733         registration->Update();
734       }
735       catch( itk::ExceptionObject & err )
736       {
737         std::cerr << "ExceptionObject caught while registering!" << std::endl;
738         std::cerr << err << std::endl;
739         return;
740       }
741
742
743       //=======================================================
744       // Get the result
745       //=======================================================
746       OptimizerType::ParametersType finalParameters =  registration->GetLastTransformParameters();
747       regTransform->SetParameters( finalParameters );
748       if (m_Verbose)
749       {
750         std::cout<<"Stop condition description: "
751           <<registration->GetOptimizer()->GetStopConditionDescription()<<std::endl;
752       }
753
754
755       //=======================================================
756       // Get the BSpline coefficient images and write them
757       //=======================================================
758       if (m_ArgsInfo.coeff_given)
759       {
760         typedef typename TransformType::CoefficientImageType CoefficientImageType;
761         std::vector<typename CoefficientImageType::Pointer> coefficientImages = transform->GetCoefficientImages();
762         typedef itk::ImageFileWriter<CoefficientImageType> CoeffWriterType;
763         typename CoeffWriterType::Pointer coeffWriter = CoeffWriterType::New();
764         unsigned nLabels = transform->GetnLabels();
765
766         std::string fname(m_ArgsInfo.coeff_arg);
767         int dotpos = fname.length() - 1;
768         while (dotpos >= 0 && fname[dotpos] != '.')
769           dotpos--;
770
771         for (unsigned i = 0; i < nLabels; ++i)
772         {
773           std::ostringstream osfname;
774           osfname << fname.substr(0, dotpos) << '_' << i << fname.substr(dotpos);
775           coeffWriter->SetInput(coefficientImages[i]);
776           coeffWriter->SetFileName(osfname.str());
777           coeffWriter->Update();
778         }
779       }
780
781
782
783       //=======================================================
784       // Compute the DVF (only deformable transform)
785       //=======================================================
786       typedef itk::Vector< float, SpaceDimension >  DisplacementType;
787       typedef itk::Image< DisplacementType, InputImageType::ImageDimension >  DisplacementFieldType;
788 #if (ITK_VERSION_MAJOR == 4) && (ITK_VERSION_MINOR < 6)
789       typedef itk::TransformToDisplacementFieldSource<DisplacementFieldType, double> ConvertorType;
790 #else
791       typedef itk::TransformToDisplacementFieldFilter<DisplacementFieldType, double> ConvertorType;
792 #endif
793       typename ConvertorType::Pointer filter= ConvertorType::New();
794       filter->SetNumberOfThreads(1);
795       if(m_ArgsInfo.itkbspline_flag)
796         sTransform->SetBulkTransform(NULL);
797       else
798         transform->SetBulkTransform(NULL);
799       filter->SetTransform(regTransform);
800 #if ITK_VERSION_MAJOR > 4 || (ITK_VERSION_MAJOR == 4 && ITK_VERSION_MINOR >= 6)
801       filter->SetReferenceImage(fixedImage);
802 #else
803       filter->SetOutputParametersFromImage(fixedImage);
804 #endif
805       filter->Update();
806       typename DisplacementFieldType::Pointer field = filter->GetOutput();
807
808
809       //=======================================================
810       // Write the DVF
811       //=======================================================
812       typedef itk::ImageFileWriter< DisplacementFieldType >  FieldWriterType;
813       typename FieldWriterType::Pointer fieldWriter = FieldWriterType::New();
814       fieldWriter->SetFileName( m_ArgsInfo.vf_arg );
815       fieldWriter->SetInput( field );
816       try
817       {
818         fieldWriter->Update();
819       }
820       catch( itk::ExceptionObject & excp )
821       {
822         std::cerr << "Exception thrown writing the DVF" << std::endl;
823         std::cerr << excp << std::endl;
824         return;
825       }
826
827
828       //=======================================================
829       // Resample the moving image
830       //=======================================================
831       typedef itk::ResampleImageFilter< MovingImageType, FixedImageType >    ResampleFilterType;
832       typename ResampleFilterType::Pointer resampler = ResampleFilterType::New();
833       if (rigidTransform) {
834         if(m_ArgsInfo.itkbspline_flag)
835           sTransform->SetBulkTransform(rigidTransform);
836         else
837           transform->SetBulkTransform(rigidTransform);
838       }
839       resampler->SetTransform( regTransform );
840       resampler->SetInput( movingImage);
841       resampler->SetOutputParametersFromImage(fixedImage);
842       resampler->Update();
843       typename FixedImageType::Pointer result=resampler->GetOutput();
844
845       //     typedef itk::WarpImageFilter< MovingImageType, FixedImageType, DeformationFieldType >    WarpFilterType;
846       //     typename WarpFilterType::Pointer warp = WarpFilterType::New();
847
848       //     warp->SetDeformationField( field );
849       //     warp->SetInput( movingImageReader->GetOutput() );
850       //     warp->SetOutputOrigin(  fixedImage->GetOrigin() );
851       //     warp->SetOutputSpacing( fixedImage->GetSpacing() );
852       //     warp->SetOutputDirection( fixedImage->GetDirection() );
853       //     warp->SetEdgePaddingValue( 0.0 );
854       //     warp->Update();
855
856
857       //=======================================================
858       // Write the warped image
859       //=======================================================
860       typedef itk::ImageFileWriter< FixedImageType >  WriterType;
861       typename WriterType::Pointer      writer =  WriterType::New();
862       writer->SetFileName( m_ArgsInfo.output_arg );
863       writer->SetInput( result    );
864
865       try
866       {
867         writer->Update();
868       }
869       catch( itk::ExceptionObject & err )
870       {
871         std::cerr << "ExceptionObject caught !" << std::endl;
872         std::cerr << err << std::endl;
873         return;
874       }
875
876
877       //=======================================================
878       // Calculate the difference after the deformable transform
879       //=======================================================
880       typedef clitk::DifferenceImageFilter<  FixedImageType, FixedImageType> DifferenceFilterType;
881       if (m_ArgsInfo.after_given)
882       {
883         typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New();
884         difference->SetValidInput( fixedImage );
885         difference->SetTestInput( result );
886
887         try
888         {
889           difference->Update();
890         }
891         catch( itk::ExceptionObject & err )
892         {
893           std::cerr << "ExceptionObject caught calculating the difference !" << std::endl;
894           std::cerr << err << std::endl;
895           return;
896         }
897
898         typename WriterType::Pointer differenceWriter=WriterType::New();
899         differenceWriter->SetInput(difference->GetOutput());
900         differenceWriter->SetFileName(m_ArgsInfo.after_arg);
901         differenceWriter->Update();
902
903       }
904
905
906       //=======================================================
907       // Calculate the difference before the deformable transform
908       //=======================================================
909       if( m_ArgsInfo.before_given )
910       {
911
912         typename FixedImageType::Pointer moving=FixedImageType::New();
913         if (m_ArgsInfo.initMatrix_given)
914         {
915           typedef itk::ResampleImageFilter<MovingImageType, FixedImageType> ResamplerType;
916           typename ResamplerType::Pointer resampler=ResamplerType::New();
917           resampler->SetInput(movingImage);
918           resampler->SetOutputOrigin(fixedImage->GetOrigin());
919           resampler->SetSize(fixedImage->GetLargestPossibleRegion().GetSize());
920           resampler->SetOutputSpacing(fixedImage->GetSpacing());
921           resampler->SetDefaultPixelValue( 0. );
922           if (rigidTransform ) resampler->SetTransform(rigidTransform);
923           resampler->Update();
924           moving=resampler->GetOutput();
925         }
926         else
927           moving=movingImage;
928
929         typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New();
930         difference->SetValidInput( fixedImage );
931         difference->SetTestInput( moving );
932
933         try
934         {
935           difference->Update();
936         }
937         catch( itk::ExceptionObject & err )
938         {
939           std::cerr << "ExceptionObject caught calculating the difference !" << std::endl;
940           std::cerr << err << std::endl;
941           return;
942         }
943
944         typename WriterType::Pointer differenceWriter=WriterType::New();
945         writer->SetFileName( m_ArgsInfo.before_arg  );
946         writer->SetInput( difference->GetOutput()  );
947         writer->Update( );
948       }
949
950       return;
951
952     }
953 }//end clitk
954
955 #endif // #define clitkBLUTDIRGenericFilter_txx