]> Creatis software - clitk.git/blob - registration/clitkBLUTDIRGenericFilter.cxx
added initial image centralization to BLUTDIR
[clitk.git] / registration / clitkBLUTDIRGenericFilter.cxx
1 /*=========================================================================
2 Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4 Authors belong to:
5 - University of LYON              http://www.universite-lyon.fr/
6 - Léon Bérard cancer center       http://www.centreleonberard.fr
7 - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9 This software is distributed WITHOUT ANY WARRANTY; without even
10 the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11 PURPOSE.  See the copyright notices for more information.
12
13 It is distributed under dual licence
14
15 - BSD        See included LICENSE.txt file
16 - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef clitkBLUTDIRGenericFilter_cxx
19 #define clitkBLUTDIRGenericFilter_cxx
20
21 /* =================================================
22  * @file   clitkBLUTDIRGenericFilter.cxx
23  * @author
24  * @date
25  *
26  * @brief
27  *
28  ===================================================*/
29
30 #include "clitkBLUTDIRGenericFilter.h"
31 #include "clitkBLUTDIRCommandIterationUpdateDVF.h"
32 #include "itkCenteredTransformInitializer.h"
33   
34 namespace clitk
35 {
36
37   //==============================================================================
38   // Creating an observer class that allows output at each iteration
39   //==============================================================================
40   class CommandIterationUpdate : public itk::Command
41   {
42     public:
43       typedef  CommandIterationUpdate   Self;
44       typedef  itk::Command             Superclass;
45       typedef  itk::SmartPointer<Self>  Pointer;
46       itkNewMacro( Self );
47     protected:
48       CommandIterationUpdate() {};
49     public:
50       typedef   clitk::GenericOptimizer<args_info_clitkBLUTDIR>     OptimizerType;
51       typedef   const OptimizerType   *           OptimizerPointer;
52
53       // Set the generic optimizer
54       void SetOptimizer(OptimizerPointer o){m_Optimizer=o;}
55
56       // Execute
57       void Execute(itk::Object *caller, const itk::EventObject & event)
58       {
59         Execute( (const itk::Object *)caller, event);
60       }
61
62       void Execute(const itk::Object * object, const itk::EventObject & event)
63       {
64         if( !(itk::IterationEvent().CheckEvent( &event )) )
65         {
66           return;
67         }
68
69         m_Optimizer->OutputIterationInfo();
70       }
71
72       OptimizerPointer m_Optimizer;
73   };
74
75   //===========================================================================//
76   //Constructor
77   //==========================================================================//
78   BLUTDIRGenericFilter::BLUTDIRGenericFilter():
79     ImageToImageGenericFilter<Self>("Register DIR")
80   {
81     InitializeImageType<2>();
82     InitializeImageType<3>();
83     m_Verbose=false;
84   }
85
86   //=========================================================================//
87   //SetArgsInfo
88   //==========================================================================//
89   void BLUTDIRGenericFilter::SetArgsInfo(const args_info_clitkBLUTDIR & a){
90     m_ArgsInfo=a;
91     if (m_ArgsInfo.reference_given) AddInputFilename(m_ArgsInfo.reference_arg);
92
93     if (m_ArgsInfo.target_given) {
94       AddInputFilename(m_ArgsInfo.target_arg);
95     }
96
97     if (m_ArgsInfo.output_given) SetOutputFilename(m_ArgsInfo.output_arg);
98     
99     if (m_ArgsInfo.verbose_given) m_Verbose=true;
100   }
101
102   //=========================================================================//
103   //===========================================================================//
104   template<unsigned int Dim>
105     void BLUTDIRGenericFilter::InitializeImageType()
106     {
107       ADD_DEFAULT_IMAGE_TYPES(3);
108     }
109   //--------------------------------------------------------------------
110
111   //==============================================================================
112   //Creating an observer class that allows us to change parameters at subsequent levels
113   //==============================================================================
114   template <typename TRegistration,class args_info_clitkBLUTDIR>
115     class RegistrationInterfaceCommand : public itk::Command
116   {
117     public:
118       typedef RegistrationInterfaceCommand   Self;
119       typedef itk::Command             Superclass;
120       typedef itk::SmartPointer<Self>  Pointer;
121       itkNewMacro( Self );
122     protected:
123       RegistrationInterfaceCommand() { };
124     public:
125
126       // Registration
127       typedef   TRegistration                              RegistrationType;
128       typedef   RegistrationType *                         RegistrationPointer;
129
130       // Transform
131       typedef typename RegistrationType::FixedImageType FixedImageType;
132       typedef typename FixedImageType::RegionType RegionType;
133       itkStaticConstMacro(ImageDimension, unsigned int,FixedImageType::ImageDimension);
134       typedef clitk::MultipleBSplineDeformableTransform<double, ImageDimension, ImageDimension> TransformType;
135       typedef clitk::MultipleBSplineDeformableTransformInitializer<TransformType, FixedImageType> InitializerType;
136       typedef typename InitializerType::CoefficientImageType CoefficientImageType;
137       typedef itk::CastImageFilter<CoefficientImageType, CoefficientImageType> CastImageFilterType;
138       typedef typename TransformType::ParametersType ParametersType;
139       typedef typename InitializerType::Pointer InitializerPointer;
140       typedef   CommandIterationUpdate::Pointer CommandIterationUpdatePointer;
141
142       // Optimizer
143       typedef clitk::GenericOptimizer<args_info_clitkBLUTDIR> GenericOptimizerType;
144       typedef typename GenericOptimizerType::Pointer GenericOptimizerPointer;
145
146       // Metric
147       typedef typename RegistrationType::FixedImageType    InternalImageType;
148       typedef clitk::GenericMetric<args_info_clitkBLUTDIR, InternalImageType, InternalImageType> GenericMetricType;
149       typedef typename GenericMetricType::Pointer GenericMetricPointer;
150
151       // Two arguments are passed to the Execute() method: the first
152       // is the pointer to the object which invoked the event and the
153       // second is the event that was invoked.
154       void Execute(itk::Object * object, const itk::EventObject & event)
155       {
156         if( !(itk::IterationEvent().CheckEvent( &event )) )
157         {
158           return;
159         }
160
161         // Get the levels
162         RegistrationPointer registration = dynamic_cast<RegistrationPointer>( object );
163         unsigned int numberOfLevels=registration->GetNumberOfLevels();
164         unsigned int currentLevel=registration->GetCurrentLevel()+1;
165
166         // Output the levels
167         std::cout<<std::endl;
168         std::cout<<"========================================"<<std::endl;
169         std::cout<<"Starting resolution level "<<currentLevel<<" of "<<numberOfLevels<<"..."<<std::endl;
170         std::cout<<"========================================"<<std::endl;
171         std::cout<<std::endl;
172
173         // Higher level?
174         if (currentLevel>1)
175         {
176           // fixed image region pyramid
177           typedef clitk::MultiResolutionPyramidRegionFilter<InternalImageType> FixedImageRegionPyramidType;
178           typename FixedImageRegionPyramidType::Pointer fixedImageRegionPyramid=FixedImageRegionPyramidType::New();
179           fixedImageRegionPyramid->SetRegion(m_MetricRegion);
180           fixedImageRegionPyramid->SetSchedule(registration->GetFixedImagePyramid()->GetSchedule());
181
182           // Reinitialize the metric (!= number of samples)
183           m_GenericMetric= GenericMetricType::New();
184           m_GenericMetric->SetArgsInfo(m_ArgsInfo);
185           m_GenericMetric->SetFixedImage(registration->GetFixedImagePyramid()->GetOutput(registration->GetCurrentLevel()));
186           if (m_ArgsInfo.referenceMask_given)  m_GenericMetric->SetFixedImageMask(registration->GetMetric()->GetFixedImageMask());
187           m_GenericMetric->SetFixedImageRegion(fixedImageRegionPyramid->GetOutput(registration->GetCurrentLevel()));
188           typedef itk::ImageToImageMetric< InternalImageType, InternalImageType >  MetricType;
189           typename  MetricType::Pointer metric=m_GenericMetric->GetMetricPointer();
190           registration->SetMetric(metric);
191
192           // Get the current coefficient image and make a COPY
193           typename itk::ImageDuplicator<CoefficientImageType>::Pointer caster = itk::ImageDuplicator<CoefficientImageType>::New();
194           std::vector<typename CoefficientImageType::Pointer> currentCoefficientImages = m_Initializer->GetTransform()->GetCoefficientImages();
195           for (unsigned i = 0; i < currentCoefficientImages.size(); ++i)
196           {
197             caster->SetInputImage(currentCoefficientImages[i]);
198             caster->Update();
199             currentCoefficientImages[i] = caster->GetOutput();
200           }
201
202           /*
203           // Write the intermediate result?
204           if (m_ArgsInfo.intermediate_given>=numberOfLevels)
205             writeImage<CoefficientImageType>(currentCoefficientImage, m_ArgsInfo.intermediate_arg[currentLevel-2], m_ArgsInfo.verbose_flag);
206             */
207
208           // Set the new transform properties
209           m_Initializer->SetImage(registration->GetFixedImagePyramid()->GetOutput(currentLevel-1));
210           if( m_Initializer->m_ControlPointSpacingIsGiven)
211             m_Initializer->SetControlPointSpacing(m_Initializer->m_ControlPointSpacingArray[registration->GetCurrentLevel()]);
212           if( m_Initializer->m_NumberOfControlPointsIsGiven)
213             m_Initializer->SetNumberOfControlPointsInsideTheImage(m_Initializer->m_NumberOfControlPointsInsideTheImageArray[registration->GetCurrentLevel()]);
214
215           // Reinitialize the transform
216           if (m_ArgsInfo.verbose_flag) std::cout<<"Initializing transform for level "<<currentLevel<<" of "<<numberOfLevels<<"..."<<std::endl;
217           m_Initializer->InitializeTransform();
218           ParametersType* newParameters= new typename TransformType::ParametersType(m_Initializer->GetTransform()->GetNumberOfParameters());
219
220           // DS : if we want to skip the last pyramid level, force to only 1 iteration
221           DD(m_ArgsInfo.skipLastPyramidLevel_flag);
222           if ((currentLevel == numberOfLevels) && (m_ArgsInfo.skipLastPyramidLevel_flag)) {
223             DD(m_ArgsInfo.maxIt_arg);
224             std::cout << "I skip the last pyramid level : set max iteration to 0" << std::endl;
225             m_ArgsInfo.maxIt_arg = 0;
226             DD(m_ArgsInfo.maxIt_arg);
227           }
228
229           // Reinitialize an Optimizer (!= number of parameters)
230           m_GenericOptimizer = GenericOptimizerType::New();
231           m_GenericOptimizer->SetArgsInfo(m_ArgsInfo);
232           m_GenericOptimizer->SetMaximize(m_Maximize);
233           m_GenericOptimizer->SetNumberOfParameters(m_Initializer->GetTransform()->GetNumberOfParameters());
234
235
236           typedef itk::SingleValuedNonLinearOptimizer OptimizerType;
237           OptimizerType::Pointer optimizer = m_GenericOptimizer->GetOptimizerPointer();
238           optimizer->AddObserver( itk::IterationEvent(), m_CommandIterationUpdate);
239           registration->SetOptimizer(optimizer);
240           m_CommandIterationUpdate->SetOptimizer(m_GenericOptimizer);
241
242           // Set the previous transform parameters to the registration
243           // if(m_Initializer->m_Parameters!=NULL )delete m_Initializer->m_Parameters;
244           m_Initializer->SetInitialParameters(currentCoefficientImages, *newParameters);
245           registration->SetInitialTransformParametersOfNextLevel(*newParameters);
246         }
247       }
248
249       void Execute(const itk::Object * , const itk::EventObject & )
250       { return; }
251
252
253       // Members
254       void SetInitializer(InitializerPointer i){m_Initializer=i;}
255       InitializerPointer m_Initializer;
256
257       void SetArgsInfo(args_info_clitkBLUTDIR a){m_ArgsInfo=a;}
258       args_info_clitkBLUTDIR m_ArgsInfo;
259
260       void SetCommandIterationUpdate(CommandIterationUpdatePointer c){m_CommandIterationUpdate=c;};
261       CommandIterationUpdatePointer m_CommandIterationUpdate;
262
263       GenericOptimizerPointer m_GenericOptimizer;
264       void SetMaximize(bool b){m_Maximize=b;}
265       bool m_Maximize;
266
267       GenericMetricPointer m_GenericMetric;
268       void SetMetricRegion(RegionType i){m_MetricRegion=i;}
269       RegionType m_MetricRegion;
270
271
272   };
273
274   //==============================================================================
275   // Update with the number of dimensions and pixeltype
276   //==============================================================================
277   template<class InputImageType>
278     void BLUTDIRGenericFilter::UpdateWithInputImageType()
279     {
280       if (m_Verbose) std::cout << "BLUTDIRGenericFilter::UpdateWithInputImageType()" << std::endl;
281       
282       //=============================================================================
283       //Input
284       //=============================================================================
285       bool threadsGiven=m_ArgsInfo.threads_given;
286       int threads=m_ArgsInfo.threads_arg;
287       typedef typename InputImageType::PixelType PixelType;
288
289       typedef double TCoordRep;
290
291       typename InputImageType::Pointer fixedImage = this->template GetInput<InputImageType>(0);
292
293       typename InputImageType::Pointer inputFixedImage = this->template GetInput<InputImageType>(0);
294
295       // typedef input2
296       typename InputImageType::Pointer movingImage = this->template GetInput<InputImageType>(1);
297
298       typename InputImageType::Pointer inputMovingImage = this->template GetInput<InputImageType>(1);
299
300       typedef itk::Image< PixelType,InputImageType::ImageDimension >  FixedImageType;
301       typedef itk::Image< PixelType, InputImageType::ImageDimension>  MovingImageType;
302       const unsigned int SpaceDimension = InputImageType::ImageDimension;
303       //Whatever the pixel type, internally we work with an image represented in float
304       //Reading reference image
305       if (m_Verbose) std::cout<<"Reading images..."<<std::endl;
306       //=======================================================
307       //Input
308       //=======================================================
309       typename FixedImageType::Pointer croppedFixedImage=fixedImage;
310       //=======================================================
311       // Regions
312       //=======================================================
313       // The original input region
314       typename FixedImageType::RegionType fixedImageRegion = fixedImage->GetLargestPossibleRegion();
315
316       // The transform region with respect to the input region:
317       // where should the transform be DEFINED (depends on mask)
318       typename FixedImageType::RegionType transformRegion = fixedImage->GetLargestPossibleRegion();
319       typename FixedImageType::RegionType::SizeType transformRegionSize=transformRegion.GetSize();
320       typename FixedImageType::RegionType::IndexType transformRegionIndex=transformRegion.GetIndex();
321       typename FixedImageType::PointType transformRegionOrigin=fixedImage->GetOrigin();
322
323       // The metric region with respect to the extracted transform region:
324       // where should the metric be CALCULATED (depends on transform)
325       typename FixedImageType::RegionType metricRegion = fixedImage->GetLargestPossibleRegion();
326       typename FixedImageType::RegionType::IndexType metricRegionIndex=metricRegion.GetIndex();
327       typename FixedImageType::PointType metricRegionOrigin=fixedImage->GetOrigin();
328
329
330       //===========================================================================
331       // If given, we connect a mask to reference or target
332       //============================================================================
333       typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension >   MaskType;
334       typedef itk::Image< unsigned char, InputImageType::ImageDimension >   ImageLabelType;
335       typename MaskType::Pointer        fixedMask = NULL;
336       typename ImageLabelType::Pointer  labels = NULL;
337       if (m_ArgsInfo.referenceMask_given)
338       {
339         fixedMask = MaskType::New();
340         labels = ImageLabelType::New();
341         typedef itk::ImageFileReader< ImageLabelType >    LabelReaderType;
342         typename LabelReaderType::Pointer  labelReader = LabelReaderType::New();
343         labelReader->SetFileName(m_ArgsInfo.referenceMask_arg);
344         try
345         {
346           labelReader->Update();
347         }
348         catch( itk::ExceptionObject & err )
349         {
350           std::cerr << "ExceptionObject caught while reading mask or labels !" << std::endl;
351           std::cerr << err << std::endl;
352           return;
353         }
354         if (m_Verbose)std::cout <<"Reference image mask was read..." <<std::endl;
355
356         // Resample labels
357         typedef itk::ResampleImageFilter<ImageLabelType, ImageLabelType> ResamplerType;
358         typename ResamplerType::Pointer resampler = ResamplerType::New();
359         typedef itk::NearestNeighborInterpolateImageFunction<ImageLabelType, TCoordRep> InterpolatorType;
360         typename InterpolatorType::Pointer interpolator = InterpolatorType::New();
361         resampler->SetOutputParametersFromImage(fixedImage);
362         resampler->SetInterpolator(interpolator);
363         resampler->SetInput(labelReader->GetOutput());
364         resampler->Update();
365         labels = resampler->GetOutput();
366
367         // Set the image to the spatialObject
368         fixedMask->SetImage(labels);
369
370         // Find the bounding box of the "inside" label
371         typedef itk::LabelGeometryImageFilter<ImageLabelType> GeometryImageFilterType;
372         typename GeometryImageFilterType::Pointer geometryImageFilter=GeometryImageFilterType::New();
373         geometryImageFilter->SetInput(labels);
374         geometryImageFilter->Update();
375         typename GeometryImageFilterType::BoundingBoxType boundingBox = geometryImageFilter->GetBoundingBox(1);
376
377         // Limit the transform region to the mask
378         for (unsigned int i=0; i<InputImageType::ImageDimension; i++)
379         {
380           transformRegionIndex[i]=boundingBox[2*i];
381           transformRegionSize[i]=boundingBox[2*i+1]-boundingBox[2*i]+1;
382         }
383         transformRegion.SetSize(transformRegionSize);
384         transformRegion.SetIndex(transformRegionIndex);
385         fixedImage->TransformIndexToPhysicalPoint(transformRegion.GetIndex(), transformRegionOrigin);
386
387         // Crop the fixedImage to the bounding box to facilitate multi-resolution
388         typedef itk::ExtractImageFilter<FixedImageType,FixedImageType> ExtractImageFilterType;
389         typename ExtractImageFilterType::Pointer extractImageFilter=ExtractImageFilterType::New();
390 #if ITK_VERSION_MAJOR == 4
391         extractImageFilter->SetDirectionCollapseToSubmatrix();
392 #endif
393         extractImageFilter->SetInput(fixedImage);
394         extractImageFilter->SetExtractionRegion(transformRegion);
395         extractImageFilter->Update();
396         croppedFixedImage=extractImageFilter->GetOutput();
397
398         // Update the metric region
399         metricRegion = croppedFixedImage->GetLargestPossibleRegion();
400         metricRegionIndex=metricRegion.GetIndex();
401         croppedFixedImage->TransformIndexToPhysicalPoint(metricRegionIndex, metricRegionOrigin);
402
403         // Set start index to zero (with respect to croppedFixedImage/transform region)
404         metricRegionIndex.Fill(0);
405         metricRegion.SetIndex(metricRegionIndex);
406         croppedFixedImage->SetRegions(metricRegion);
407         croppedFixedImage->SetOrigin(metricRegionOrigin);
408       }
409
410       typedef itk::ImageMaskSpatialObject< InputImageType::ImageDimension >   MaskType;
411       typename MaskType::Pointer  movingMask=NULL;
412       if (m_ArgsInfo.targetMask_given)
413       {
414         movingMask= MaskType::New();
415         typedef itk::Image< unsigned char, InputImageType::ImageDimension >   ImageMaskType;
416         typedef itk::ImageFileReader< ImageMaskType >    MaskReaderType;
417         typename MaskReaderType::Pointer  maskReader = MaskReaderType::New();
418         maskReader->SetFileName(m_ArgsInfo.targetMask_arg);
419         try
420         {
421           maskReader->Update();
422         }
423         catch( itk::ExceptionObject & err )
424         {
425           std::cerr << "ExceptionObject caught !" << std::endl;
426           std::cerr << err << std::endl;
427         }
428         if (m_Verbose)std::cout <<"Target image mask was read..." <<std::endl;
429
430         movingMask->SetImage( maskReader->GetOutput() );
431       }
432
433
434       //=======================================================
435       // Output Regions
436       //=======================================================
437
438       if (m_Verbose)
439       {
440         // Fixed image region
441         std::cout<<"The fixed image has its origin at "<<fixedImage->GetOrigin()<<std::endl
442           <<"The fixed image region starts at index "<<fixedImageRegion.GetIndex()<<std::endl
443           <<"The fixed image region has size "<< fixedImageRegion.GetSize()<<std::endl;
444
445         // Transform region
446         std::cout<<"The transform has its origin at "<<transformRegionOrigin<<std::endl
447           <<"The transform region will start at index "<<transformRegion.GetIndex()<<std::endl
448           <<"The transform region has size "<< transformRegion.GetSize()<<std::endl;
449
450         // Metric region
451         std::cout<<"The metric region has its origin at "<<metricRegionOrigin<<std::endl
452           <<"The metric region will start at index "<<metricRegion.GetIndex()<<std::endl
453           <<"The metric region has size "<< metricRegion.GetSize()<<std::endl;
454
455       }
456
457
458       //=======================================================
459       // Pyramids (update them for transform initializer)
460       //=======================================================
461       typedef itk::RecursiveMultiResolutionPyramidImageFilter< FixedImageType, FixedImageType>    FixedImagePyramidType;
462       typedef itk::RecursiveMultiResolutionPyramidImageFilter< MovingImageType, MovingImageType>    MovingImagePyramidType;
463       typename FixedImagePyramidType::Pointer fixedImagePyramid = FixedImagePyramidType::New();
464       typename MovingImagePyramidType::Pointer movingImagePyramid = MovingImagePyramidType::New();
465       fixedImagePyramid->SetUseShrinkImageFilter(false);
466       fixedImagePyramid->SetInput(croppedFixedImage);
467       fixedImagePyramid->SetNumberOfLevels(m_ArgsInfo.levels_arg);
468       movingImagePyramid->SetUseShrinkImageFilter(false);
469       movingImagePyramid->SetInput(movingImage);
470       movingImagePyramid->SetNumberOfLevels(m_ArgsInfo.levels_arg);
471       if (m_Verbose) std::cout<<"Creating the image pyramid..."<<std::endl;
472       fixedImagePyramid->Update();
473       movingImagePyramid->Update();
474       typedef clitk::MultiResolutionPyramidRegionFilter<FixedImageType> FixedImageRegionPyramidType;
475       typename FixedImageRegionPyramidType::Pointer fixedImageRegionPyramid=FixedImageRegionPyramidType::New();
476       fixedImageRegionPyramid->SetRegion(metricRegion);
477       fixedImageRegionPyramid->SetSchedule(fixedImagePyramid->GetSchedule());
478
479
480       //=======================================================
481       // Rigid or Affine Transform
482       //=======================================================
483       typedef itk::AffineTransform <double,3> RigidTransformType;
484       RigidTransformType::Pointer rigidTransform=NULL;
485       if (m_ArgsInfo.initMatrix_given)
486       {
487         if(m_Verbose) std::cout<<"Reading the prior transform matrix "<< m_ArgsInfo.initMatrix_arg<<"..."<<std::endl;
488         rigidTransform=RigidTransformType::New();
489         itk::Matrix<double,4,4> rigidTransformMatrix=clitk::ReadMatrix3D(m_ArgsInfo.initMatrix_arg);
490
491         //Set the rotation
492         itk::Matrix<double,3,3> finalRotation = clitk::GetRotationalPartMatrix3D(rigidTransformMatrix);
493         rigidTransform->SetMatrix(finalRotation);
494
495         //Set the translation
496         itk::Vector<double,3> finalTranslation = clitk::GetTranslationPartMatrix3D(rigidTransformMatrix);
497         rigidTransform->SetTranslation(finalTranslation);
498       }
499       else
500       {
501         if(m_Verbose) std::cout<<"No itinial matrix given. Centering all images..."<<std::endl;
502         
503         rigidTransform=RigidTransformType::New();
504         
505         typedef itk::CenteredTransformInitializer<RigidTransformType, FixedImageType, MovingImageType > TransformInitializerType;
506         typename TransformInitializerType::Pointer initializer = TransformInitializerType::New();
507         initializer->SetTransform( rigidTransform );
508         initializer->SetFixedImage( fixedImage );
509         initializer->SetMovingImage( movingImage );        
510         initializer->GeometryOn();
511         initializer->InitializeTransform();
512       }
513
514
515       //=======================================================
516       // B-LUT FFD Transform
517       //=======================================================
518       typedef  clitk::MultipleBSplineDeformableTransform<TCoordRep,InputImageType::ImageDimension, SpaceDimension > TransformType;
519       typename TransformType::Pointer transform = TransformType::New();
520       if (labels) transform->SetLabels(labels);
521       if (rigidTransform) transform->SetBulkTransform(rigidTransform);
522
523       //-------------------------------------------------------------------------
524       // The transform initializer
525       //-------------------------------------------------------------------------
526       typedef clitk::MultipleBSplineDeformableTransformInitializer< TransformType,FixedImageType> InitializerType;
527       typename InitializerType::Pointer initializer = InitializerType::New();
528       initializer->SetVerbose(m_Verbose);
529       initializer->SetImage(fixedImagePyramid->GetOutput(0));
530       initializer->SetTransform(transform);
531
532       //-------------------------------------------------------------------------
533       // Order
534       //-------------------------------------------------------------------------
535       typename FixedImageType::RegionType::SizeType splineOrders ;
536       splineOrders.Fill(3);
537       if (m_ArgsInfo.order_given)
538         for(unsigned int i=0; i<InputImageType::ImageDimension;i++)
539           splineOrders[i]=m_ArgsInfo.order_arg[i];
540       if (m_Verbose) std::cout<<"Setting the spline orders  to "<<splineOrders<<"..."<<std::endl;
541       initializer->SetSplineOrders(splineOrders);
542
543       //-------------------------------------------------------------------------
544       // Levels
545       //-------------------------------------------------------------------------
546
547       // Spacing
548       if (m_ArgsInfo.spacing_given)
549       {
550         initializer->m_ControlPointSpacingArray.resize(m_ArgsInfo.levels_arg);
551         initializer->SetControlPointSpacing(m_ArgsInfo.spacing_arg);
552         initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1]=initializer->m_ControlPointSpacing;
553         if (m_Verbose) std::cout<<"Using a control point spacing of "<<initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1]
554           <<" at level "<<m_ArgsInfo.levels_arg<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
555
556         for (int i=1; i<m_ArgsInfo.levels_arg; i++ )
557         {
558           initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1-i]=initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-i]*2;
559           if (m_Verbose) std::cout<<"Using a control point spacing of "<<initializer->m_ControlPointSpacingArray[m_ArgsInfo.levels_arg-1-i]
560             <<" at level "<<m_ArgsInfo.levels_arg-i<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
561         }
562
563       }
564
565       // Control
566       if (m_ArgsInfo.control_given)
567       {
568         initializer->m_NumberOfControlPointsInsideTheImageArray.resize(m_ArgsInfo.levels_arg);
569         initializer->SetNumberOfControlPointsInsideTheImage(m_ArgsInfo.control_arg);
570         initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1]=initializer->m_NumberOfControlPointsInsideTheImage;
571         if (m_Verbose) std::cout<<"Using "<< initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1]<<"control points inside the image"
572           <<" at level "<<m_ArgsInfo.levels_arg<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
573
574         for (int i=1; i<m_ArgsInfo.levels_arg; i++ )
575         {
576           for(unsigned int j=0;j<InputImageType::ImageDimension;j++)
577             initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i][j]=ceil ((double)initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-i][j]/2.);
578           //        initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i]=ceil ((double)initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-i]/2.);
579           if (m_Verbose) std::cout<<"Using "<< initializer->m_NumberOfControlPointsInsideTheImageArray[m_ArgsInfo.levels_arg-1-i]<<"control points inside the image"
580             <<" at level "<<m_ArgsInfo.levels_arg<<" of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
581
582         }
583       }
584
585       // Inialize on the first level
586       if (m_ArgsInfo.verbose_flag) std::cout<<"Initializing transform for level 1 of "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
587       if (m_ArgsInfo.spacing_given) initializer->SetControlPointSpacing(        initializer->m_ControlPointSpacingArray[0]);
588       if (m_ArgsInfo.control_given) initializer->SetNumberOfControlPointsInsideTheImage(initializer->m_NumberOfControlPointsInsideTheImageArray[0]);
589       if (m_ArgsInfo.samplingFactor_given) initializer->SetSamplingFactors(m_ArgsInfo.samplingFactor_arg);
590
591       // Initialize
592       initializer->InitializeTransform();
593
594       //-------------------------------------------------------------------------
595       // Initial parameters (passed by reference)
596       //-------------------------------------------------------------------------
597       typedef typename TransformType::ParametersType     ParametersType;
598       const unsigned int numberOfParameters =    transform->GetNumberOfParameters();
599       ParametersType parameters(numberOfParameters);
600       parameters.Fill( 0.0 );
601       transform->SetParameters( parameters );
602       if (m_ArgsInfo.initCoeff_given) initializer->SetInitialParameters(m_ArgsInfo.initCoeff_arg, parameters);
603
604       //-------------------------------------------------------------------------
605       // DEBUG: use an itk BSpline instead of multilabel BLUTs
606       //-------------------------------------------------------------------------
607       typedef itk::Transform< TCoordRep, 3, 3 > RegistrationTransformType;
608       RegistrationTransformType::Pointer regTransform(transform);
609       typedef itk::BSplineDeformableTransform<TCoordRep,SpaceDimension, 3> SingleBSplineTransformType;
610       typename SingleBSplineTransformType::Pointer sTransform;
611       if(m_ArgsInfo.itkbspline_flag) {
612         if( transform->GetTransforms().size()>1)
613           itkExceptionMacro(<< "invalid --itkbspline option if there is more than 1 label")
614         sTransform = SingleBSplineTransformType::New();
615         sTransform->SetBulkTransform( transform->GetTransforms()[0]->GetBulkTransform() );
616         sTransform->SetGridSpacing( transform->GetTransforms()[0]->GetGridSpacing() );
617         sTransform->SetGridOrigin( transform->GetTransforms()[0]->GetGridOrigin() );
618         sTransform->SetGridRegion( transform->GetTransforms()[0]->GetGridRegion() );
619         sTransform->SetParameters( transform->GetTransforms()[0]->GetParameters() );
620         regTransform = sTransform;
621         transform = NULL; // free memory
622       }
623
624       //=======================================================
625       // Interpolator
626       //=======================================================
627       typedef clitk::GenericInterpolator<args_info_clitkBLUTDIR, FixedImageType, TCoordRep > GenericInterpolatorType;
628       typename   GenericInterpolatorType::Pointer genericInterpolator=GenericInterpolatorType::New();
629       genericInterpolator->SetArgsInfo(m_ArgsInfo);
630       typedef itk::InterpolateImageFunction< FixedImageType, TCoordRep >  InterpolatorType;
631       typename  InterpolatorType::Pointer interpolator=genericInterpolator->GetInterpolatorPointer();
632
633
634       //=======================================================
635       // Metric
636       //=======================================================
637       typedef clitk::GenericMetric<args_info_clitkBLUTDIR, FixedImageType,MovingImageType > GenericMetricType;
638       typename GenericMetricType::Pointer genericMetric=GenericMetricType::New();
639       genericMetric->SetArgsInfo(m_ArgsInfo);
640       genericMetric->SetFixedImage(fixedImagePyramid->GetOutput(0));
641       if (fixedMask) genericMetric->SetFixedImageMask(fixedMask);
642       genericMetric->SetFixedImageRegion(fixedImageRegionPyramid->GetOutput(0));
643       typedef itk::ImageToImageMetric< FixedImageType, MovingImageType >  MetricType;
644       typename  MetricType::Pointer metric=genericMetric->GetMetricPointer();
645       if (movingMask) metric->SetMovingImageMask(movingMask);
646
647 #if defined(ITK_USE_OPTIMIZED_REGISTRATION_METHODS) || ITK_VERSION_MAJOR >= 4
648       if (threadsGiven) {
649         metric->SetNumberOfThreads( threads );
650         if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl;
651       }
652 #else
653       if (m_Verbose) std::cout<<"Not setting the number of threads (not compiled with USE_OPTIMIZED_REGISTRATION_METHODS)..."<<std::endl;
654 #endif
655
656
657       //=======================================================
658       // Optimizer
659       //=======================================================
660       typedef clitk::GenericOptimizer<args_info_clitkBLUTDIR> GenericOptimizerType;
661       GenericOptimizerType::Pointer genericOptimizer = GenericOptimizerType::New();
662       genericOptimizer->SetArgsInfo(m_ArgsInfo);
663       genericOptimizer->SetMaximize(genericMetric->GetMaximize());
664       genericOptimizer->SetNumberOfParameters(regTransform->GetNumberOfParameters());
665       typedef itk::SingleValuedNonLinearOptimizer OptimizerType;
666       OptimizerType::Pointer optimizer = genericOptimizer->GetOptimizerPointer();
667
668
669       //=======================================================
670       // Registration
671       //=======================================================
672       typedef itk::MultiResolutionImageRegistrationMethod<  FixedImageType, MovingImageType >    RegistrationType;
673       typename RegistrationType::Pointer   registration  = RegistrationType::New();
674       registration->SetMetric(        metric        );
675       registration->SetOptimizer(     optimizer     );
676       registration->SetInterpolator(  interpolator  );
677       registration->SetTransform (regTransform );
678       if(threadsGiven) {
679         registration->SetNumberOfThreads(threads);
680         if (m_Verbose) std::cout<< "Using " << threads << " threads." << std::endl;
681       }
682       registration->SetFixedImage(  croppedFixedImage   );
683       registration->SetMovingImage(  movingImage   );
684       registration->SetFixedImageRegion( metricRegion );
685       registration->SetFixedImagePyramid( fixedImagePyramid );
686       registration->SetMovingImagePyramid( movingImagePyramid );
687       registration->SetInitialTransformParameters( regTransform->GetParameters() );
688       registration->SetNumberOfLevels( m_ArgsInfo.levels_arg );
689       if (m_Verbose) std::cout<<"Setting the number of resolution levels to "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
690
691
692       //================================================================================================
693       // Observers
694       //================================================================================================
695       if (m_Verbose)
696       {
697         // Output iteration info
698         CommandIterationUpdate::Pointer observer = CommandIterationUpdate::New();
699         observer->SetOptimizer(genericOptimizer);
700         optimizer->AddObserver( itk::IterationEvent(), observer );
701
702         // Output level info
703         typedef RegistrationInterfaceCommand<RegistrationType,args_info_clitkBLUTDIR> CommandType;
704         typename CommandType::Pointer command = CommandType::New();
705         command->SetInitializer(initializer);
706         command->SetArgsInfo(m_ArgsInfo);
707         command->SetCommandIterationUpdate(observer);
708         command->SetMaximize(genericMetric->GetMaximize());
709         command->SetMetricRegion(metricRegion);
710         registration->AddObserver( itk::IterationEvent(), command );
711
712         if (m_ArgsInfo.coeff_given)
713         {
714           if(m_ArgsInfo.itkbspline_flag) {
715             itkExceptionMacro("--coeff and --itkbpline are incompatible");
716           }
717
718           std::cout << std::endl << "Output coefficient images every " << m_ArgsInfo.coeffEveryN_arg << " iterations." << std::endl;
719           typedef CommandIterationUpdateDVF<FixedImageType, OptimizerType, TransformType> DVFCommandType;
720           typename DVFCommandType::Pointer observerdvf = DVFCommandType::New();
721           observerdvf->SetFixedImage(fixedImage);
722           observerdvf->SetTransform(transform);
723           observerdvf->SetArgsInfo(m_ArgsInfo);
724           optimizer->AddObserver( itk::IterationEvent(), observerdvf );
725         }
726       }
727
728
729       //=======================================================
730       // Let's go
731       //=======================================================
732       if (m_Verbose) std::cout << std::endl << "Starting Registration" << std::endl;
733
734       try
735       {
736         registration->StartRegistration();
737       }
738       catch( itk::ExceptionObject & err )
739       {
740         std::cerr << "ExceptionObject caught while registering!" << std::endl;
741         std::cerr << err << std::endl;
742         return;
743       }
744
745
746       //=======================================================
747       // Get the result
748       //=======================================================
749       OptimizerType::ParametersType finalParameters =  registration->GetLastTransformParameters();
750       regTransform->SetParameters( finalParameters );
751       if (m_Verbose)
752       {
753         std::cout<<"Stop condition description: "
754           <<registration->GetOptimizer()->GetStopConditionDescription()<<std::endl;
755       }
756
757
758       //=======================================================
759       // Get the BSpline coefficient images and write them
760       //=======================================================
761       if (m_ArgsInfo.coeff_given)
762       {
763         typedef typename TransformType::CoefficientImageType CoefficientImageType;
764         std::vector<typename CoefficientImageType::Pointer> coefficientImages = transform->GetCoefficientImages();
765         typedef itk::ImageFileWriter<CoefficientImageType> CoeffWriterType;
766         typename CoeffWriterType::Pointer coeffWriter = CoeffWriterType::New();
767         unsigned nLabels = transform->GetnLabels();
768
769         std::string fname(m_ArgsInfo.coeff_arg);
770         int dotpos = fname.length() - 1;
771         while (dotpos >= 0 && fname[dotpos] != '.')
772           dotpos--;
773
774         for (unsigned i = 0; i < nLabels; ++i)
775         {
776           std::ostringstream osfname;
777           osfname << fname.substr(0, dotpos) << '_' << i << fname.substr(dotpos);
778           coeffWriter->SetInput(coefficientImages[i]);
779           coeffWriter->SetFileName(osfname.str());
780           coeffWriter->Update();
781         }
782       }
783
784
785
786       //=======================================================
787       // Compute the DVF (only deformable transform)
788       //=======================================================
789       typedef itk::Vector< float, SpaceDimension >  DisplacementType;
790       typedef itk::Image< DisplacementType, InputImageType::ImageDimension >  DisplacementFieldType;
791 #if ITK_VERSION_MAJOR >= 4
792       typedef itk::TransformToDisplacementFieldSource<DisplacementFieldType, double> ConvertorType;
793 #else
794       typedef itk::TransformToDeformationFieldSource<DisplacementFieldType, double> ConvertorType;
795 #endif
796       typename ConvertorType::Pointer filter= ConvertorType::New();
797       filter->SetNumberOfThreads(1);
798       if(m_ArgsInfo.itkbspline_flag)
799         sTransform->SetBulkTransform(NULL);
800       else
801         transform->SetBulkTransform(NULL);
802       filter->SetTransform(regTransform);
803       filter->SetOutputParametersFromImage(fixedImage);
804       filter->Update();
805       typename DisplacementFieldType::Pointer field = filter->GetOutput();
806
807
808       //=======================================================
809       // Write the DVF
810       //=======================================================
811       typedef itk::ImageFileWriter< DisplacementFieldType >  FieldWriterType;
812       typename FieldWriterType::Pointer fieldWriter = FieldWriterType::New();
813       fieldWriter->SetFileName( m_ArgsInfo.vf_arg );
814       fieldWriter->SetInput( field );
815       try
816       {
817         fieldWriter->Update();
818       }
819       catch( itk::ExceptionObject & excp )
820       {
821         std::cerr << "Exception thrown writing the DVF" << std::endl;
822         std::cerr << excp << std::endl;
823         return;
824       }
825
826
827       //=======================================================
828       // Resample the moving image
829       //=======================================================
830       typedef itk::ResampleImageFilter< MovingImageType, FixedImageType >    ResampleFilterType;
831       typename ResampleFilterType::Pointer resampler = ResampleFilterType::New();
832       if (rigidTransform) {
833         if(m_ArgsInfo.itkbspline_flag)
834           sTransform->SetBulkTransform(rigidTransform);
835         else
836           transform->SetBulkTransform(rigidTransform);
837       }
838       resampler->SetTransform( regTransform );
839       resampler->SetInput( movingImage);
840       resampler->SetOutputParametersFromImage(fixedImage);
841       resampler->Update();
842       typename FixedImageType::Pointer result=resampler->GetOutput();
843
844       //     typedef itk::WarpImageFilter< MovingImageType, FixedImageType, DeformationFieldType >    WarpFilterType;
845       //     typename WarpFilterType::Pointer warp = WarpFilterType::New();
846
847       //     warp->SetDeformationField( field );
848       //     warp->SetInput( movingImageReader->GetOutput() );
849       //     warp->SetOutputOrigin(  fixedImage->GetOrigin() );
850       //     warp->SetOutputSpacing( fixedImage->GetSpacing() );
851       //     warp->SetOutputDirection( fixedImage->GetDirection() );
852       //     warp->SetEdgePaddingValue( 0.0 );
853       //     warp->Update();
854
855
856       //=======================================================
857       // Write the warped image
858       //=======================================================
859       typedef itk::ImageFileWriter< FixedImageType >  WriterType;
860       typename WriterType::Pointer      writer =  WriterType::New();
861       writer->SetFileName( m_ArgsInfo.output_arg );
862       writer->SetInput( result    );
863
864       try
865       {
866         writer->Update();
867       }
868       catch( itk::ExceptionObject & err )
869       {
870         std::cerr << "ExceptionObject caught !" << std::endl;
871         std::cerr << err << std::endl;
872         return;
873       }
874
875
876       //=======================================================
877       // Calculate the difference after the deformable transform
878       //=======================================================
879       typedef clitk::DifferenceImageFilter<  FixedImageType, FixedImageType> DifferenceFilterType;
880       if (m_ArgsInfo.after_given)
881       {
882         typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New();
883         difference->SetValidInput( fixedImage );
884         difference->SetTestInput( result );
885
886         try
887         {
888           difference->Update();
889         }
890         catch( itk::ExceptionObject & err )
891         {
892           std::cerr << "ExceptionObject caught calculating the difference !" << std::endl;
893           std::cerr << err << std::endl;
894           return;
895         }
896
897         typename WriterType::Pointer differenceWriter=WriterType::New();
898         differenceWriter->SetInput(difference->GetOutput());
899         differenceWriter->SetFileName(m_ArgsInfo.after_arg);
900         differenceWriter->Update();
901
902       }
903
904
905       //=======================================================
906       // Calculate the difference before the deformable transform
907       //=======================================================
908       if( m_ArgsInfo.before_given )
909       {
910
911         typename FixedImageType::Pointer moving=FixedImageType::New();
912         if (m_ArgsInfo.initMatrix_given)
913         {
914           typedef itk::ResampleImageFilter<MovingImageType, FixedImageType> ResamplerType;
915           typename ResamplerType::Pointer resampler=ResamplerType::New();
916           resampler->SetInput(movingImage);
917           resampler->SetOutputOrigin(fixedImage->GetOrigin());
918           resampler->SetSize(fixedImage->GetLargestPossibleRegion().GetSize());
919           resampler->SetOutputSpacing(fixedImage->GetSpacing());
920           resampler->SetDefaultPixelValue( 0. );
921           if (rigidTransform ) resampler->SetTransform(rigidTransform);
922           resampler->Update();
923           moving=resampler->GetOutput();
924         }
925         else
926           moving=movingImage;
927
928         typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New();
929         difference->SetValidInput( fixedImage );
930         difference->SetTestInput( moving );
931
932         try
933         {
934           difference->Update();
935         }
936         catch( itk::ExceptionObject & err )
937         {
938           std::cerr << "ExceptionObject caught calculating the difference !" << std::endl;
939           std::cerr << err << std::endl;
940           return;
941         }
942
943         typename WriterType::Pointer differenceWriter=WriterType::New();
944         writer->SetFileName( m_ArgsInfo.before_arg  );
945         writer->SetInput( difference->GetOutput()  );
946         writer->Update( );
947       }
948
949       return;
950
951     }
952 }//end clitk
953
954 #endif // #define clitkBLUTDIRGenericFilter_txx