]> Creatis software - clitk.git/blob - registration/clitkBLUTDIRGenericFilter.cxx
Remove ITK3 and ITK4.2 tests and dependencies
[clitk.git] / registration / clitkBLUTDIRGenericFilter.cxx
1 /*=========================================================================
2 Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4 Authors belong to:
5 - University of LYON              http://www.universite-lyon.fr/
6 - Léon Bérard cancer center       http://www.centreleonberard.fr
7 - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9 This software is distributed WITHOUT ANY WARRANTY; without even
10 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11 PURPOSE.  See the copyright notices for more information.
12
13 It is distributed under dual licence
14
15 - BSD        See included LICENSE.txt file
16 - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef clitkBLUTDIRGenericFilter_cxx
19 #define clitkBLUTDIRGenericFilter_cxx
20
21 /* =================================================
22  * @file   clitkBLUTDIRGenericFilter.cxx
23  * @author
24  * @date
25  *
26  * @brief
27  *
28  ===================================================*/
29
30 #include "clitkBLUTDIRGenericFilter.h"
31 #include "clitkBLUTDIRCommandIterationUpdateDVF.h"
32 #include "itkCenteredTransformInitializer.h"
33 #if ITK_VERSION_MAJOR >= 4
34 #  if ITK_VERSION_MINOR < 6
35 #    include "itkTransformToDisplacementFieldSource.h"
36 #  else
37 #    include "itkTransformToDisplacementFieldFilter.h"
38 #  endif
39 #else
40 #  include "itkTransformToDeformationFieldSource.h"
41 #endif
42
43 namespace clitk
44 {
45
46   //==============================================================================
47   // Creating an observer class that allows output at each iteration
48   //==============================================================================
49   class CommandIterationUpdate : public itk::Command
50   {
51     public:
52       typedef  CommandIterationUpdate   Self;
53       typedef  itk::Command             Superclass;
54       typedef  itk::SmartPointer<Self>  Pointer;
55       itkNewMacro( Self );
56     protected:
57       CommandIterationUpdate() {};
58     public:
59       typedef   clitk::GenericOptimizer<args_info_clitkBLUTDIR>     OptimizerType;
60       typedef   const OptimizerType   *           OptimizerPointer;
61
62       // Set the generic optimizer
63       void SetOptimizer(OptimizerPointer o){m_Optimizer=o;}
64
65       // Execute
66       void Execute(itk::Object *caller, const itk::EventObject & event)
67       {
68         Execute( (const itk::Object *)caller, event);
69       }
70
71       void Execute(const itk::Object * object, const itk::EventObject & event)
72       {
73         if( !(itk::IterationEvent().CheckEvent( &event )) )
74         {
75           return;
76         }
77
78         m_Optimizer->OutputIterationInfo();
79       }
80
81       OptimizerPointer m_Optimizer;
82   };
83
84   //===========================================================================//
85   //Constructor
86   //==========================================================================//
87   BLUTDIRGenericFilter::BLUTDIRGenericFilter():
88     ImageToImageGenericFilter<Self>("Register DIR")
89   {
90     InitializeImageType<2>();
91     InitializeImageType<3>();
92     m_Verbose=false;
93   }
94
95   //=========================================================================//
96   //SetArgsInfo
97   //==========================================================================//
98   void BLUTDIRGenericFilter::SetArgsInfo(const args_info_clitkBLUTDIR & a){
99     m_ArgsInfo=a;
100     if (m_ArgsInfo.reference_given) AddInputFilename(m_ArgsInfo.reference_arg);
101
102     if (m_ArgsInfo.target_given) {
103       AddInputFilename(m_ArgsInfo.target_arg);
104     }
105
106     if (m_ArgsInfo.output_given) SetOutputFilename(m_ArgsInfo.output_arg);
107     
108     if (m_ArgsInfo.verbose_given) m_Verbose=true;
109   }
110
111   //=========================================================================//
112   //===========================================================================//
113   template<unsigned int Dim>
114     void BLUTDIRGenericFilter::InitializeImageType()
115     {
116       ADD_DEFAULT_IMAGE_TYPES(3);
117     }
118   //--------------------------------------------------------------------
119
120   //==============================================================================
121   //Creating an observer class that allows us to change parameters at subsequent levels
122   //==============================================================================
123   template <typename TRegistration,class args_info_clitkBLUTDIR>
124     class RegistrationInterfaceCommand : public itk::Command
125   {
126     public:
127       typedef RegistrationInterfaceCommand   Self;
128       typedef itk::Command             Superclass;
129       typedef itk::SmartPointer<Self>  Pointer;
130       itkNewMacro( Self );
131     protected:
132       RegistrationInterfaceCommand() { };
133     public:
134
135       // Registration
136       typedef   TRegistration                              RegistrationType;
137       typedef   RegistrationType *                         RegistrationPointer;
138
139       // Transform
140       typedef typename RegistrationType::FixedImageType FixedImageType;
141       typedef typename FixedImageType::RegionType RegionType;
142       itkStaticConstMacro(ImageDimension, unsigned int,FixedImageType::ImageDimension);
143       typedef clitk::MultipleBSplineDeformableTransform<double, ImageDimension, ImageDimension> TransformType;
144       typedef clitk::MultipleBSplineDeformableTransformInitializer<TransformType, FixedImageType> InitializerType;
145       typedef typename InitializerType::CoefficientImageType CoefficientImageType;
146       typedef itk::CastImageFilter<CoefficientImageType, CoefficientImageType> CastImageFilterType;
147       typedef typename TransformType::ParametersType ParametersType;
148       typedef typename InitializerType::Pointer InitializerPointer;
149       typedef   CommandIterationUpdate::Pointer CommandIterationUpdatePointer;
150
151       // Optimizer
152       typedef clitk::GenericOptimizer<args_info_clitkBLUTDIR> GenericOptimizerType;
153       typedef typename GenericOptimizerType::Pointer GenericOptimizerPointer;
154
155       // Metric
156       typedef typename RegistrationType::FixedImageType    InternalImageType;
157       typedef clitk::GenericMetric<args_info_clitkBLUTDIR, InternalImageType, InternalImageType> GenericMetricType;
158       typedef typename GenericMetricType::Pointer GenericMetricPointer;
159
160       // Two arguments are passed to the Execute() method: the first
161       // is the pointer to the object which invoked the event and the
162       // second is the event that was invoked.
163       void Execute(itk::Object * object, const itk::EventObject & event)
164       {
165         if( !(itk::IterationEvent().CheckEvent( &event )) )
166         {
167           return;
168         }
169
170         // Get the levels
171         RegistrationPointer registration = dynamic_cast<RegistrationPointer>( object );
172         unsigned int numberOfLevels=registration->GetNumberOfLevels();
173         unsigned int currentLevel=registration->GetCurrentLevel()+1;
174
175         // Output the levels
176         std::cout<<std::endl;
177         std::cout<<"========================================"<<std::endl;
178         std::cout<<"Starting resolution level "<<currentLevel<<" of "<<numberOfLevels<<"..."<<std::endl;
179         std::cout<<"========================================"<<std::endl;
180         std::cout<<std::endl;
181
182         // Higher level?
183         if (currentLevel>1)
184         {
185           // fixed image region pyramid
186           typedef clitk::MultiResolutionPyramidRegionFilter<InternalImageType> FixedImageRegionPyramidType;
187           typename FixedImageRegionPyramidType::Pointer fixedImageRegionPyramid=FixedImageRegionPyramidType::New();
188           fixedImageRegionPyramid->SetRegion(m_MetricRegion);
189           fixedImageRegionPyramid->SetSchedule(registration->GetFixedImagePyramid()->GetSchedule());
190
191           // Reinitialize the metric (!= number of samples)
192           m_GenericMetric= GenericMetricType::New();
193           m_GenericMetric->SetArgsInfo(m_ArgsInfo);
194           m_GenericMetric->SetFixedImage(registration->GetFixedImagePyramid()->GetOutput(registration->GetCurrentLevel()));
195           if (m_ArgsInfo.referenceMask_given)  m_GenericMetric->SetFixedImageMask(registration->GetMetric()->GetFixedImageMask());
196           m_GenericMetric->SetFixedImageRegion(fixedImageRegionPyramid->GetOutput(registration->GetCurrentLevel()));
197           typedef itk::ImageToImageMetric< InternalImageType, InternalImageType >  MetricType;
198           typename  MetricType::Pointer metric=m_GenericMetric->GetMetricPointer();
199           registration->SetMetric(metric);
200
201           // Get the current coefficient image and make a COPY
202           typename itk::ImageDuplicator<CoefficientImageType>::Pointer caster = itk::ImageDuplicator<CoefficientImageType>::New();
203           std::vector<typename CoefficientImageType::Pointer> currentCoefficientImages = m_Initializer->GetTransform()->GetCoefficientImages();
204           for (unsigned i = 0; i < currentCoefficientImages.size(); ++i)
205           {
206             caster->SetInputImage(currentCoefficientImages[i]);
207             caster->Update();
208             currentCoefficientImages[i] = caster->GetOutput();
209           }
210
211           /*
212           // Write the intermediate result?
213           if (m_ArgsInfo.intermediate_given>=numberOfLevels)
214             writeImage<CoefficientImageType>(currentCoefficientImage, m_ArgsInfo.intermediate_arg[currentLevel-2], m_ArgsInfo.verbose_flag);
215             */
216
217           // Set the new transform properties
218           m_Initializer->SetImage(registration->GetFixedImagePyramid()->GetOutput(currentLevel-1));
219           if( m_Initializer->m_ControlPointSpacingIsGiven)
220             m_Initializer->SetControlPointSpacing(m_Initializer->m_ControlPointSpacingArray[registration->GetCurrentLevel()]);
221           if( m_Initializer->m_NumberOfControlPointsIsGiven)
222             m_Initializer->SetNumberOfControlPointsInsideTheImage(m_Initializer->m_NumberOfControlPointsInsideTheImageArray[registration->GetCurrentLevel()]);
223
224           // Reinitialize the transform
225           if (m_ArgsInfo.verbose_flag) std::cout<<"Initializing transform for level "<<currentLevel<<" of "<<numberOfLevels<<"..."<<std::endl;
226           m_Initializer->InitializeTransform();
227           ParametersType* newParameters= new typename TransformType::ParametersType(m_Initializer->GetTransform()->GetNumberOfParameters());
228
229           // DS : if we want to skip the last pyramid level, force to only 1 iteration
230           DD(m_ArgsInfo.skipLastPyramidLevel_flag);
231           if ((currentLevel == numberOfLevels) && (m_ArgsInfo.skipLastPyramidLevel_flag)) {
232             DD(m_ArgsInfo.maxIt_arg);
233             std::cout << "I skip the last pyramid level : set max iteration to 0" << std::endl;
234             m_ArgsInfo.maxIt_arg = 0;
235             DD(m_ArgsInfo.maxIt_arg);
236           }
237
238           // Reinitialize an Optimizer (!= number of parameters)
239           m_GenericOptimizer = GenericOptimizerType::New();
240           m_GenericOptimizer->SetArgsInfo(m_ArgsInfo);
241           m_GenericOptimizer->SetMaximize(m_Maximize);
242           m_GenericOptimizer->SetNumberOfParameters(m_Initializer->GetTransform()->GetNumberOfParameters());
243
244
245           typedef itk::SingleValuedNonLinearOptimizer OptimizerType;
246           OptimizerType::Pointer optimizer = m_GenericOptimizer->GetOptimizerPointer();
247           optimizer->AddObserver( itk::IterationEvent(), m_CommandIterationUpdate);
248           registration->SetOptimizer(optimizer);
249           m_CommandIterationUpdate->SetOptimizer(m_GenericOptimizer);
250
251           // Set the previous transform parameters to the registration
252           // if(m_Initializer->m_Parameters!=NULL )delete m_Initializer->m_Parameters;
253           m_Initializer->SetInitialParameters(currentCoefficientImages, *newParameters);
254           registration->SetInitialTransformParametersOfNextLevel(*newParameters);
255         }
256       }
257
258       void Execute(const itk::Object * , const itk::EventObject & )
259       { return; }
260
261
262       // Members
263       void SetInitializer(InitializerPointer i){m_Initializer=i;}
264       InitializerPointer m_Initializer;
265
266       void SetArgsInfo(args_info_clitkBLUTDIR a){m_ArgsInfo=a;}
267       args_info_clitkBLUTDIR m_ArgsInfo;
268
269       void SetCommandIterationUpdate(CommandIterationUpdatePointer c){m_CommandIterationUpdate=c;};
270       CommandIterationUpdatePointer m_CommandIterationUpdate;
271
272       GenericOptimizerPointer m_GenericOptimizer;
273       void SetMaximize(bool b){m_Maximize=b;}
274       bool m_Maximize;
275
276       GenericMetricPointer m_GenericMetric;
277       void SetMetricRegion(RegionType i){m_MetricRegion=i;}
278       RegionType m_MetricRegion;
279
280
281   };
282
283   //==============================================================================
284   // Update with the number of dimensions and pixeltype
285   //==============================================================================
286   template<class InputImageType>
287     void BLUTDIRGenericFilter::UpdateWithInputImageType()
288     {
289       if (m_Verbose) std::cout << "BLUTDIRGenericFilter::UpdateWithInputImageType()" << std::endl;
290       
291       //=============================================================================
292       //Input
293       //=============================================================================
294       bool threadsGiven=m_ArgsInfo.threads_given;
295       int threads=m_ArgsInfo.threads_arg;
296       typedef typename InputImageType::PixelType PixelType;
297
298       typedef double TCoordRep;
299
300       typename InputImageType::Pointer fixedImage = this->template GetInput<InputImageType>(0);
301
302       typename InputImageType::Pointer inputFixedImage = this->template GetInput<InputImageType>(0);
303
304       // typedef input2
305       typename InputImageType::Pointer movingImage = this->template GetInput<InputImageType>(1);
306
307       typename InputImageType::Pointer inputMovingImage = this->template GetInput<InputImageType>(1);
308
309       typedef itk::Image< PixelType,InputImageType::ImageDimension >  FixedImageType;
310       typedef itk::Image< PixelType, InputImageType::ImageDimension>  MovingImageType;
311       const unsigned int SpaceDimension = InputImageType::ImageDimension;
312       //Whatever the pixel type, internally we work with an image represented in float
313       //Reading reference image
314       if (m_Verbose) std::cout<<"Reading images..."<<std::endl;
315       //=======================================================
316       //Input
317       //=======================================================
318       typename FixedImageType::Pointer croppedFixedImage=fixedImage;
319       //=======================================================
320       // Regions
321       //=======================================================
322       // The original input region
323       typename FixedImageType::RegionType fixedImageRegion = fixedImage->GetLargestPossibleRegion();
324
325       // The transform region with respect to the input region:
326       // where should the transform be DEFINED (depends on mask)
327       typename FixedImageType::RegionType transformRegion = fixedImage->GetLargestPossibleRegion();
328       typename FixedImageType::RegionType::SizeType transformRegionSize=transformRegion.GetSize();
329       typename FixedImageType::RegionType::IndexType transformRegionIndex=transformRegion.GetIndex();
330       typename FixedImageType::PointType transformRegionOrigin=fixedImage->GetOrigin();
331
332       // The metric region with respect to the extracted transform region:
333       // where should the metric be CALCULATED (depends on transform)
334       typename FixedImageType::RegionType metricRegion = fixedImage->GetLargestPossibleRegion();
335       typename FixedImageType::RegionType::IndexType metricRegionIndex=metricRegion.GetIndex();
336       typename FixedImageType::PointType metricRegionOrigin=fixedImage->GetOrigin();
337
338
339       //===========================================================================
340       // If given, we connect a mask to reference or target
341       //============================================================================
342       typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension >   MaskType;
343       typedef itk::Image< unsigned char, InputImageType::ImageDimension >   ImageLabelType;
344       typename MaskType::Pointer        fixedMask = NULL;
345       typename ImageLabelType::Pointer  labels = NULL;
346       if (m_ArgsInfo.referenceMask_given)
347       {
348         fixedMask = MaskType::New();
349         labels = ImageLabelType::New();
350         typedef itk::ImageFileReader< ImageLabelType >    LabelReaderType;
351         typename LabelReaderType::Pointer  labelReader = LabelReaderType::New();
352         labelReader->SetFileName(m_ArgsInfo.referenceMask_arg);
353         try
354         {
355           labelReader->Update();
356         }
357         catch( itk::ExceptionObject & err )
358         {
359           std::cerr << "ExceptionObject caught while reading mask or labels !" << std::endl;
360           std::cerr << err << std::endl;
361           return;
362         }
363         if (m_Verbose)std::cout <<"Reference image mask was read..." <<std::endl;
364
365         // Resample labels
366         typedef itk::ResampleImageFilter<ImageLabelType, ImageLabelType> ResamplerType;
367         typename ResamplerType::Pointer resampler = ResamplerType::New();
368         typedef itk::NearestNeighborInterpolateImageFunction<ImageLabelType, TCoordRep> InterpolatorType;
369         typename InterpolatorType::Pointer interpolator = InterpolatorType::New();
370         resampler->SetOutputParametersFromImage(fixedImage);
371         resampler->SetInterpolator(interpolator);
372         resampler->SetInput(labelReader->GetOutput());
373         resampler->Update();
374         labels = resampler->GetOutput();
375
376         // Set the image to the spatialObject
377         fixedMask->SetImage(labels);
378
379         // Find the bounding box of the "inside" label
380         typedef itk::LabelGeometryImageFilter<ImageLabelType> GeometryImageFilterType;
381         typename GeometryImageFilterType::Pointer geometryImageFilter=GeometryImageFilterType::New();
382         geometryImageFilter->SetInput(labels);
383         geometryImageFilter->Update();
384         typename GeometryImageFilterType::BoundingBoxType boundingBox = geometryImageFilter->GetBoundingBox(1);
385
386         // Limit the transform region to the mask
387         for (unsigned int i=0; i<InputImageType::ImageDimension; i++)
388         {
389           transformRegionIndex[i]=boundingBox[2*i];
390           transformRegionSize[i]=boundingBox[2*i+1]-boundingBox[2*i]+1;
391         }
392         transformRegion.SetSize(transformRegionSize);
393         transformRegion.SetIndex(transformRegionIndex);
394         fixedImage->TransformIndexToPhysicalPoint(transformRegion.GetIndex(), transformRegionOrigin);
395
396         // Crop the fixedImage to the bounding box to facilitate multi-resolution
397         typedef itk::ExtractImageFilter<FixedImageType,FixedImageType> ExtractImageFilterType;
398         typename ExtractImageFilterType::Pointer extractImageFilter=ExtractImageFilterType::New();
399         extractImageFilter->SetDirectionCollapseToSubmatrix();
400         extractImageFilter->SetInput(fixedImage);
401         extractImageFilter->SetExtractionRegion(transformRegion);
402         extractImageFilter->Update();
403         croppedFixedImage=extractImageFilter->GetOutput();
404
405         // Update the metric region
406         metricRegion = croppedFixedImage->GetLargestPossibleRegion();
407         metricRegionIndex=metricRegion.GetIndex();
408         croppedFixedImage->TransformIndexToPhysicalPoint(metricRegionIndex, metricRegionOrigin);
409
410         // Set start index to zero (with respect to croppedFixedImage/transform region)
411         metricRegionIndex.Fill(0);
412         metricRegion.SetIndex(metricRegionIndex);
413         croppedFixedImage->SetRegions(metricRegion);
414         croppedFixedImage->SetOrigin(metricRegionOrigin);
415       }
416
417       typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension >   MaskType;
418       typename MaskType::Pointer  movingMask=NULL;
419       if (m_ArgsInfo.targetMask_given)
420       {
421         movingMask= MaskType::New();
422         typedef itk::Image< unsigned char, InputImageType::ImageDimension >   ImageMaskType;
423         typedef itk::ImageFileReader< ImageMaskType >    MaskReaderType;
424         typename MaskReaderType::Pointer  maskReader = MaskReaderType::New();
425         maskReader->SetFileName(m_ArgsInfo.targetMask_arg);
426         try
427         {
428           maskReader->Update();
429         }
430         catch( itk::ExceptionObject & err )
431         {
432           std::cerr << "ExceptionObject caught !" << std::endl;
433           std::cerr << err << std::endl;
434         }
435         if (m_Verbose)std::cout <<"Target image mask was read..." <<std::endl;
436
437         movingMask->SetImage( maskReader->GetOutput() );
438       }
439
440
441       //=======================================================
442       // Output Regions
443       //=======================================================
444
445       if (m_Verbose)
446       {
447         // Fixed image region
448         std::cout<<"The fixed image has its origin at "<<fixedImage->GetOrigin()<<std::endl
449           <<"The fixed image region starts at index "<<fixedImageRegion.GetIndex()<<std::endl
450           <<"The fixed image region has size "<< fixedImageRegion.GetSize()<<std::endl;
451
452         // Transform region
453         std::cout<<"The transform has its origin at "<<transformRegionOrigin<<std::endl
454           <<"The transform region will start at index "<<transformRegion.GetIndex()<<std::endl
455           <<"The transform region has size "<< transformRegion.GetSize()<<std::endl;
456
457         // Metric region
458         std::cout<<"The metric region has its origin at "<<metricRegionOrigin<<std::endl
459           <<"The metric region will start at index "<<metricRegion.GetIndex()<<std::endl
460           <<"The metric region has size "<< metricRegion.GetSize()<<std::endl;
461
462       }
463
464
465       //=======================================================
466       // Pyramids (update them for transform initializer)
467       //=======================================================
468       typedef itk::RecursiveMultiResolutionPyramidImageFilter< FixedImageType, FixedImageType>    FixedImagePyramidType;
469       typedef itk::RecursiveMultiResolutionPyramidImageFilter< MovingImageType, MovingImageType>    MovingImagePyramidType;
470       typename FixedImagePyramidType::Pointer fixedImagePyramid = FixedImagePyramidType::New();
471       typename MovingImagePyramidType::Pointer movingImagePyramid = MovingImagePyramidType::New();
472       fixedImagePyramid->SetUseShrinkImageFilter(false);
473       fixedImagePyramid->SetInput(croppedFixedImage);
474       fixedImagePyramid->SetNumberOfLevels(m_ArgsInfo.levels_arg);
475       movingImagePyramid->SetUseShrinkImageFilter(false);
476       movingImagePyramid->SetInput(movingImage);
477       movingImagePyramid->SetNumberOfLevels(m_ArgsInfo.levels_arg);
478       if (m_Verbose) std::cout<<"Creating the image pyramid..."<<std::endl;
479       fixedImagePyramid->Update();
480       movingImagePyramid->Update();
481       typedef clitk::MultiResolutionPyramidRegionFilter<FixedImageType> FixedImageRegionPyramidType;
482       typename FixedImageRegionPyramidType::Pointer fixedImageRegionPyramid=FixedImageRegionPyramidType::New();
483       fixedImageRegionPyramid->SetRegion(metricRegion);
484       fixedImageRegionPyramid->SetSchedule(fixedImagePyramid->GetSchedule());
485
486
487       //=======================================================
488       // Rigid or Affine Transform
489       //=======================================================
490       typedef itk::AffineTransform <double,3> RigidTransformType;
491       RigidTransformType::Pointer rigidTransform=NULL;
492       if (m_ArgsInfo.initMatrix_given)
493       {
494         if(m_Verbose) std::cout<<"Reading the prior transform matrix "<< m_ArgsInfo.initMatrix_arg<<"..."<<std::endl;
495         rigidTransform=RigidTransformType::New();
496         itk::Matrix<double,4,4> rigidTransformMatrix=clitk::ReadMatrix3D(m_ArgsInfo.initMatrix_arg);
497
498         //Set the rotation
499         itk::Matrix<double,3,3> finalRotation = clitk::GetRotationalPartMatrix3D(rigidTransformMatrix);
500         rigidTransform->SetMatrix(finalRotation);
501
502         //Set the translation
503         itk::Vector<double,3> finalTranslation = clitk::GetTranslationPartMatrix3D(rigidTransformMatrix);
504         rigidTransform->SetTranslation(finalTranslation);
505       }
506       else if (m_ArgsInfo.centre_flag)
507       {
508         if(m_Verbose) std::cout<<"No itinial matrix given and \"centre\" flag switched on. Centering all images..."<<std::endl;
509         
510         rigidTransform=RigidTransformType::New();
511         
512         typedef itk::CenteredTransformInitializer<RigidTransformType, FixedImageType, MovingImageType > TransformInitializerType;
513         typename TransformInitializerType::Pointer initializer = TransformInitializerType::New();
514         initializer->SetTransform( rigidTransform );
515         initializer->SetFixedImage( fixedImage );
516         initializer->SetMovingImage( movingImage );        
517         initializer->GeometryOn();
518         initializer->InitializeTransform();
519       }
520
521
522       //=======================================================
523       // B-LUT FFD Transform
524       //=======================================================
525       typedef  clitk::MultipleBSplineDeformableTransform<TCoordRep,InputImageType::ImageDimension, SpaceDimension > TransformType;
526       typename TransformType::Pointer transform = TransformType::New();
527       if (labels) transform->SetLabels(labels);
528       if (rigidTransform) transform->SetBulkTransform(rigidTransform);
529
530       //-------------------------------------------------------------------------
531       // The transform initializer
532       //-------------------------------------------------------------------------
533       typedef clitk::MultipleBSplineDeformableTransformInitializer< TransformType,FixedImageType> InitializerType;
534       typename InitializerType::Pointer initializer = InitializerType::New();
535       initializer->SetVerbose(m_Verbose);
536       initializer->SetImage(fixedImagePyramid->GetOutput(0));
537       initializer->SetTransform(transform);
538
539       //-------------------------------------------------------------------------
540       // Order
541       //-------------------------------------------------------------------------
542       typename FixedImageType::RegionType::SizeType splineOrders ;
543       splineOrders.Fill(3);
544       if (m_ArgsInfo.order_given)
545         for(unsigned int i=0; i<InputImageType::ImageDimension;i++)
546           splineOrders[i]=m_ArgsInfo.order_arg[i];
547       if (m_Verbose) std::cout<<"Setting the spline orders  to "<<splineOrders<<"..."<<std::endl;
548       initializer->SetSplineOrders(splineOrders);
549
550       //-------------------------------------------------------------------------
551       // Levels
552       //-------------------------------------------------------------------------
553
554       // Spacing
555       if (m_ArgsInfo.spacing_given)
556       {
557         initializer->m_ControlPointSpacingArray.resize(m_ArgsInfo.levels_arg);
558         initializer->SetControlPointSpacing(m_ArgsInfo.spacing_arg);
559         initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1]=initializer->m_ControlPointSpacing;
560         if (m_Verbose) std::cout<<"Using a control point spacing of "<<initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1]
561           <<" at level "<<m_ArgsInfo.levels_arg<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
562
563         for (int i=1; i<m_ArgsInfo.levels_arg; i++ )
564         {
565           initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1-i]=initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-i]*2;
566           if (m_Verbose) std::cout<<"Using a control point spacing of "<<initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1-i]
567             <<" at level "<<m_ArgsInfo.levels_arg-i<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
568         }
569
570       }
571
572       // Control
573       if (m_ArgsInfo.control_given)
574       {
575         initializer->m_NumberOfControlPointsInsideTheImageArray.resize(m_ArgsInfo.levels_arg);
576         initializer->SetNumberOfControlPointsInsideTheImage(m_ArgsInfo.control_arg);
577         initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1]=initializer->m_NumberOfControlPointsInsideTheImage;
578         if (m_Verbose) std::cout<<"Using "<< initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1]<<"control points inside the image"
579           <<" at level "<<m_ArgsInfo.levels_arg<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
580
581         for (int i=1; i<m_ArgsInfo.levels_arg; i++ )
582         {
583           for(unsigned int j=0;j<InputImageType::ImageDimension;j++)
584             initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i][j]=ceil ((double)initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-i][j]/2.);
585           //        initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i]=ceil ((double)initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-i]/2.);
586           if (m_Verbose) std::cout<<"Using "<< initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i]<<"control points inside the image"
587             <<" at level "<<m_ArgsInfo.levels_arg<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
588
589         }
590       }
591
592       // Inialize on the first level
593       if (m_ArgsInfo.verbose_flag) std::cout<<"Initializing transform for level 1 of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
594       if (m_ArgsInfo.spacing_given) initializer->SetControlPointSpacing(        initializer->m_ControlPointSpacingArray[0]);
595       if (m_ArgsInfo.control_given) initializer->SetNumberOfControlPointsInsideTheImage(initializer->m_NumberOfControlPointsInsideTheImageArray[0]);
596       if (m_ArgsInfo.samplingFactor_given) initializer->SetSamplingFactors(m_ArgsInfo.samplingFactor_arg);
597
598       // Initialize
599       initializer->InitializeTransform();
600
601       //-------------------------------------------------------------------------
602       // Initial parameters (passed by reference)
603       //-------------------------------------------------------------------------
604       typedef typename TransformType::ParametersType     ParametersType;
605       const unsigned int numberOfParameters =    transform->GetNumberOfParameters();
606       ParametersType parameters(numberOfParameters);
607       parameters.Fill( 0.0 );
608       transform->SetParameters( parameters );
609       if (m_ArgsInfo.initCoeff_given) initializer->SetInitialParameters(m_ArgsInfo.initCoeff_arg, parameters);
610
611       //-------------------------------------------------------------------------
612       // DEBUG: use an itk BSpline instead of multilabel BLUTs
613       //-------------------------------------------------------------------------
614       typedef itk::Transform< TCoordRep, 3, 3 > RegistrationTransformType;
615       RegistrationTransformType::Pointer regTransform(transform);
616       typedef itk::BSplineDeformableTransform<TCoordRep,SpaceDimension, 3> SingleBSplineTransformType;
617       typename SingleBSplineTransformType::Pointer sTransform;
618       if(m_ArgsInfo.itkbspline_flag) {
619         if( transform->GetTransforms().size()>1)
620           itkExceptionMacro(<< "invalid --itkbspline option if there is more than 1 label")
621         sTransform = SingleBSplineTransformType::New();
622         sTransform->SetBulkTransform( transform->GetTransforms()[0]->GetBulkTransform() );
623         sTransform->SetGridSpacing( transform->GetTransforms()[0]->GetGridSpacing() );
624         sTransform->SetGridOrigin( transform->GetTransforms()[0]->GetGridOrigin() );
625         sTransform->SetGridRegion( transform->GetTransforms()[0]->GetGridRegion() );
626         sTransform->SetParameters( transform->GetTransforms()[0]->GetParameters() );
627         regTransform = sTransform;
628         transform = NULL; // free memory
629       }
630
631       //=======================================================
632       // Interpolator
633       //=======================================================
634       typedef clitk::GenericInterpolator<args_info_clitkBLUTDIR, FixedImageType, TCoordRep > GenericInterpolatorType;
635       typename   GenericInterpolatorType::Pointer genericInterpolator=GenericInterpolatorType::New();
636       genericInterpolator->SetArgsInfo(m_ArgsInfo);
637       typedef itk::InterpolateImageFunction< FixedImageType, TCoordRep >  InterpolatorType;
638       typename  InterpolatorType::Pointer interpolator=genericInterpolator->GetInterpolatorPointer();
639
640
641       //=======================================================
642       // Metric
643       //=======================================================
644       typedef clitk::GenericMetric<args_info_clitkBLUTDIR, FixedImageType,MovingImageType > GenericMetricType;
645       typename GenericMetricType::Pointer genericMetric=GenericMetricType::New();
646       genericMetric->SetArgsInfo(m_ArgsInfo);
647       genericMetric->SetFixedImage(fixedImagePyramid->GetOutput(0));
648       if (fixedMask) genericMetric->SetFixedImageMask(fixedMask);
649       genericMetric->SetFixedImageRegion(fixedImageRegionPyramid->GetOutput(0));
650       typedef itk::ImageToImageMetric< FixedImageType, MovingImageType >  MetricType;
651       typename  MetricType::Pointer metric=genericMetric->GetMetricPointer();
652       if (movingMask) metric->SetMovingImageMask(movingMask);
653       if (threadsGiven) {
654         metric->SetNumberOfThreads( threads );
655         if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl;
656       }
657
658       //=======================================================
659       // Optimizer
660       //=======================================================
661       typedef clitk::GenericOptimizer<args_info_clitkBLUTDIR> GenericOptimizerType;
662       GenericOptimizerType::Pointer genericOptimizer = GenericOptimizerType::New();
663       genericOptimizer->SetArgsInfo(m_ArgsInfo);
664       genericOptimizer->SetMaximize(genericMetric->GetMaximize());
665       genericOptimizer->SetNumberOfParameters(regTransform->GetNumberOfParameters());
666       typedef itk::SingleValuedNonLinearOptimizer OptimizerType;
667       OptimizerType::Pointer optimizer = genericOptimizer->GetOptimizerPointer();
668
669
670       //=======================================================
671       // Registration
672       //=======================================================
673       typedef itk::MultiResolutionImageRegistrationMethod<  FixedImageType, MovingImageType >    RegistrationType;
674       typename RegistrationType::Pointer   registration  = RegistrationType::New();
675       registration->SetMetric(        metric        );
676       registration->SetOptimizer(     optimizer     );
677       registration->SetInterpolator(  interpolator  );
678       registration->SetTransform (regTransform );
679       if(threadsGiven) {
680         registration->SetNumberOfThreads(threads);
681         if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl;
682       }
683       registration->SetFixedImage(  croppedFixedImage   );
684       registration->SetMovingImage(  movingImage   );
685       registration->SetFixedImageRegion( metricRegion );
686       registration->SetFixedImagePyramid( fixedImagePyramid );
687       registration->SetMovingImagePyramid( movingImagePyramid );
688       registration->SetInitialTransformParameters( regTransform->GetParameters() );
689       registration->SetNumberOfLevels( m_ArgsInfo.levels_arg );
690       if (m_Verbose) std::cout<<"Setting the number of resolution levels to "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
691
692
693       //================================================================================================
694       // Observers
695       //================================================================================================
696       if (m_Verbose)
697       {
698         // Output iteration info
699         CommandIterationUpdate::Pointer observer = CommandIterationUpdate::New();
700         observer->SetOptimizer(genericOptimizer);
701         optimizer->AddObserver( itk::IterationEvent(), observer );
702
703         // Output level info
704         typedef RegistrationInterfaceCommand<RegistrationType,args_info_clitkBLUTDIR> CommandType;
705         typename CommandType::Pointer command = CommandType::New();
706         command->SetInitializer(initializer);
707         command->SetArgsInfo(m_ArgsInfo);
708         command->SetCommandIterationUpdate(observer);
709         command->SetMaximize(genericMetric->GetMaximize());
710         command->SetMetricRegion(metricRegion);
711         registration->AddObserver( itk::IterationEvent(), command );
712
713         if (m_ArgsInfo.coeff_given)
714         {
715           if(m_ArgsInfo.itkbspline_flag) {
716             itkExceptionMacro("--coeff and --itkbpline are incompatible");
717           }
718
719           std::cout << std::endl << "Output coefficient images every " << m_ArgsInfo.coeffEveryN_arg << " iterations." << std::endl;
720           typedef CommandIterationUpdateDVF<FixedImageType, OptimizerType, TransformType> DVFCommandType;
721           typename DVFCommandType::Pointer observerdvf = DVFCommandType::New();
722           observerdvf->SetFixedImage(fixedImage);
723           observerdvf->SetTransform(transform);
724           observerdvf->SetArgsInfo(m_ArgsInfo);
725           optimizer->AddObserver( itk::IterationEvent(), observerdvf );
726         }
727       }
728
729
730       //=======================================================
731       // Let's go
732       //=======================================================
733       if (m_Verbose) std::cout << std::endl << "Starting Registration" << std::endl;
734
735       try
736       {
737         registration->Update();
738       }
739       catch( itk::ExceptionObject & err )
740       {
741         std::cerr << "ExceptionObject caught while registering!" << std::endl;
742         std::cerr << err << std::endl;
743         return;
744       }
745
746
747       //=======================================================
748       // Get the result
749       //=======================================================
750       OptimizerType::ParametersType finalParameters =  registration->GetLastTransformParameters();
751       regTransform->SetParameters( finalParameters );
752       if (m_Verbose)
753       {
754         std::cout<<"Stop condition description: "
755           <<registration->GetOptimizer()->GetStopConditionDescription()<<std::endl;
756       }
757
758
759       //=======================================================
760       // Get the BSpline coefficient images and write them
761       //=======================================================
762       if (m_ArgsInfo.coeff_given)
763       {
764         typedef typename TransformType::CoefficientImageType CoefficientImageType;
765         std::vector<typename CoefficientImageType::Pointer> coefficientImages = transform->GetCoefficientImages();
766         typedef itk::ImageFileWriter<CoefficientImageType> CoeffWriterType;
767         typename CoeffWriterType::Pointer coeffWriter = CoeffWriterType::New();
768         unsigned nLabels = transform->GetnLabels();
769
770         std::string fname(m_ArgsInfo.coeff_arg);
771         int dotpos = fname.length() - 1;
772         while (dotpos >= 0 && fname[dotpos] != '.')
773           dotpos--;
774
775         for (unsigned i = 0; i < nLabels; ++i)
776         {
777           std::ostringstream osfname;
778           osfname << fname.substr(0, dotpos) << '_' << i << fname.substr(dotpos);
779           coeffWriter->SetInput(coefficientImages[i]);
780           coeffWriter->SetFileName(osfname.str());
781           coeffWriter->Update();
782         }
783       }
784
785
786
787       //=======================================================
788       // Compute the DVF (only deformable transform)
789       //=======================================================
790       typedef itk::Vector< float, SpaceDimension >  DisplacementType;
791       typedef itk::Image< DisplacementType, InputImageType::ImageDimension >  DisplacementFieldType;
792 #if ITK_VERSION_MAJOR >= 4
793 #  if ITK_VERSION_MINOR < 6
794       typedef itk::TransformToDisplacementFieldSource<DisplacementFieldType, double> ConvertorType;
795 #  else
796       typedef itk::TransformToDisplacementFieldFilter<DisplacementFieldType, double> ConvertorType;
797 #  endif
798 #endif
799       typename ConvertorType::Pointer filter= ConvertorType::New();
800       filter->SetNumberOfThreads(1);
801       if(m_ArgsInfo.itkbspline_flag)
802         sTransform->SetBulkTransform(NULL);
803       else
804         transform->SetBulkTransform(NULL);
805       filter->SetTransform(regTransform);
806 #if ITK_VERSION_MAJOR > 4 || (ITK_VERSION_MAJOR == 4 && ITK_VERSION_MINOR >= 6)
807       filter->SetReferenceImage(fixedImage);
808 #else
809       filter->SetOutputParametersFromImage(fixedImage);
810 #endif
811       filter->Update();
812       typename DisplacementFieldType::Pointer field = filter->GetOutput();
813
814
815       //=======================================================
816       // Write the DVF
817       //=======================================================
818       typedef itk::ImageFileWriter< DisplacementFieldType >  FieldWriterType;
819       typename FieldWriterType::Pointer fieldWriter = FieldWriterType::New();
820       fieldWriter->SetFileName( m_ArgsInfo.vf_arg );
821       fieldWriter->SetInput( field );
822       try
823       {
824         fieldWriter->Update();
825       }
826       catch( itk::ExceptionObject & excp )
827       {
828         std::cerr << "Exception thrown writing the DVF" << std::endl;
829         std::cerr << excp << std::endl;
830         return;
831       }
832
833
834       //=======================================================
835       // Resample the moving image
836       //=======================================================
837       typedef itk::ResampleImageFilter< MovingImageType, FixedImageType >    ResampleFilterType;
838       typename ResampleFilterType::Pointer resampler = ResampleFilterType::New();
839       if (rigidTransform) {
840         if(m_ArgsInfo.itkbspline_flag)
841           sTransform->SetBulkTransform(rigidTransform);
842         else
843           transform->SetBulkTransform(rigidTransform);
844       }
845       resampler->SetTransform( regTransform );
846       resampler->SetInput( movingImage);
847       resampler->SetOutputParametersFromImage(fixedImage);
848       resampler->Update();
849       typename FixedImageType::Pointer result=resampler->GetOutput();
850
851       //     typedef itk::WarpImageFilter< MovingImageType, FixedImageType, DeformationFieldType >    WarpFilterType;
852       //     typename WarpFilterType::Pointer warp = WarpFilterType::New();
853
854       //     warp->SetDeformationField( field );
855       //     warp->SetInput( movingImageReader->GetOutput() );
856       //     warp->SetOutputOrigin(  fixedImage->GetOrigin() );
857       //     warp->SetOutputSpacing( fixedImage->GetSpacing() );
858       //     warp->SetOutputDirection( fixedImage->GetDirection() );
859       //     warp->SetEdgePaddingValue( 0.0 );
860       //     warp->Update();
861
862
863       //=======================================================
864       // Write the warped image
865       //=======================================================
866       typedef itk::ImageFileWriter< FixedImageType >  WriterType;
867       typename WriterType::Pointer      writer =  WriterType::New();
868       writer->SetFileName( m_ArgsInfo.output_arg );
869       writer->SetInput( result    );
870
871       try
872       {
873         writer->Update();
874       }
875       catch( itk::ExceptionObject & err )
876       {
877         std::cerr << "ExceptionObject caught !" << std::endl;
878         std::cerr << err << std::endl;
879         return;
880       }
881
882
883       //=======================================================
884       // Calculate the difference after the deformable transform
885       //=======================================================
886       typedef clitk::DifferenceImageFilter<  FixedImageType, FixedImageType> DifferenceFilterType;
887       if (m_ArgsInfo.after_given)
888       {
889         typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New();
890         difference->SetValidInput( fixedImage );
891         difference->SetTestInput( result );
892
893         try
894         {
895           difference->Update();
896         }
897         catch( itk::ExceptionObject & err )
898         {
899           std::cerr << "ExceptionObject caught calculating the difference !" << std::endl;
900           std::cerr << err << std::endl;
901           return;
902         }
903
904         typename WriterType::Pointer differenceWriter=WriterType::New();
905         differenceWriter->SetInput(difference->GetOutput());
906         differenceWriter->SetFileName(m_ArgsInfo.after_arg);
907         differenceWriter->Update();
908
909       }
910
911
912       //=======================================================
913       // Calculate the difference before the deformable transform
914       //=======================================================
915       if( m_ArgsInfo.before_given )
916       {
917
918         typename FixedImageType::Pointer moving=FixedImageType::New();
919         if (m_ArgsInfo.initMatrix_given)
920         {
921           typedef itk::ResampleImageFilter<MovingImageType, FixedImageType> ResamplerType;
922           typename ResamplerType::Pointer resampler=ResamplerType::New();
923           resampler->SetInput(movingImage);
924           resampler->SetOutputOrigin(fixedImage->GetOrigin());
925           resampler->SetSize(fixedImage->GetLargestPossibleRegion().GetSize());
926           resampler->SetOutputSpacing(fixedImage->GetSpacing());
927           resampler->SetDefaultPixelValue( 0. );
928           if (rigidTransform ) resampler->SetTransform(rigidTransform);
929           resampler->Update();
930           moving=resampler->GetOutput();
931         }
932         else
933           moving=movingImage;
934
935         typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New();
936         difference->SetValidInput( fixedImage );
937         difference->SetTestInput( moving );
938
939         try
940         {
941           difference->Update();
942         }
943         catch( itk::ExceptionObject & err )
944         {
945           std::cerr << "ExceptionObject caught calculating the difference !" << std::endl;
946           std::cerr << err << std::endl;
947           return;
948         }
949
950         typename WriterType::Pointer differenceWriter=WriterType::New();
951         writer->SetFileName( m_ArgsInfo.before_arg  );
952         writer->SetInput( difference->GetOutput()  );
953         writer->Update( );
954       }
955
956       return;
957
958     }
959 }//end clitk
960
961 #endif // #define clitkBLUTDIRGenericFilter_txx