]> Creatis software - clitk.git/blob - registration/clitkBSplineDeformableRegistrationGenericFilter.txx
Attempt to make clitkBSplineDeformableRegistration work
[clitk.git] / registration / clitkBSplineDeformableRegistrationGenericFilter.txx
1 /*=========================================================================
2   Program:   vv                     http://www.creatis.insa-lyon.fr/rio/vv
3
4   Authors belong to:
5   - University of LYON              http://www.universite-lyon.fr/
6   - Léon Bérard cancer center       http://www.centreleonberard.fr
7   - CREATIS CNRS laboratory         http://www.creatis.insa-lyon.fr
8
9   This software is distributed WITHOUT ANY WARRANTY; without even
10   the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
11   PURPOSE.  See the copyright notices for more information.
12
13   It is distributed under dual licence
14
15   - BSD        See included LICENSE.txt file
16   - CeCILL-B   http://www.cecill.info/licences/Licence_CeCILL-B_V1-en.html
17 ===========================================================================**/
18 #ifndef __clitkBSplineDeformableRegistrationGenericFilter_txx
19 #define __clitkBSplineDeformableRegistrationGenericFilter_txx
20 #include "clitkBSplineDeformableRegistrationGenericFilter.h"
21
22
23 namespace clitk
24 {
25
26 //==============================================================================
27 //Creating an observer class that allows us to change parameters at subsequent levels
28 //==============================================================================
29 template <typename TRegistration>
30 class RegistrationInterfaceCommand : public itk::Command
31 {
32 public:
33   typedef RegistrationInterfaceCommand   Self;
34
35
36   typedef itk::Command             Superclass;
37   typedef itk::SmartPointer<Self>  Pointer;
38   itkNewMacro( Self );
39 protected:
40   RegistrationInterfaceCommand() {};
41 public:
42   typedef   TRegistration                              RegistrationType;
43   typedef   RegistrationType *                         RegistrationPointer;
44
45   // Two arguments are passed to the Execute() method: the first
46   // is the pointer to the object which invoked the event and the
47   // second is the event that was invoked.
48   void Execute(itk::Object * object, const itk::EventObject & event) {
49     if( !(itk::IterationEvent().CheckEvent( &event )) ) {
50       return;
51     }
52     RegistrationPointer registration = dynamic_cast<RegistrationPointer>( object );
53     unsigned int numberOfLevels=registration->GetNumberOfLevels();
54     unsigned int currentLevel=registration->GetCurrentLevel()+1;
55     std::cout<<std::endl;
56     std::cout<<"========================================"<<std::endl;
57     std::cout<<"Starting resolution level "<<currentLevel<<" of "<<numberOfLevels<<"..."<<std::endl;
58     std::cout<<"========================================"<<std::endl;
59     std::cout<<std::endl;
60   }
61
62   void Execute(const itk::Object * , const itk::EventObject & ) {
63     return;
64   }
65
66 };
67
68
69 //==============================================================================
70 // Creating an observer class that allows output at each iteration
71 //==============================================================================
72 class CommandIterationUpdate : public itk::Command
73 {
74 public:
75   typedef  CommandIterationUpdate   Self;
76   typedef  itk::Command             Superclass;
77   typedef  itk::SmartPointer<Self>  Pointer;
78   itkNewMacro( Self );
79 protected:
80   CommandIterationUpdate() {};
81 public:
82   typedef   clitk::GenericOptimizer<args_info_clitkBSplineDeformableRegistration>     OptimizerType;
83   typedef   const OptimizerType   *           OptimizerPointer;
84
85   // We set the generic optimizer
86   void SetOptimizer(OptimizerPointer o) {
87     m_Optimizer=o;
88   }
89
90   // Execute
91   void Execute(itk::Object *caller, const itk::EventObject & event) {
92     Execute( (const itk::Object *)caller, event);
93   }
94
95   void Execute(const itk::Object * object, const itk::EventObject & event) {
96     if( !(itk::IterationEvent().CheckEvent( &event )) ) {
97       return;
98     }
99
100     m_Optimizer->OutputIterationInfo();
101   }
102
103   OptimizerPointer m_Optimizer;
104 };
105
106
107 //==============================================================================
108 // Update with the number of dimensions
109 //==============================================================================
110 template<unsigned int Dimension>
111 void BSplineDeformableRegistrationGenericFilter::UpdateWithDim(std::string PixelType)
112 {
113
114   if (m_Verbose) std::cout  << "Images were detected to be "<< Dimension << "D and " << PixelType << "..." << std::endl;
115
116   if(PixelType == "short") {
117     if (m_Verbose) std::cout  << "Launching warp in "<< Dimension <<"D and signed short..." << std::endl;
118     UpdateWithDimAndPixelType<Dimension, signed short>();
119   }
120   //    else if(PixelType == "unsigned_short"){
121   //       if (m_Verbose) std::cout  << "Launching warp in "<< Dimension <<"D and unsigned_short..." << std::endl;
122   //       UpdateWithDimAndPixelType<Dimension, unsigned short>();
123   //     }
124
125   //     else if (PixelType == "unsigned_char"){
126   //       if (m_Verbose) std::cout  << "Launching warp in "<< Dimension <<"D and unsigned_char..." << std::endl;
127   //       UpdateWithDimAndPixelType<Dimension, unsigned char>();
128   //     }
129
130   //     else if (PixelType == "char"){
131   //       if (m_Verbose) std::cout  << "Launching warp in "<< Dimension <<"D and signed_char..." << std::endl;
132   //       UpdateWithDimAndPixelType<Dimension, signed char>();
133   //    }
134   else {
135     if (m_Verbose) std::cout  << "Launching filter in "<< Dimension <<"D and float..." << std::endl;
136     UpdateWithDimAndPixelType<Dimension, float>();
137   }
138 }
139
140
141
142 //==============================================================================
143 // Update with the number of dimensions and pixeltype
144 //==============================================================================
145 template<unsigned int ImageDimension, class PixelType>
146 void BSplineDeformableRegistrationGenericFilter::UpdateWithDimAndPixelType()
147 {
148
149   //=======================================================
150   // Run-time
151   //=======================================================
152   bool threadsGiven=m_ArgsInfo.threads_given;
153   int threads=m_ArgsInfo.threads_arg;
154
155   typedef itk::Image< PixelType, ImageDimension >  FixedImageType;
156   typedef itk::Image< PixelType, ImageDimension >  MovingImageType;
157   const unsigned int SpaceDimension = ImageDimension;
158   typedef double TCoordRep;
159
160
161   //=======================================================
162   //Input
163   //=======================================================
164   typedef itk::ImageFileReader< FixedImageType  > FixedImageReaderType;
165   typedef itk::ImageFileReader< MovingImageType > MovingImageReaderType;
166
167   typename FixedImageReaderType::Pointer  fixedImageReader  = FixedImageReaderType::New();
168   typename MovingImageReaderType::Pointer movingImageReader = MovingImageReaderType::New();
169
170   fixedImageReader->SetFileName(  m_ArgsInfo.reference_arg );
171   movingImageReader->SetFileName( m_ArgsInfo.target_arg );
172   if (m_Verbose) std::cout<<"Reading images..."<<std::endl;
173   fixedImageReader->Update();
174   movingImageReader->Update();
175
176   typename FixedImageType::Pointer fixedImage = fixedImageReader->GetOutput();
177   typename MovingImageType::Pointer movingImage =movingImageReader->GetOutput();
178   typename FixedImageType::RegionType fixedImageRegion = fixedImage->GetLargestPossibleRegion();
179
180   // The metric region: where should the metric be CALCULATED (depends on mask)
181   typename FixedImageType::RegionType metricRegion = fixedImage->GetLargestPossibleRegion();
182   typename FixedImageType::RegionType::SizeType metricRegionSize=metricRegion.GetSize();
183   typename FixedImageType::RegionType::IndexType metricRegionIndex=metricRegion.GetIndex();
184   typename FixedImageType::PointType metricRegionOrigin=fixedImage->GetOrigin();
185
186   // The transform region: where should the transform be DEFINED (depends on mask)
187   typename FixedImageType::RegionType transformRegion = fixedImage->GetLargestPossibleRegion();
188   typename FixedImageType::RegionType::SizeType transformRegionSize=transformRegion.GetSize();
189   typename FixedImageType::RegionType::IndexType transformRegionIndex=transformRegion.GetIndex();
190   typename FixedImageType::PointType transformRegionOrigin=fixedImage->GetOrigin();
191
192
193   //=======================================================
194   // If given, we connect a mask to the fixed image
195   //======================================================
196   typedef itk::ImageMaskSpatialObject<  ImageDimension >   MaskType;
197   typename MaskType::Pointer  spatialObjectMask=NULL;
198
199   if (m_ArgsInfo.mask_given) {
200     typedef itk::Image< unsigned char, ImageDimension >   ImageMaskType;
201     typedef itk::ImageFileReader< ImageMaskType >    MaskReaderType;
202     typename MaskReaderType::Pointer  maskReader = MaskReaderType::New();
203     maskReader->SetFileName(m_ArgsInfo.mask_arg);
204
205     try {
206       maskReader->Update();
207     } catch( itk::ExceptionObject & err ) {
208       std::cerr << "ExceptionObject caught while reading mask !" << std::endl;
209       std::cerr << err << std::endl;
210       return;
211     }
212     if (m_Verbose)std::cout <<"Reference image mask was read..." <<std::endl;
213
214
215     // Set the image to the spatialObject
216     spatialObjectMask = MaskType::New();
217     spatialObjectMask->SetImage( maskReader->GetOutput() );
218
219     // Find the bounding box of the "inside" label
220     typedef itk::LabelStatisticsImageFilter<ImageMaskType, ImageMaskType> StatisticsImageFilterType;
221     typename StatisticsImageFilterType::Pointer statisticsImageFilter=StatisticsImageFilterType::New();
222     statisticsImageFilter->SetInput(maskReader->GetOutput());
223     statisticsImageFilter->SetLabelInput(maskReader->GetOutput());
224     statisticsImageFilter->Update();
225     typename StatisticsImageFilterType::BoundingBoxType boundingBox = statisticsImageFilter->GetBoundingBox(1);
226
227     // Limit the transform region to the mask
228     for (unsigned int i=0; i<ImageDimension; i++) {
229       transformRegionIndex[i]=boundingBox[2*i];
230       transformRegionSize[i]=boundingBox[2*i+1]-boundingBox[2*i]+1;
231     }
232     transformRegion.SetSize(transformRegionSize);
233     transformRegion.SetIndex(transformRegionIndex);
234     fixedImage->TransformIndexToPhysicalPoint(transformRegion.GetIndex(), transformRegionOrigin);
235
236     // Limit the metric region to the mask
237     metricRegion=transformRegion;
238     fixedImage->TransformIndexToPhysicalPoint(metricRegion.GetIndex(), metricRegionOrigin);
239
240   }
241
242
243
244   //=======================================================
245   // Regions
246   //=====================================================
247   if (m_Verbose) {
248     // Fixed image region
249     std::cout<<"The fixed image has its origin at "<<fixedImage->GetOrigin()<<std::endl
250              <<"The fixed image region starts at index "<<fixedImageRegion.GetIndex()<<std::endl
251              <<"The fixed image region has size "<< fixedImageRegion.GetSize()<<std::endl;
252
253     // Transform region
254     std::cout<<"The transform has its origin at "<<transformRegionOrigin<<std::endl
255              <<"The transform region will start at index "<<transformRegion.GetIndex()<<std::endl
256              <<"The transform region has size "<< transformRegion.GetSize()<<std::endl;
257
258     // Metric region
259     std::cout<<"The metric region has its origin at "<<metricRegionOrigin<<std::endl
260              <<"The metric region will start at index "<<metricRegion.GetIndex()<<std::endl
261              <<"The metric region has size "<< metricRegion.GetSize()<<std::endl;
262
263   }
264
265
266   //=======================================================
267   //Pyramids
268   //=======================================================
269   typedef itk::RecursiveMultiResolutionPyramidImageFilter< FixedImageType, FixedImageType>    FixedImagePyramidType;
270   typedef itk::RecursiveMultiResolutionPyramidImageFilter< MovingImageType, MovingImageType>    MovingImagePyramidType;
271   typename FixedImagePyramidType::Pointer fixedImagePyramid = FixedImagePyramidType::New();
272   typename MovingImagePyramidType::Pointer movingImagePyramid = MovingImagePyramidType::New();
273
274
275   //     //=======================================================
276   //     // Rigid Transform
277   //     //=======================================================
278   //     typedef itk::Euler3DTransform <double> RigidTransformType;
279   //     RigidTransformType::Pointer rigidTransform=RigidTransformType::New();
280
281   //     if (m_ArgsInfo.rigid_given)
282   //       {
283   //            itk::Matrix<double,4,4> rigidTransformMatrix=clitk::ReadMatrix3D(m_ArgsInfo.rigid_arg);
284
285   //            //Set the rotation
286   //            itk::Matrix<double,3,3> finalRotation = clitk::GetRotationalPartMatrix3D(rigidTransformMatrix);
287   //            rigidTransform->SetMatrix(finalRotation);
288
289   //            //Set the translation
290   //            itk::Vector<double,3> finalTranslation = clitk::GetTranslationPartMatrix3D(rigidTransformMatrix);
291   //            rigidTransform->SetTranslation(finalTranslation);
292
293   //       }
294
295
296   //=======================================================
297   // BSpline Transform
298   //=======================================================
299   typename FixedImageType::RegionType::SizeType splineOrders ;
300
301   //Default is cubic splines
302   splineOrders.Fill(3);
303   if (m_ArgsInfo.order_given)
304     for(unsigned int i=0; i<ImageDimension; i++)
305       splineOrders[i]=m_ArgsInfo.order_arg[i];
306
307   // BLUT or ITK FFD
308   typedef itk::Transform<TCoordRep, ImageDimension, SpaceDimension> TransformType;
309   typename TransformType::Pointer transform;
310   typedef  itk::BSplineDeformableTransform<TCoordRep,SpaceDimension, 3> BSplineTransformType;
311   typedef  BSplineTransformType* BSplineTransformPointer;
312   typedef  clitk::BSplineDeformableTransform<TCoordRep,ImageDimension, SpaceDimension > BLUTTransformType;
313   typedef  BLUTTransformType* BLUTTransformPointer;
314
315   // JV parameter array is passed by reference, create outside context so it exists afterwards!!!!!
316   typedef typename TransformType::ParametersType     ParametersType;
317   ParametersType parameters;
318
319
320   // CLITK BLUT transform
321   if(m_ArgsInfo.wlut_flag) {
322     typename BLUTTransformType::Pointer  bsplineTransform = BLUTTransformType::New();
323     if (m_Verbose) std::cout<<"Setting the spline orders  to "<<splineOrders<<"..."<<std::endl;
324     bsplineTransform->SetSplineOrders(splineOrders);
325
326     //-------------------------------------------------------------------------
327     // Define the region: Either the spacing or the number of CP should be given
328     //-------------------------------------------------------------------------
329
330     // Region
331     typedef typename BSplineTransformType::RegionType RegionType;
332     RegionType bsplineRegion;
333     typename RegionType::SizeType   gridSizeOnImage;
334     typename RegionType::SizeType   gridBorderSize;
335     typename RegionType::SizeType   totalGridSize;
336
337     // Spacing
338     typedef typename BSplineTransformType::SpacingType SpacingType;
339     SpacingType fixedImageSpacing, chosenSpacing, adaptedSpacing;
340     fixedImageSpacing = fixedImage->GetSpacing();
341
342     // Only spacing given: adjust if necessary
343     if (m_ArgsInfo.spacing_given && !m_ArgsInfo.control_given) {
344       for(unsigned int r=0; r<ImageDimension; r++) {
345         chosenSpacing[r]= m_ArgsInfo.spacing_arg[r];
346         gridSizeOnImage[r] = ceil( (double) transformRegion.GetSize()[r] / ( itk::Math::Round<double>(chosenSpacing[r]/fixedImageSpacing[r]) ) );
347         adaptedSpacing[r]= ( itk::Math::Round<double>(chosenSpacing[r]/fixedImageSpacing[r]) *fixedImageSpacing[r] ) ;
348       }
349       if (m_Verbose) std::cout<<"The chosen control point spacing "<<chosenSpacing<<"..."<<std::endl;
350       if (m_Verbose) std::cout<<"The control points spacing was adapted to "<<adaptedSpacing<<"..."<<std::endl;
351       if (m_Verbose) std::cout<<"The number of (internal) control points is "<<gridSizeOnImage<<"..."<<std::endl;
352     }
353
354     // Only number of CP given: adjust if necessary
355     else if (m_ArgsInfo.control_given && !m_ArgsInfo.spacing_given) {
356       for(unsigned int r=0; r<ImageDimension; r++) {
357         gridSizeOnImage[r]= m_ArgsInfo.control_arg[r];
358         chosenSpacing[r]=fixedImageSpacing[r]*( (double)(transformRegion.GetSize()[r])  /
359                                                 (double)(gridSizeOnImage[r]) );
360         adaptedSpacing[r]= fixedImageSpacing[r]* ceil( (double)(transformRegion.GetSize()[r] - 1)  /
361                            (double)(gridSizeOnImage[r] - 1) );
362       }
363       if (m_Verbose) std::cout<<"The chosen control point spacing "<<chosenSpacing<<"..."<<std::endl;
364       if (m_Verbose) std::cout<<"The control points spacing was adapted to "<<adaptedSpacing<<"..."<<std::endl;
365       if (m_Verbose) std::cout<<"The number of (internal) control points is "<<gridSizeOnImage<<"..."<<std::endl;
366     }
367
368     // Spacing and number of CP given: no adjustment adjust, just warnings
369     else if (m_ArgsInfo.control_given && m_ArgsInfo.spacing_given) {
370       for(unsigned int r=0; r<ImageDimension; r++) {
371         adaptedSpacing[r]= m_ArgsInfo.spacing_arg[r];
372         gridSizeOnImage[r] =  m_ArgsInfo.control_arg[r];
373         if (gridSizeOnImage[r]*adaptedSpacing[r]< transformRegion.GetSize()[r]*fixedImageSpacing[r]) {
374           std::cout<<"WARNING: Specified control point region ("<<gridSizeOnImage[r]*adaptedSpacing[r]
375                    <<"mm) does not cover the transform region ("<< transformRegion.GetSize()[r]*fixedImageSpacing[r]
376                    <<"mm) for dimension "<<r<<"!" <<std::endl
377                    <<"Specify only --spacing or --control for automatic adjustment..."<<std::endl;
378         }
379         if (  fmod(adaptedSpacing[r], fixedImageSpacing[r]) ) {
380           std::cout<<"WARNING: Specified control point spacing for dimension "<<r
381                    <<" does not allow exact representation of BLUT FFD!"<<std::endl
382                    <<"Spacing ratio is non-integer: "<<adaptedSpacing[r]/ fixedImageSpacing[r]<<std::endl
383                    <<"Specify only --spacing or --control for automatic adjustment..."<<std::endl;
384         }
385       }
386       if (m_Verbose) std::cout<<"The control points spacing was set to "<<adaptedSpacing<<"..."<<std::endl;
387       if (m_Verbose) std::cout<<"The number of (internal) control points spacing is "<<gridSizeOnImage<<"..."<<std::endl;
388     }
389
390     //JV  border size should depend on spline order
391     for(unsigned int r=0; r<ImageDimension; r++) gridBorderSize[r]=splineOrders[r]; // Border for spline order = 3 ( 1 lower, 2 upper )
392     totalGridSize = gridSizeOnImage + gridBorderSize;
393     bsplineRegion.SetSize( totalGridSize );
394     if (m_Verbose) std::cout<<"The total control point grid size was set to "<<totalGridSize<<"..."<<std::endl;
395
396     // Direction
397     typename FixedImageType::DirectionType gridDirection = fixedImage->GetDirection();
398     SpacingType gridOriginOffset = gridDirection * adaptedSpacing;
399
400     // Origin: 1 CP border for spatial dimensions
401     typedef typename BSplineTransformType::OriginType OriginType;
402     OriginType gridOrigin = transformRegionOrigin - gridOriginOffset;
403     if (m_Verbose) std::cout<<"The control point grid origin was set to "<<gridOrigin<<"..."<<std::endl;
404
405     // Set
406     bsplineTransform->SetGridSpacing( adaptedSpacing );
407     bsplineTransform->SetGridOrigin( gridOrigin );
408     bsplineTransform->SetGridRegion( bsplineRegion );
409     bsplineTransform->SetGridDirection( gridDirection );
410
411     //Bulk transform
412     //if (m_Verbose) std::cout<<"Setting rigid transform..."<<std::endl;
413     //bsplineTransform->SetBulkTransform( rigidTransform );
414
415     //Vector BSpline interpolator
416     //bsplineTransform->SetOutputSpacing(fixedImage->GetSpacing());
417     typename RegionType::SizeType samplingFactors;
418     for (unsigned int i=0; i< ImageDimension; i++) {
419       if (m_Verbose) std::cout<<"For dimension "<<i<<", the ideal sampling factor (if integer) is a multitude of "
420                                 << (double)adaptedSpacing[i]/ (double) fixedImageSpacing[i]<<"..."<<std::endl;
421       if (m_ArgsInfo.samplingFactor_given) samplingFactors[i]=m_ArgsInfo.samplingFactor_arg[i];
422       else samplingFactors[i]=(int) ((double)adaptedSpacing[i]/ (double) movingImage->GetSpacing()[i]);
423       if (m_Verbose) std::cout<<"Setting sampling factor "<<i<<" to "<<samplingFactors[i]<<"..."<<std::endl;
424     }
425     bsplineTransform->SetLUTSamplingFactors(samplingFactors);
426
427     //initial parameters
428     if (m_ArgsInfo.init_given) {
429       typedef itk::ImageFileReader<typename BLUTTransformType::CoefficientImageType> CoefficientReaderType;
430       typename CoefficientReaderType::Pointer coeffReader=CoefficientReaderType::New();
431       coeffReader->SetFileName(m_ArgsInfo.init_arg[0]);
432       coeffReader->Update();
433       bsplineTransform->SetCoefficientImage(coeffReader->GetOutput());
434     } else {
435       //typedef typename TransformType::ParametersType     ParametersType;
436       const unsigned int numberOfParameters =    bsplineTransform->GetNumberOfParameters();
437       parameters=ParametersType( numberOfParameters );
438       parameters.Fill( 0.0 );
439       bsplineTransform->SetParameters( parameters );
440     }
441
442     // Mask
443     if (spatialObjectMask) bsplineTransform->SetMask( spatialObjectMask );
444
445     // Pass
446     transform=bsplineTransform;
447   }
448
449   //ITK BSpline transform
450   else {
451     typename BSplineTransformType::Pointer  bsplineTransform = BSplineTransformType::New();
452
453     // Define the region
454     typedef typename BSplineTransformType::RegionType RegionType;
455     RegionType bsplineRegion;
456     typename RegionType::SizeType   gridSizeOnImage;
457     typename RegionType::SizeType   gridBorderSize;
458     typename RegionType::SizeType   totalGridSize;
459
460     //Set the number of control points
461     for(unsigned int r=0; r<ImageDimension; r++)  gridSizeOnImage[r]=m_ArgsInfo.control_arg[r];
462     if (m_Verbose) std::cout<<"Setting the number of internal control points "<<gridSizeOnImage<<"..."<<std::endl;
463     gridBorderSize.Fill( 3 );    // Border for spline order = 3 ( 1 lower, 2 upper )
464     totalGridSize = gridSizeOnImage + gridBorderSize;
465     bsplineRegion.SetSize( totalGridSize );
466
467     // Spacing
468     typedef typename BSplineTransformType::SpacingType SpacingType;
469     SpacingType spacing = fixedImage->GetSpacing();
470     typename FixedImageType::SizeType fixedImageSize = fixedImageRegion.GetSize();
471     if (m_ArgsInfo.spacing_given) {
472
473       for(unsigned int r=0; r<ImageDimension; r++) {
474         spacing[r] =m_ArgsInfo.spacing_arg[r];
475       }
476     } else {
477       for(unsigned int r=0; r<ImageDimension; r++) {
478         spacing[r] *= static_cast<double>(fixedImageSize[r] - 1)  /
479                       static_cast<double>(gridSizeOnImage[r] - 1);
480       }
481     }
482     if (m_Verbose) std::cout<<"The control points spacing was set to "<<spacing<<"..."<<std::endl;
483
484     // Direction
485     typename FixedImageType::DirectionType gridDirection = fixedImage->GetDirection();
486     SpacingType gridOriginOffset = gridDirection * spacing;
487
488     // Origin
489     typedef typename BSplineTransformType::OriginType OriginType;
490     OriginType origin = fixedImage->GetOrigin();
491     OriginType gridOrigin = origin - gridOriginOffset;
492
493     // Set
494     bsplineTransform->SetGridSpacing( spacing );
495     bsplineTransform->SetGridOrigin( gridOrigin );
496     bsplineTransform->SetGridRegion( bsplineRegion );
497     bsplineTransform->SetGridDirection( gridDirection );
498
499     // Bulk transform
500     // if (m_Verbose) std::cout<<"Setting rigid transform..."<<std::endl;
501     // bsplineTransform->SetBulkTransform( rigidTransform );
502
503     // Initial parameters
504     if (m_ArgsInfo.init_given) {
505       typedef itk::ImageFileReader<typename BSplineTransformType::ImageType> CoefficientReaderType;
506       typename BSplineTransformType::ImageType::Pointer coeffImages[SpaceDimension];
507       for(unsigned int i=0; i<SpaceDimension; i++) {
508         typename CoefficientReaderType::Pointer coeffReader=CoefficientReaderType::New();
509         coeffReader->SetFileName(m_ArgsInfo.init_arg[i]);
510         coeffReader->Update();
511         coeffImages[i]=coeffReader->GetOutput();
512       }
513       bsplineTransform->SetCoefficientImage(coeffImages);
514     } else {
515       const unsigned int numberOfParameters =    bsplineTransform->GetNumberOfParameters();
516       parameters=ParametersType( numberOfParameters );
517       parameters.Fill( 0.0 );
518       bsplineTransform->SetParameters( parameters );
519     }
520
521     // Pass
522     transform=bsplineTransform;
523
524   }
525
526   //=======================================================
527   // Interpolator
528   //=======================================================
529   typedef clitk::GenericInterpolator<args_info_clitkBSplineDeformableRegistration, FixedImageType, TCoordRep > GenericInterpolatorType;
530   typename   GenericInterpolatorType::Pointer genericInterpolator=GenericInterpolatorType::New();
531   genericInterpolator->SetArgsInfo(m_ArgsInfo);
532   typedef itk::InterpolateImageFunction< FixedImageType, TCoordRep >  InterpolatorType;
533   typename  InterpolatorType::Pointer interpolator=genericInterpolator->GetInterpolatorPointer();
534
535
536   //=======================================================
537   // Metric
538   //=======================================================
539   typedef clitk::GenericMetric<args_info_clitkBSplineDeformableRegistration, FixedImageType,MovingImageType > GenericMetricType;
540   typename GenericMetricType::Pointer genericMetric=GenericMetricType::New();
541   genericMetric->SetArgsInfo(m_ArgsInfo);
542   genericMetric->SetFixedImage(fixedImage);
543   genericMetric->SetFixedImageRegion(metricRegion);
544   typedef itk::ImageToImageMetric< FixedImageType, MovingImageType >  MetricType;
545   typename  MetricType::Pointer metric=genericMetric->GetMetricPointer();
546   if (spatialObjectMask) metric->SetFixedImageMask( spatialObjectMask );
547
548 #ifdef ITK_USE_OPTIMIZED_REGISTRATION_METHODS
549   if (threadsGiven) metric->SetNumberOfThreads( threads );
550 #else
551   if (m_Verbose) std::cout<<"Not setting the number of threads (not compiled with USE_OPTIMIZED_REGISTRATION_METHODS)..."<<std::endl;
552 #endif
553
554
555   //=======================================================
556   // Optimizer
557   //=======================================================
558   typedef clitk::GenericOptimizer<args_info_clitkBSplineDeformableRegistration> GenericOptimizerType;
559   GenericOptimizerType::Pointer genericOptimizer = GenericOptimizerType::New();
560   genericOptimizer->SetArgsInfo(m_ArgsInfo);
561   genericOptimizer->SetMaximize(genericMetric->GetMaximize());
562   genericOptimizer->SetNumberOfParameters(transform->GetNumberOfParameters());
563   typedef itk::SingleValuedNonLinearOptimizer OptimizerType;
564   OptimizerType::Pointer optimizer = genericOptimizer->GetOptimizerPointer();
565
566
567   //=======================================================
568   // Registration
569   //=======================================================
570   typedef itk::MultiResolutionImageRegistrationMethod<  FixedImageType, MovingImageType >    RegistrationType;
571   typename RegistrationType::Pointer   registration  = RegistrationType::New();
572
573   registration->SetMetric(        metric        );
574   registration->SetOptimizer(     optimizer     );
575   registration->SetInterpolator(  interpolator  );
576   registration->SetTransform (transform);
577   if(threadsGiven) registration->SetNumberOfThreads(threads);
578   registration->SetFixedImage(  fixedImage   );
579   registration->SetMovingImage(   movingImage   );
580   registration->SetFixedImageRegion( metricRegion );
581   registration->SetFixedImagePyramid( fixedImagePyramid );
582   registration->SetMovingImagePyramid( movingImagePyramid );
583   registration->SetInitialTransformParameters( transform->GetParameters() );
584   registration->SetNumberOfLevels(m_ArgsInfo.levels_arg);
585   if (m_Verbose) std::cout<<"Setting the number of resolution levels to "<<m_ArgsInfo.levels_arg<<"..."<<std::endl;
586
587
588   //================================================================================================
589   // Observers
590   //================================================================================================
591   if (m_Verbose) {
592     // Output iteration info
593     CommandIterationUpdate::Pointer observer = CommandIterationUpdate::New();
594     observer->SetOptimizer(genericOptimizer);
595     optimizer->AddObserver( itk::IterationEvent(), observer );
596
597
598     // Output level info
599     typedef RegistrationInterfaceCommand<RegistrationType> CommandType;
600     typename CommandType::Pointer command = CommandType::New();
601     registration->AddObserver( itk::IterationEvent(), command );
602   }
603
604
605   //=======================================================
606   // Let's go
607   //=======================================================
608   if (m_Verbose) std::cout << std::endl << "Starting Registration" << std::endl;
609
610   try {
611     registration->StartRegistration();
612   } catch( itk::ExceptionObject & err ) {
613     std::cerr << "ExceptionObject caught while registering!" << std::endl;
614     std::cerr << err << std::endl;
615     return;
616   }
617
618
619   //=======================================================
620   // Get the result
621   //=======================================================
622   OptimizerType::ParametersType finalParameters =  registration->GetLastTransformParameters();
623   transform->SetParameters( finalParameters );
624
625
626   //=======================================================
627   // Get the BSpline coefficient images and write them
628   //=======================================================
629   if (m_ArgsInfo.coeff_given) {
630     if(m_ArgsInfo.wlut_flag) {
631       BLUTTransformPointer bsplineTransform=dynamic_cast<BLUTTransformPointer>(registration->GetTransform());
632       typedef  itk::Image<itk::Vector<TCoordRep, SpaceDimension>, ImageDimension> CoefficientImageType;
633       typename CoefficientImageType::Pointer coefficientImage =bsplineTransform->GetCoefficientImage();
634       typedef itk::ImageFileWriter<CoefficientImageType> CoeffWriterType;
635       typename CoeffWriterType::Pointer coeffWriter=CoeffWriterType::New();
636       coeffWriter->SetInput(coefficientImage);
637       coeffWriter->SetFileName(m_ArgsInfo.coeff_arg[0]);
638       coeffWriter->Update();
639     } else {
640       BSplineTransformPointer bsplineTransform=dynamic_cast<BSplineTransformPointer>(registration->GetTransform());
641       typedef  itk::Image<TCoordRep, ImageDimension> CoefficientImageType;
642 #if ITK_VERSION_MAJOR > 3
643       typename BSplineTransformType::CoefficientImageArray coefficientImages = bsplineTransform->GetCoefficientImage();
644 #else
645       typename CoefficientImageType::Pointer *coefficientImages =bsplineTransform->GetCoefficientImage();
646 #endif
647       typedef itk::ImageFileWriter<CoefficientImageType> CoeffWriterType;
648       for (unsigned int i=0; i<std::min(SpaceDimension,m_ArgsInfo.coeff_given); i ++) {
649         typename CoeffWriterType::Pointer coeffWriter=CoeffWriterType::New();
650         coeffWriter->SetInput(coefficientImages[i]);
651         coeffWriter->SetFileName(m_ArgsInfo.coeff_arg[i]);
652         coeffWriter->Update();
653       }
654     }
655   }
656
657
658   //=======================================================
659   // Generate the DVF
660   //=======================================================
661   typedef itk::Vector< float, SpaceDimension >  DisplacementType;
662   typedef itk::Image< DisplacementType, ImageDimension >  DeformationFieldType;
663
664   typename DeformationFieldType::Pointer field = DeformationFieldType::New();
665   field->SetRegions( fixedImageRegion );
666   field->SetOrigin( fixedImage->GetOrigin() );
667   field->SetSpacing( fixedImage->GetSpacing() );
668   field->SetDirection( fixedImage->GetDirection() );
669   field->Allocate();
670
671   typedef itk::ImageRegionIteratorWithIndex< DeformationFieldType > FieldIterator;
672   FieldIterator fi( field, fixedImageRegion );
673   fi.GoToBegin();
674
675   typename TransformType::InputPointType  fixedPoint;
676   typename TransformType::OutputPointType movingPoint;
677   typename DeformationFieldType::IndexType index;
678
679   DisplacementType displacement;
680   while( ! fi.IsAtEnd() ) {
681     index = fi.GetIndex();
682     field->TransformIndexToPhysicalPoint( index, fixedPoint );
683     movingPoint = transform->TransformPoint( fixedPoint );
684     displacement = movingPoint - fixedPoint;
685     fi.Set( displacement );
686     ++fi;
687   }
688
689
690   //=======================================================
691   // Write the DVF
692   //=======================================================
693   typedef itk::ImageFileWriter< DeformationFieldType >  FieldWriterType;
694   typename FieldWriterType::Pointer fieldWriter = FieldWriterType::New();
695   fieldWriter->SetFileName( m_ArgsInfo.vf_arg );
696   fieldWriter->SetInput( field );
697   try {
698     fieldWriter->Update();
699   } catch( itk::ExceptionObject & excp ) {
700     std::cerr << "Exception thrown writing the DVF" << std::endl;
701     std::cerr << excp << std::endl;
702     return;
703   }
704
705
706   //=======================================================
707   // Resample the moving image
708   //=======================================================
709   typedef itk::WarpImageFilter< MovingImageType, FixedImageType, DeformationFieldType >    WarpFilterType;
710   typename WarpFilterType::Pointer warp = WarpFilterType::New();
711
712   warp->SetDeformationField( field );
713   warp->SetInput( movingImageReader->GetOutput() );
714   warp->SetOutputOrigin(  fixedImage->GetOrigin() );
715   warp->SetOutputSpacing( fixedImage->GetSpacing() );
716   warp->SetOutputDirection( fixedImage->GetDirection() );
717   warp->SetEdgePaddingValue( 0.0 );
718   warp->Update();
719
720
721   //=======================================================
722   // Write the warped image
723   //=======================================================
724   typedef itk::ImageFileWriter< FixedImageType >  WriterType;
725   typename WriterType::Pointer      writer =  WriterType::New();
726   writer->SetFileName( m_ArgsInfo.output_arg );
727   writer->SetInput( warp->GetOutput()    );
728
729   try {
730     writer->Update();
731   } catch( itk::ExceptionObject & err ) {
732     std::cerr << "ExceptionObject caught !" << std::endl;
733     std::cerr << err << std::endl;
734     return;
735   }
736
737 DD("here")
738   //=======================================================
739   // Calculate the difference after the deformable transform
740   //=======================================================
741   typedef clitk::DifferenceImageFilter<  FixedImageType, FixedImageType> DifferenceFilterType;
742   if (m_ArgsInfo.after_given) {
743     typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New();
744     difference->SetValidInput( fixedImage );
745     difference->SetTestInput( warp->GetOutput() );
746
747     try {
748       difference->Update();
749     } catch( itk::ExceptionObject & err ) {
750       std::cerr << "ExceptionObject caught calculating the difference !" << std::endl;
751       std::cerr << err << std::endl;
752       return;
753     }
754
755     typename WriterType::Pointer differenceWriter=WriterType::New();
756     differenceWriter->SetInput(difference->GetOutput());
757     differenceWriter->SetFileName(m_ArgsInfo.after_arg);
758     differenceWriter->Update();
759
760   }
761
762
763   //=======================================================
764   // Calculate the difference before the deformable transform
765   //=======================================================
766   if( m_ArgsInfo.before_given ) {
767
768     typename FixedImageType::Pointer moving=FixedImageType::New();
769     if (m_ArgsInfo.rigid_given) {
770       typedef itk::ResampleImageFilter<MovingImageType, FixedImageType> ResamplerType;
771       typename ResamplerType::Pointer resampler=ResamplerType::New();
772       resampler->SetInput(movingImage);
773       resampler->SetOutputOrigin(fixedImage->GetOrigin());
774       resampler->SetSize(fixedImage->GetLargestPossibleRegion().GetSize());
775       resampler->SetOutputSpacing(fixedImage->GetSpacing());
776       resampler->SetDefaultPixelValue( 0. );
777       //resampler->SetTransform(rigidTransform);
778       resampler->Update();
779       moving=resampler->GetOutput();
780     } else
781       moving=movingImage;
782
783     typename DifferenceFilterType::Pointer difference = DifferenceFilterType::New();
784     difference->SetValidInput( fixedImage );
785     difference->SetTestInput( moving );
786
787     try {
788       difference->Update();
789     } catch( itk::ExceptionObject & err ) {
790       std::cerr << "ExceptionObject caught calculating the difference !" << std::endl;
791       std::cerr << err << std::endl;
792       return;
793     }
794
795     typename WriterType::Pointer differenceWriter=WriterType::New();
796     writer->SetFileName( m_ArgsInfo.before_arg  );
797     writer->SetInput( difference->GetOutput()  );
798     writer->Update( );
799   }
800
801   return;
802 }
803 }
804
805 #endif // __clitkBSplineDeformableRegistrationGenericFilter_txx