]> Creatis software - clitk.git/blob - registration/clitkBLUTDIRGenericFilter.cxx
Merge branch 'master' of git.creatis.insa-lyon.fr:clitk
[clitk.git] / registration / clitkBLUTDIRGenericFilter.cxx
1 /*=========================================================================
2 Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4 Authors belong to:
5 - University of LYON              http://www.universite-lyon.fr/
6 - Léon Bérard cancer center       http://www.centreleonberard.fr
7 - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9 This software is distributed WITHOUT ANY WARRANTY; without even
10 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11 PURPOSE.  See the copyright notices for more information.
12
13 It is distributed under dual licence
14
15 - BSD        See included LICENSE.txt file
16 - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef clitkBLUTDIRGenericFilter_cxx
19 #define clitkBLUTDIRGenericFilter_cxx
20
21 /* =================================================
22  * @file   clitkBLUTDIRGenericFilter.cxx
23  * @author
24  * @date
25  *
26  * @brief
27  *
28  ===================================================*/
29
30 #include "clitkBLUTDIRGenericFilter.h"
31 #include "clitkBLUTDIRCommandIterationUpdateDVF.h"
32 #include "itkCenteredTransformInitializer.h"
33 #if ITK_VERSION_MAJOR >= 4
34 #  if ITK_VERSION_MINOR < 6
35 #    include "itkTransformToDisplacementFieldSource.h"
36 #  else
37 #    include "itkTransformToDisplacementFieldFilter.h"
38 #  endif
39 #else
40 #  include "itkTransformToDeformationFieldSource.h"
41 #endif
42
43 namespace clitk
44 {
45
46   //==============================================================================
47   // Creating an observer class that allows output at each iteration
48   //==============================================================================
49   class CommandIterationUpdate : public itk::Command
50   {
51     public:
52       typedef  CommandIterationUpdate   Self;
53       typedef  itk::Command             Superclass;
54       typedef  itk::SmartPointer<Self>  Pointer;
55       itkNewMacro( Self );
56     protected:
57       CommandIterationUpdate() {};
58     public:
59       typedef   clitk::GenericOptimizer<args_info_clitkBLUTDIR>     OptimizerType;
60       typedef   const OptimizerType   *           OptimizerPointer;
61
62       // Set the generic optimizer
63       void SetOptimizer(OptimizerPointer o){m_Optimizer=o;}
64
65       // Execute
66       void Execute(itk::Object *caller, const itk::EventObject & event)
67       {
68         Execute( (const itk::Object *)caller, event);
69       }
70
71       void Execute(const itk::Object * object, const itk::EventObject & event)
72       {
73         if( !(itk::IterationEvent().CheckEvent( &event )) )
74         {
75           return;
76         }
77
78         m_Optimizer->OutputIterationInfo();
79       }
80
81       OptimizerPointer m_Optimizer;
82   };
83
84   //===========================================================================//
85   //Constructor
86   //==========================================================================//
87   BLUTDIRGenericFilter::BLUTDIRGenericFilter():
88     ImageToImageGenericFilter<Self>("Register DIR")
89   {
90     InitializeImageType<2>();
91     InitializeImageType<3>();
92     m_Verbose=false;
93   }
94
95   //=========================================================================//
96   //SetArgsInfo
97   //==========================================================================//
98   void BLUTDIRGenericFilter::SetArgsInfo(const args_info_clitkBLUTDIR & a){
99     m_ArgsInfo=a;
100     if (m_ArgsInfo.reference_given) AddInputFilename(m_ArgsInfo.reference_arg);
101
102     if (m_ArgsInfo.target_given) {
103       AddInputFilename(m_ArgsInfo.target_arg);
104     }
105
106     if (m_ArgsInfo.output_given) SetOutputFilename(m_ArgsInfo.output_arg);
107     
108     if (m_ArgsInfo.verbose_given) m_Verbose=true;
109   }
110
111   //=========================================================================//
112   //===========================================================================//
113   template<unsigned int Dim>
114     void BLUTDIRGenericFilter::InitializeImageType()
115     {
116       ADD_DEFAULT_IMAGE_TYPES(3);
117     }
118   //--------------------------------------------------------------------
119
120   //==============================================================================
121   //Creating an observer class that allows us to change parameters at subsequent levels
122   //==============================================================================
123   template <typename TRegistration,class args_info_clitkBLUTDIR>
124     class RegistrationInterfaceCommand : public itk::Command
125   {
126     public:
127       typedef RegistrationInterfaceCommand   Self;
128       typedef itk::Command             Superclass;
129       typedef itk::SmartPointer<Self>  Pointer;
130       itkNewMacro( Self );
131     protected:
132       RegistrationInterfaceCommand() { };
133     public:
134
135       // Registration
136       typedef   TRegistration                              RegistrationType;
137       typedef   RegistrationType *                         RegistrationPointer;
138
139       // Transform
140       typedef typename RegistrationType::FixedImageType FixedImageType;
141       typedef typename FixedImageType::RegionType RegionType;
142       itkStaticConstMacro(ImageDimension, unsigned int,FixedImageType::ImageDimension);
143       typedef clitk::MultipleBSplineDeformableTransform<double, ImageDimension, ImageDimension> TransformType;
144       typedef clitk::MultipleBSplineDeformableTransformInitializer<TransformType, FixedImageType> InitializerType;
145       typedef typename InitializerType::CoefficientImageType CoefficientImageType;
146       typedef itk::CastImageFilter<CoefficientImageType, CoefficientImageType> CastImageFilterType;
147       typedef typename TransformType::ParametersType ParametersType;
148       typedef typename InitializerType::Pointer InitializerPointer;
149       typedef   CommandIterationUpdate::Pointer CommandIterationUpdatePointer;
150
151       // Optimizer
152       typedef clitk::GenericOptimizer<args_info_clitkBLUTDIR> GenericOptimizerType;
153       typedef typename GenericOptimizerType::Pointer GenericOptimizerPointer;
154
155       // Metric
156       typedef typename RegistrationType::FixedImageType    InternalImageType;
157       typedef clitk::GenericMetric<args_info_clitkBLUTDIR, InternalImageType, InternalImageType> GenericMetricType;
158       typedef typename GenericMetricType::Pointer GenericMetricPointer;
159
160       // Two arguments are passed to the Execute() method: the first
161       // is the pointer to the object which invoked the event and the
162       // second is the event that was invoked.
163       void Execute(itk::Object * object, const itk::EventObject & event)
164       {
165         if( !(itk::IterationEvent().CheckEvent( &event )) )
166         {
167           return;
168         }
169
170         // Get the levels
171         RegistrationPointer registration = dynamic_cast<RegistrationPointer>( object );
172         unsigned int numberOfLevels=registration->GetNumberOfLevels();
173         unsigned int currentLevel=registration->GetCurrentLevel()+1;
174
175         // Output the levels
176         std::cout<<std::endl;
177         std::cout<<"========================================"<<std::endl;
178         std::cout<<"Starting resolution level "<<currentLevel<<" of "<<numberOfLevels<<"..."<<std::endl;
179         std::cout<<"========================================"<<std::endl;
180         std::cout<<std::endl;
181
182         // Higher level?
183         if (currentLevel>1)
184         {
185           // fixed image region pyramid
186           typedef clitk::MultiResolutionPyramidRegionFilter<InternalImageType> FixedImageRegionPyramidType;
187           typename FixedImageRegionPyramidType::Pointer fixedImageRegionPyramid=FixedImageRegionPyramidType::New();
188           fixedImageRegionPyramid->SetRegion(m_MetricRegion);
189           fixedImageRegionPyramid->SetSchedule(registration->GetFixedImagePyramid()->GetSchedule());
190
191           // Reinitialize the metric (!= number of samples)
192           m_GenericMetric= GenericMetricType::New();
193           m_GenericMetric->SetArgsInfo(m_ArgsInfo);
194           m_GenericMetric->SetFixedImage(registration->GetFixedImagePyramid()->GetOutput(registration->GetCurrentLevel()));
195           if (m_ArgsInfo.referenceMask_given)  m_GenericMetric->SetFixedImageMask(registration->GetMetric()->GetFixedImageMask());
196           m_GenericMetric->SetFixedImageRegion(fixedImageRegionPyramid->GetOutput(registration->GetCurrentLevel()));
197           typedef itk::ImageToImageMetric< InternalImageType, InternalImageType >  MetricType;
198           typename  MetricType::Pointer metric=m_GenericMetric->GetMetricPointer();
199           registration->SetMetric(metric);
200
201           // Get the current coefficient image and make a COPY
202           typename itk::ImageDuplicator<CoefficientImageType>::Pointer caster = itk::ImageDuplicator<CoefficientImageType>::New();
203           std::vector<typename CoefficientImageType::Pointer> currentCoefficientImages = m_Initializer->GetTransform()->GetCoefficientImages();
204           for (unsigned i = 0; i < currentCoefficientImages.size(); ++i)
205           {
206             caster->SetInputImage(currentCoefficientImages[i]);
207             caster->Update();
208             currentCoefficientImages[i] = caster->GetOutput();
209           }
210
211           /*
212           // Write the intermediate result?
213           if (m_ArgsInfo.intermediate_given>=numberOfLevels)
214             writeImage<CoefficientImageType>(currentCoefficientImage, m_ArgsInfo.intermediate_arg[currentLevel-2], m_ArgsInfo.verbose_flag);
215             */
216
217           // Set the new transform properties
218           m_Initializer->SetImage(registration->GetFixedImagePyramid()->GetOutput(currentLevel-1));
219           if( m_Initializer->m_ControlPointSpacingIsGiven)
220             m_Initializer->SetControlPointSpacing(m_Initializer->m_ControlPointSpacingArray[registration->GetCurrentLevel()]);
221           if( m_Initializer->m_NumberOfControlPointsIsGiven)
222             m_Initializer->SetNumberOfControlPointsInsideTheImage(m_Initializer->m_NumberOfControlPointsInsideTheImageArray[registration->GetCurrentLevel()]);
223
224           // Reinitialize the transform
225           if (m_ArgsInfo.verbose_flag) std::cout<<"Initializing transform for level "<<currentLevel<<" of "<<numberOfLevels<<"..."<<std::endl;
226           m_Initializer->InitializeTransform();
227           ParametersType* newParameters= new typename TransformType::ParametersType(m_Initializer->GetTransform()->GetNumberOfParameters());
228
229           // DS : if we want to skip the last pyramid level, force to only 1 iteration
230           DD(m_ArgsInfo.skipLastPyramidLevel_flag);
231           if ((currentLevel == numberOfLevels) && (m_ArgsInfo.skipLastPyramidLevel_flag)) {
232             DD(m_ArgsInfo.maxIt_arg);
233             std::cout << "I skip the last pyramid level : set max iteration to 0" << std::endl;
234             m_ArgsInfo.maxIt_arg = 0;
235             DD(m_ArgsInfo.maxIt_arg);
236           }
237
238           // Reinitialize an Optimizer (!= number of parameters)
239           m_GenericOptimizer = GenericOptimizerType::New();
240           m_GenericOptimizer->SetArgsInfo(m_ArgsInfo);
241           m_GenericOptimizer->SetMaximize(m_Maximize);
242           m_GenericOptimizer->SetNumberOfParameters(m_Initializer->GetTransform()->GetNumberOfParameters());
243
244
245           typedef itk::SingleValuedNonLinearOptimizer OptimizerType;
246           OptimizerType::Pointer optimizer = m_GenericOptimizer->GetOptimizerPointer();
247           optimizer->AddObserver( itk::IterationEvent(), m_CommandIterationUpdate);
248           registration->SetOptimizer(optimizer);
249           m_CommandIterationUpdate->SetOptimizer(m_GenericOptimizer);
250
251           // Set the previous transform parameters to the registration
252           // if(m_Initializer->m_Parameters!=NULL )delete m_Initializer->m_Parameters;
253           m_Initializer->SetInitialParameters(currentCoefficientImages, *newParameters);
254           registration->SetInitialTransformParametersOfNextLevel(*newParameters);
255         }
256       }
257
258       void Execute(const itk::Object * , const itk::EventObject & )
259       { return; }
260
261
262       // Members
263       void SetInitializer(InitializerPointer i){m_Initializer=i;}
264       InitializerPointer m_Initializer;
265
266       void SetArgsInfo(args_info_clitkBLUTDIR a){m_ArgsInfo=a;}
267       args_info_clitkBLUTDIR m_ArgsInfo;
268
269       void SetCommandIterationUpdate(CommandIterationUpdatePointer c){m_CommandIterationUpdate=c;};
270       CommandIterationUpdatePointer m_CommandIterationUpdate;
271
272       GenericOptimizerPointer m_GenericOptimizer;
273       void SetMaximize(bool b){m_Maximize=b;}
274       bool m_Maximize;
275
276       GenericMetricPointer m_GenericMetric;
277       void SetMetricRegion(RegionType i){m_MetricRegion=i;}
278       RegionType m_MetricRegion;
279
280
281   };
282
283   //==============================================================================
284   // Update with the number of dimensions and pixeltype
285   //==============================================================================
286   template<class InputImageType>
287     void BLUTDIRGenericFilter::UpdateWithInputImageType()
288     {
289       if (m_Verbose) std::cout << "BLUTDIRGenericFilter::UpdateWithInputImageType()" << std::endl;
290       
291       //=============================================================================
292       //Input
293       //=============================================================================
294       bool threadsGiven=m_ArgsInfo.threads_given;
295       int threads=m_ArgsInfo.threads_arg;
296       typedef typename InputImageType::PixelType PixelType;
297
298       typedef double TCoordRep;
299
300       typename InputImageType::Pointer fixedImage = this->template GetInput<InputImageType>(0);
301
302       typename InputImageType::Pointer inputFixedImage = this->template GetInput<InputImageType>(0);
303
304       // typedef input2
305       typename InputImageType::Pointer movingImage = this->template GetInput<InputImageType>(1);
306
307       typename InputImageType::Pointer inputMovingImage = this->template GetInput<InputImageType>(1);
308
309       typedef itk::Image< PixelType,InputImageType::ImageDimension >  FixedImageType;
310       typedef itk::Image< PixelType, InputImageType::ImageDimension>  MovingImageType;
311       const unsigned int SpaceDimension = InputImageType::ImageDimension;
312       //Whatever the pixel type, internally we work with an image represented in float
313       //Reading reference image
314       if (m_Verbose) std::cout<<"Reading images..."<<std::endl;
315       //=======================================================
316       //Input
317       //=======================================================
318       typename FixedImageType::Pointer croppedFixedImage=fixedImage;
319       //=======================================================
320       // Regions
321       //=======================================================
322       // The original input region
323       typename FixedImageType::RegionType fixedImageRegion = fixedImage->GetLargestPossibleRegion();
324
325       // The transform region with respect to the input region:
326       // where should the transform be DEFINED (depends on mask)
327       typename FixedImageType::RegionType transformRegion = fixedImage->GetLargestPossibleRegion();
328       typename FixedImageType::RegionType::SizeType transformRegionSize=transformRegion.GetSize();
329       typename FixedImageType::RegionType::IndexType transformRegionIndex=transformRegion.GetIndex();
330       typename FixedImageType::PointType transformRegionOrigin=fixedImage->GetOrigin();
331
332       // The metric region with respect to the extracted transform region:
333       // where should the metric be CALCULATED (depends on transform)
334       typename FixedImageType::RegionType metricRegion = fixedImage->GetLargestPossibleRegion();
335       typename FixedImageType::RegionType::IndexType metricRegionIndex=metricRegion.GetIndex();
336       typename FixedImageType::PointType metricRegionOrigin=fixedImage->GetOrigin();
337
338
339       //===========================================================================
340       // If given, we connect a mask to reference or target
341       //============================================================================
342       typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension >   MaskType;
343       typedef itk::Image< unsigned char, InputImageType::ImageDimension >   ImageLabelType;
344       typename MaskType::Pointer        fixedMask = NULL;
345       typename ImageLabelType::Pointer  labels = NULL;
346       if (m_ArgsInfo.referenceMask_given)
347       {
348         fixedMask = MaskType::New();
349         labels = ImageLabelType::New();
350         typedef itk::ImageFileReader< ImageLabelType >    LabelReaderType;
351         typename LabelReaderType::Pointer  labelReader = LabelReaderType::New();
352         labelReader->SetFileName(m_ArgsInfo.referenceMask_arg);
353         try
354         {
355           labelReader->Update();
356         }
357         catch( itk::ExceptionObject & err )
358         {
359           std::cerr << "ExceptionObject caught while reading mask or labels !" << std::endl;
360           std::cerr << err << std::endl;
361           return;
362         }
363         if (m_Verbose)std::cout <<"Reference image mask was read..." <<std::endl;
364
365         // Resample labels
366         typedef itk::ResampleImageFilter<ImageLabelType, ImageLabelType> ResamplerType;
367         typename ResamplerType::Pointer resampler = ResamplerType::New();
368         typedef itk::NearestNeighborInterpolateImageFunction<ImageLabelType, TCoordRep> InterpolatorType;
369         typename InterpolatorType::Pointer interpolator = InterpolatorType::New();
370         resampler->SetOutputParametersFromImage(fixedImage);
371         resampler->SetInterpolator(interpolator);
372         resampler->SetInput(labelReader->GetOutput());
373         resampler->Update();
374         labels = resampler->GetOutput();
375
376         // Set the image to the spatialObject
377         fixedMask->SetImage(labels);
378
379         // Find the bounding box of the "inside" label
380         typedef itk::LabelGeometryImageFilter<ImageLabelType> GeometryImageFilterType;
381         typename GeometryImageFilterType::Pointer geometryImageFilter=GeometryImageFilterType::New();
382         geometryImageFilter->SetInput(labels);
383         geometryImageFilter->Update();
384         typename GeometryImageFilterType::BoundingBoxType boundingBox = geometryImageFilter->GetBoundingBox(1);
385
386         // Limit the transform region to the mask
387         for (unsigned int i=0; i<InputImageType::ImageDimension; i++)
388         {
389           transformRegionIndex[i]=boundingBox[2*i];
390           transformRegionSize[i]=boundingBox[2*i+1]-boundingBox[2*i]+1;
391         }
392         transformRegion.SetSize(transformRegionSize);
393         transformRegion.SetIndex(transformRegionIndex);
394         fixedImage->TransformIndexToPhysicalPoint(transformRegion.GetIndex(), transformRegionOrigin);
395
396         // Crop the fixedImage to the bounding box to facilitate multi-resolution
397         typedef itk::ExtractImageFilter<FixedImageType,FixedImageType> ExtractImageFilterType;
398         typename ExtractImageFilterType::Pointer extractImageFilter=ExtractImageFilterType::New();
399 #if ITK_VERSION_MAJOR == 4
400         extractImageFilter->SetDirectionCollapseToSubmatrix();
401 #endif
402         extractImageFilter->SetInput(fixedImage);
403         extractImageFilter->SetExtractionRegion(transformRegion);
404         extractImageFilter->Update();
405         croppedFixedImage=extractImageFilter->GetOutput();
406
407         // Update the metric region
408         metricRegion = croppedFixedImage->GetLargestPossibleRegion();
409         metricRegionIndex=metricRegion.GetIndex();
410         croppedFixedImage->TransformIndexToPhysicalPoint(metricRegionIndex, metricRegionOrigin);
411
412         // Set start index to zero (with respect to croppedFixedImage/transform region)
413         metricRegionIndex.Fill(0);
414         metricRegion.SetIndex(metricRegionIndex);
415         croppedFixedImage->SetRegions(metricRegion);
416         croppedFixedImage->SetOrigin(metricRegionOrigin);
417       }
418
419       typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension >   MaskType;
420       typename MaskType::Pointer  movingMask=NULL;
421       if (m_ArgsInfo.targetMask_given)
422       {
423         movingMask= MaskType::New();
424         typedef itk::Image< unsigned char, InputImageType::ImageDimension >   ImageMaskType;
425         typedef itk::ImageFileReader< ImageMaskType >    MaskReaderType;
426         typename MaskReaderType::Pointer  maskReader = MaskReaderType::New();
427         maskReader->SetFileName(m_ArgsInfo.targetMask_arg);
428         try
429         {
430           maskReader->Update();
431         }
432         catch( itk::ExceptionObject & err )
433         {
434           std::cerr << "ExceptionObject caught !" << std::endl;
435           std::cerr << err << std::endl;
436         }
437         if (m_Verbose)std::cout <<"Target image mask was read..." <<std::endl;
438
439         movingMask->SetImage( maskReader->GetOutput() );
440       }
441
442
443       //=======================================================
444       // Output Regions
445       //=======================================================
446
447       if (m_Verbose)
448       {
449         // Fixed image region
450         std::cout<<"The fixed image has its origin at "<<fixedImage->GetOrigin()<<std::endl
451           <<"The fixed image region starts at index "<<fixedImageRegion.GetIndex()<<std::endl
452           <<"The fixed image region has size "<< fixedImageRegion.GetSize()<<std::endl;
453
454         // Transform region
455         std::cout<<"The transform has its origin at "<<transformRegionOrigin<<std::endl
456           <<"The transform region will start at index "<<transformRegion.GetIndex()<<std::endl
457           <<"The transform region has size "<< transformRegion.GetSize()<<std::endl;
458
459         // Metric region
460         std::cout<<"The metric region has its origin at "<<metricRegionOrigin<<std::endl
461           <<"The metric region will start at index "<<metricRegion.GetIndex()<<std::endl
462           <<"The metric region has size "<< metricRegion.GetSize()<<std::endl;
463
464       }
465
466
467       //=======================================================
468       // Pyramids (update them for transform initializer)
469       //=======================================================
470       typedef itk::RecursiveMultiResolutionPyramidImageFilter< FixedImageType, FixedImageType>    FixedImagePyramidType;
471       typedef itk::RecursiveMultiResolutionPyramidImageFilter< MovingImageType, MovingImageType>    MovingImagePyramidType;
472       typename FixedImagePyramidType::Pointer fixedImagePyramid = FixedImagePyramidType::New();
473       typename MovingImagePyramidType::Pointer movingImagePyramid = MovingImagePyramidType::New();
474       fixedImagePyramid->SetUseShrinkImageFilter(false);
475       fixedImagePyramid->SetInput(croppedFixedImage);
476       fixedImagePyramid->SetNumberOfLevels(m_ArgsInfo.levels_arg);
477       movingImagePyramid->SetUseShrinkImageFilter(false);
478       movingImagePyramid->SetInput(movingImage);
479       movingImagePyramid->SetNumberOfLevels(m_ArgsInfo.levels_arg);
480       if (m_Verbose) std::cout<<"Creating the image pyramid..."<<std::endl;
481       fixedImagePyramid->Update();
482       movingImagePyramid->Update();
483       typedef clitk::MultiResolutionPyramidRegionFilter<FixedImageType> FixedImageRegionPyramidType;
484       typename FixedImageRegionPyramidType::Pointer fixedImageRegionPyramid=FixedImageRegionPyramidType::New();
485       fixedImageRegionPyramid->SetRegion(metricRegion);
486       fixedImageRegionPyramid->SetSchedule(fixedImagePyramid->GetSchedule());
487
488
489       //=======================================================
490       // Rigid or Affine Transform
491       //=======================================================
492       typedef itk::AffineTransform <double,3> RigidTransformType;
493       RigidTransformType::Pointer rigidTransform=NULL;
494       if (m_ArgsInfo.initMatrix_given)
495       {
496         if(m_Verbose) std::cout<<"Reading the prior transform matrix "<< m_ArgsInfo.initMatrix_arg<<"..."<<std::endl;
497         rigidTransform=RigidTransformType::New();
498         itk::Matrix<double,4,4> rigidTransformMatrix=clitk::ReadMatrix3D(m_ArgsInfo.initMatrix_arg);
499
500         //Set the rotation
501         itk::Matrix<double,3,3> finalRotation = clitk::GetRotationalPartMatrix3D(rigidTransformMatrix);
502         rigidTransform->SetMatrix(finalRotation);
503
504         //Set the translation
505         itk::Vector<double,3> finalTranslation = clitk::GetTranslationPartMatrix3D(rigidTransformMatrix);
506         rigidTransform->SetTranslation(finalTranslation);
507       }
508       else if (m_ArgsInfo.centre_flag)
509       {
510         if(m_Verbose) std::cout<<"No itinial matrix given and \"centre\" flag switched on. Centering all images..."<<std::endl;
511         
512         rigidTransform=RigidTransformType::New();
513         
514         typedef itk::CenteredTransformInitializer<RigidTransformType, FixedImageType, MovingImageType > TransformInitializerType;
515         typename TransformInitializerType::Pointer initializer = TransformInitializerType::New();
516         initializer->SetTransform( rigidTransform );
517         initializer->SetFixedImage( fixedImage );
518         initializer->SetMovingImage( movingImage );        
519         initializer->GeometryOn();
520         initializer->InitializeTransform();
521       }
522
523
524       //=======================================================
525       // B-LUT FFD Transform
526       //=======================================================
527       typedef  clitk::MultipleBSplineDeformableTransform<TCoordRep,InputImageType::ImageDimension, SpaceDimension > TransformType;
528       typename TransformType::Pointer transform = TransformType::New();
529       if (labels) transform->SetLabels(labels);
530       if (rigidTransform) transform->SetBulkTransform(rigidTransform);
531
532       //-------------------------------------------------------------------------
533       // The transform initializer
534       //-------------------------------------------------------------------------
535       typedef clitk::MultipleBSplineDeformableTransformInitializer< TransformType,FixedImageType> InitializerType;
536       typename InitializerType::Pointer initializer = InitializerType::New();
537       initializer->SetVerbose(m_Verbose);
538       initializer->SetImage(fixedImagePyramid->GetOutput(0));
539       initializer->SetTransform(transform);
540
541       //-------------------------------------------------------------------------
542       // Order
543       //-------------------------------------------------------------------------
544       typename FixedImageType::RegionType::SizeType splineOrders ;
545       splineOrders.Fill(3);
546       if (m_ArgsInfo.order_given)
547         for(unsigned int i=0; i<InputImageType::ImageDimension;i++)
548           splineOrders[i]=m_ArgsInfo.order_arg[i];
549       if (m_Verbose) std::cout<<"Setting the spline orders  to "<<splineOrders<<"..."<<std::endl;
550       initializer->SetSplineOrders(splineOrders);
551
552       //-------------------------------------------------------------------------
553       // Levels
554       //-------------------------------------------------------------------------
555
556       // Spacing
557       if (m_ArgsInfo.spacing_given)
558       {
559         initializer->m_ControlPointSpacingArray.resize(m_ArgsInfo.levels_arg);
560         initializer->SetControlPointSpacing(m_ArgsInfo.spacing_arg);
561         initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1]=initializer->m_ControlPointSpacing;
562         if (m_Verbose) std::cout<<"Using a control point spacing of "<<initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1]
563           <<" at level "<<m_ArgsInfo.levels_arg<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
564
565         for (int i=1; i<m_ArgsInfo.levels_arg; i++ )
566         {
567           initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1-i]=initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-i]*2;
568           if (m_Verbose) std::cout<<"Using a control point spacing of "<<initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1-i]
569             <<" at level "<<m_ArgsInfo.levels_arg-i<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
570         }
571
572       }
573
574       // Control
575       if (m_ArgsInfo.control_given)
576       {
577         initializer->m_NumberOfControlPointsInsideTheImageArray.resize(m_ArgsInfo.levels_arg);
578         initializer->SetNumberOfControlPointsInsideTheImage(m_ArgsInfo.control_arg);
579         initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1]=initializer->m_NumberOfControlPointsInsideTheImage;
580         if (m_Verbose) std::cout<<"Using "<< initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1]<<"control points inside the image"
581           <<" at level "<<m_ArgsInfo.levels_arg<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
582
583         for (int i=1; i<m_ArgsInfo.levels_arg; i++ )
584         {
585           for(unsigned int j=0;j<InputImageType::ImageDimension;j++)
586             initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i][j]=ceil ((double)initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-i][j]/2.);
587           //        initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i]=ceil ((double)initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-i]/2.);
588           if (m_Verbose) std::cout<<"Using "<< initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i]<<"control points inside the image"
589             <<" at level "<<m_ArgsInfo.levels_arg<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
590
591         }
592       }
593
594       // Inialize on the first level
595       if (m_ArgsInfo.verbose_flag) std::cout<<"Initializing transform for level 1 of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
596       if (m_ArgsInfo.spacing_given) initializer->SetControlPointSpacing(        initializer->m_ControlPointSpacingArray[0]);
597       if (m_ArgsInfo.control_given) initializer->SetNumberOfControlPointsInsideTheImage(initializer->m_NumberOfControlPointsInsideTheImageArray[0]);
598       if (m_ArgsInfo.samplingFactor_given) initializer->SetSamplingFactors(m_ArgsInfo.samplingFactor_arg);
599
600       // Initialize
601       initializer->InitializeTransform();
602
603       //-------------------------------------------------------------------------
604       // Initial parameters (passed by reference)
605       //-------------------------------------------------------------------------
606       typedef typename TransformType::ParametersType     ParametersType;
607       const unsigned int numberOfParameters =    transform->GetNumberOfParameters();
608       ParametersType parameters(numberOfParameters);
609       parameters.Fill( 0.0 );
610       transform->SetParameters( parameters );
611       if (m_ArgsInfo.initCoeff_given) initializer->SetInitialParameters(m_ArgsInfo.initCoeff_arg, parameters);
612
613       //-------------------------------------------------------------------------
614       // DEBUG: use an itk BSpline instead of multilabel BLUTs
615       //-------------------------------------------------------------------------
616       typedef itk::Transform< TCoordRep, 3, 3 > RegistrationTransformType;
617       RegistrationTransformType::Pointer regTransform(transform);
618       typedef itk::BSplineDeformableTransform<TCoordRep,SpaceDimension, 3> SingleBSplineTransformType;
619       typename SingleBSplineTransformType::Pointer sTransform;
620       if(m_ArgsInfo.itkbspline_flag) {
621         if( transform->GetTransforms().size()>1)
622           itkExceptionMacro(<< "invalid --itkbspline option if there is more than 1 label")
623         sTransform = SingleBSplineTransformType::New();
624         sTransform->SetBulkTransform( transform->GetTransforms()[0]->GetBulkTransform() );
625         sTransform->SetGridSpacing( transform->GetTransforms()[0]->GetGridSpacing() );
626         sTransform->SetGridOrigin( transform->GetTransforms()[0]->GetGridOrigin() );
627         sTransform->SetGridRegion( transform->GetTransforms()[0]->GetGridRegion() );
628         sTransform->SetParameters( transform->GetTransforms()[0]->GetParameters() );
629         regTransform = sTransform;
630         transform = NULL; // free memory
631       }
632
633       //=======================================================
634       // Interpolator
635       //=======================================================
636       typedef clitk::GenericInterpolator<args_info_clitkBLUTDIR, FixedImageType, TCoordRep > GenericInterpolatorType;
637       typename   GenericInterpolatorType::Pointer genericInterpolator=GenericInterpolatorType::New();
638       genericInterpolator->SetArgsInfo(m_ArgsInfo);
639       typedef itk::InterpolateImageFunction< FixedImageType, TCoordRep >  InterpolatorType;
640       typename  InterpolatorType::Pointer interpolator=genericInterpolator->GetInterpolatorPointer();
641
642
643       //=======================================================
644       // Metric
645       //=======================================================
646       typedef clitk::GenericMetric<args_info_clitkBLUTDIR, FixedImageType,MovingImageType > GenericMetricType;
647       typename GenericMetricType::Pointer genericMetric=GenericMetricType::New();
648       genericMetric->SetArgsInfo(m_ArgsInfo);
649       genericMetric->SetFixedImage(fixedImagePyramid->GetOutput(0));
650       if (fixedMask) genericMetric->SetFixedImageMask(fixedMask);
651       genericMetric->SetFixedImageRegion(fixedImageRegionPyramid->GetOutput(0));
652       typedef itk::ImageToImageMetric< FixedImageType, MovingImageType >  MetricType;
653       typename  MetricType::Pointer metric=genericMetric->GetMetricPointer();
654       if (movingMask) metric->SetMovingImageMask(movingMask);
655
656 #if defined(ITK_USE_OPTIMIZED_REGISTRATION_METHODS) || ITK_VERSION_MAJOR >= 4
657       if (threadsGiven) {
658         metric->SetNumberOfThreads( threads );
659         if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl;
660       }
661 #else
662       if (m_Verbose) std::cout<<"Not setting the number of threads (not compiled with USE_OPTIMIZED_REGISTRATION_METHODS)..."<<std::endl;
663 #endif
664
665
666       //=======================================================
667       // Optimizer
668       //=======================================================
669       typedef clitk::GenericOptimizer<args_info_clitkBLUTDIR> GenericOptimizerType;
670       GenericOptimizerType::Pointer genericOptimizer = GenericOptimizerType::New();
671       genericOptimizer->SetArgsInfo(m_ArgsInfo);
672       genericOptimizer->SetMaximize(genericMetric->GetMaximize());
673       genericOptimizer->SetNumberOfParameters(regTransform->GetNumberOfParameters());
674       typedef itk::SingleValuedNonLinearOptimizer OptimizerType;
675       OptimizerType::Pointer optimizer = genericOptimizer->GetOptimizerPointer();
676
677
678       //=======================================================
679       // Registration
680       //=======================================================
681       typedef itk::MultiResolutionImageRegistrationMethod<  FixedImageType, MovingImageType >    RegistrationType;
682       typename RegistrationType::Pointer   registration  = RegistrationType::New();
683       registration->SetMetric(        metric        );
684       registration->SetOptimizer(     optimizer     );
685       registration->SetInterpolator(  interpolator  );
686       registration->SetTransform (regTransform );
687       if(threadsGiven) {
688         registration->SetNumberOfThreads(threads);
689         if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl;
690       }
691       registration->SetFixedImage(  croppedFixedImage   );
692       registration->SetMovingImage(  movingImage   );
693       registration->SetFixedImageRegion( metricRegion );
694       registration->SetFixedImagePyramid( fixedImagePyramid );
695       registration->SetMovingImagePyramid( movingImagePyramid );
696       registration->SetInitialTransformParameters( regTransform->GetParameters() );
697       registration->SetNumberOfLevels( m_ArgsInfo.levels_arg );
698       if (m_Verbose) std::cout<<"Setting the number of resolution levels to "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
699
700
701       //================================================================================================
702       // Observers
703       //================================================================================================
704       if (m_Verbose)
705       {
706         // Output iteration info
707         CommandIterationUpdate::Pointer observer = CommandIterationUpdate::New();
708         observer->SetOptimizer(genericOptimizer);
709         optimizer->AddObserver( itk::IterationEvent(), observer );
710
711         // Output level info
712         typedef RegistrationInterfaceCommand<RegistrationType,args_info_clitkBLUTDIR> CommandType;
713         typename CommandType::Pointer command = CommandType::New();
714         command->SetInitializer(initializer);
715         command->SetArgsInfo(m_ArgsInfo);
716         command->SetCommandIterationUpdate(observer);
717         command->SetMaximize(genericMetric->GetMaximize());
718         command->SetMetricRegion(metricRegion);
719         registration->AddObserver( itk::IterationEvent(), command );
720
721         if (m_ArgsInfo.coeff_given)
722         {
723           if(m_ArgsInfo.itkbspline_flag) {
724             itkExceptionMacro("--coeff and --itkbpline are incompatible");
725           }
726
727           std::cout << std::endl << "Output coefficient images every " << m_ArgsInfo.coeffEveryN_arg << " iterations." << std::endl;
728           typedef CommandIterationUpdateDVF<FixedImageType, OptimizerType, TransformType> DVFCommandType;
729           typename DVFCommandType::Pointer observerdvf = DVFCommandType::New();
730           observerdvf->SetFixedImage(fixedImage);
731           observerdvf->SetTransform(transform);
732           observerdvf->SetArgsInfo(m_ArgsInfo);
733           optimizer->AddObserver( itk::IterationEvent(), observerdvf );
734         }
735       }
736
737
738       //=======================================================
739       // Let's go
740       //=======================================================
741       if (m_Verbose) std::cout << std::endl << "Starting Registration" << std::endl;
742
743       try
744       {
745 #if ITK_VERSION_MAJOR < 4 || (ITK_VERSION_MAJOR == 4 && ITK_VERSION_MINOR <= 2)
746         registration->StartRegistration();
747 #else
748         registration->Update();
749 #endif
750       }
751       catch( itk::ExceptionObject & err )
752       {
753         std::cerr << "ExceptionObject caught while registering!" << std::endl;
754         std::cerr << err << std::endl;
755         return;
756       }
757
758
759       //=======================================================
760       // Get the result
761       //=======================================================
762       OptimizerType::ParametersType finalParameters =  registration->GetLastTransformParameters();
763       regTransform->SetParameters( finalParameters );
764       if (m_Verbose)
765       {
766         std::cout<<"Stop condition description: "
767           <<registration->GetOptimizer()->GetStopConditionDescription()<<std::endl;
768       }
769
770
771       //=======================================================
772       // Get the BSpline coefficient images and write them
773       //=======================================================
774       if (m_ArgsInfo.coeff_given)
775       {
776         typedef typename TransformType::CoefficientImageType CoefficientImageType;
777         std::vector<typename CoefficientImageType::Pointer> coefficientImages = transform->GetCoefficientImages();
778         typedef itk::ImageFileWriter<CoefficientImageType> CoeffWriterType;
779         typename CoeffWriterType::Pointer coeffWriter = CoeffWriterType::New();
780         unsigned nLabels = transform->GetnLabels();
781
782         std::string fname(m_ArgsInfo.coeff_arg);
783         int dotpos = fname.length() - 1;
784         while (dotpos >= 0 && fname[dotpos] != '.')
785           dotpos--;
786
787         for (unsigned i = 0; i < nLabels; ++i)
788         {
789           std::ostringstream osfname;
790           osfname << fname.substr(0, dotpos) << '_' << i << fname.substr(dotpos);
791           coeffWriter->SetInput(coefficientImages[i]);
792           coeffWriter->SetFileName(osfname.str());
793           coeffWriter->Update();
794         }
795       }
796
797
798
799       //=======================================================
800       // Compute the DVF (only deformable transform)
801       //=======================================================
802       typedef itk::Vector< float, SpaceDimension >  DisplacementType;
803       typedef itk::Image< DisplacementType, InputImageType::ImageDimension >  DisplacementFieldType;
804 #if ITK_VERSION_MAJOR >= 4
805 #  if ITK_VERSION_MINOR < 6
806       typedef itk::TransformToDisplacementFieldSource<DisplacementFieldType, double> ConvertorType;
807 #  else
808       typedef itk::TransformToDisplacementFieldFilter<DisplacementFieldType, double> ConvertorType;
809 #  endif
810 #else
811       typedef itk::TransformToDeformationFieldSource<DisplacementFieldType, double> ConvertorType;
812 #endif
813       typename ConvertorType::Pointer filter= ConvertorType::New();
814       filter->SetNumberOfThreads(1);
815       if(m_ArgsInfo.itkbspline_flag)
816         sTransform->SetBulkTransform(NULL);
817       else
818         transform->SetBulkTransform(NULL);
819       filter->SetTransform(regTransform);
820 #if ITK_VERSION_MAJOR > 4 || (ITK_VERSION_MAJOR == 4 && ITK_VERSION_MINOR >= 6)
821       filter->SetReferenceImage(fixedImage);
822 #else
823       filter->SetOutputParametersFromImage(fixedImage);
824 #endif
825       filter->Update();
826       typename DisplacementFieldType::Pointer field = filter->GetOutput();
827
828
829       //=======================================================
830       // Write the DVF
831       //=======================================================
832       typedef itk::ImageFileWriter< DisplacementFieldType >  FieldWriterType;
833       typename FieldWriterType::Pointer fieldWriter = FieldWriterType::New();
834       fieldWriter->SetFileName( m_ArgsInfo.vf_arg );
835       fieldWriter->SetInput( field );
836       try
837       {
838         fieldWriter->Update();
839       }
840       catch( itk::ExceptionObject & excp )
841       {
842         std::cerr << "Exception thrown writing the DVF" << std::endl;
843         std::cerr << excp << std::endl;
844         return;
845       }
846
847
848       //=======================================================
849       // Resample the moving image
850       //=======================================================
851       typedef itk::ResampleImageFilter< MovingImageType, FixedImageType >    ResampleFilterType;
852       typename ResampleFilterType::Pointer resampler = ResampleFilterType::New();
853       if (rigidTransform) {
854         if(m_ArgsInfo.itkbspline_flag)
855           sTransform->SetBulkTransform(rigidTransform);
856         else
857           transform->SetBulkTransform(rigidTransform);
858       }
859       resampler->SetTransform( regTransform );
860       resampler->SetInput( movingImage);
861       resampler->SetOutputParametersFromImage(fixedImage);
862       resampler->Update();
863       typename FixedImageType::Pointer result=resampler->GetOutput();
864
865       //     typedef itk::WarpImageFilter< MovingImageType, FixedImageType, DeformationFieldType >    WarpFilterType;
866       //     typename WarpFilterType::Pointer warp = WarpFilterType::New();
867
868       //     warp->SetDeformationField( field );
869       //     warp->SetInput( movingImageReader->GetOutput() );
870       //     warp->SetOutputOrigin(  fixedImage->GetOrigin() );
871       //     warp->SetOutputSpacing( fixedImage->GetSpacing() );
872       //     warp->SetOutputDirection( fixedImage->GetDirection() );
873       //     warp->SetEdgePaddingValue( 0.0 );
874       //     warp->Update();
875
876
877       //=======================================================
878       // Write the warped image
879       //=======================================================
880       typedef itk::ImageFileWriter< FixedImageType >  WriterType;
881       typename WriterType::Pointer      writer =  WriterType::New();
882       writer->SetFileName( m_ArgsInfo.output_arg );
883       writer->SetInput( result    );
884
885       try
886       {
887         writer->Update();
888       }
889       catch( itk::ExceptionObject & err )
890       {
891         std::cerr << "ExceptionObject caught !" << std::endl;
892         std::cerr << err << std::endl;
893         return;
894       }
895
896
897       //=======================================================
898       // Calculate the difference after the deformable transform
899       //=======================================================
900       typedef clitk::DifferenceImageFilter<  FixedImageType, FixedImageType> DifferenceFilterType;
901       if (m_ArgsInfo.after_given)
902       {
903         typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New();
904         difference->SetValidInput( fixedImage );
905         difference->SetTestInput( result );
906
907         try
908         {
909           difference->Update();
910         }
911         catch( itk::ExceptionObject & err )
912         {
913           std::cerr << "ExceptionObject caught calculating the difference !" << std::endl;
914           std::cerr << err << std::endl;
915           return;
916         }
917
918         typename WriterType::Pointer differenceWriter=WriterType::New();
919         differenceWriter->SetInput(difference->GetOutput());
920         differenceWriter->SetFileName(m_ArgsInfo.after_arg);
921         differenceWriter->Update();
922
923       }
924
925
926       //=======================================================
927       // Calculate the difference before the deformable transform
928       //=======================================================
929       if( m_ArgsInfo.before_given )
930       {
931
932         typename FixedImageType::Pointer moving=FixedImageType::New();
933         if (m_ArgsInfo.initMatrix_given)
934         {
935           typedef itk::ResampleImageFilter<MovingImageType, FixedImageType> ResamplerType;
936           typename ResamplerType::Pointer resampler=ResamplerType::New();
937           resampler->SetInput(movingImage);
938           resampler->SetOutputOrigin(fixedImage->GetOrigin());
939           resampler->SetSize(fixedImage->GetLargestPossibleRegion().GetSize());
940           resampler->SetOutputSpacing(fixedImage->GetSpacing());
941           resampler->SetDefaultPixelValue( 0. );
942           if (rigidTransform ) resampler->SetTransform(rigidTransform);
943           resampler->Update();
944           moving=resampler->GetOutput();
945         }
946         else
947           moving=movingImage;
948
949         typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New();
950         difference->SetValidInput( fixedImage );
951         difference->SetTestInput( moving );
952
953         try
954         {
955           difference->Update();
956         }
957         catch( itk::ExceptionObject & err )
958         {
959           std::cerr << "ExceptionObject caught calculating the difference !" << std::endl;
960           std::cerr << err << std::endl;
961           return;
962         }
963
964         typename WriterType::Pointer differenceWriter=WriterType::New();
965         writer->SetFileName( m_ArgsInfo.before_arg  );
966         writer->SetInput( difference->GetOutput()  );
967         writer->Update( );
968       }
969
970       return;
971
972     }
973 }//end clitk
974
975 #endif // #define clitkBLUTDIRGenericFilter_txx