]> Creatis software - FrontAlgorithms.git/blob - appli/CTArteries/algorithms/RandomWalkSegmentation.hxx
...
[FrontAlgorithms.git] / appli / CTArteries / algorithms / RandomWalkSegmentation.hxx
1 // =========================================================================
2 // @author Leonardo Florez-Valencia (florez-l@javeriana.edu.co)
3 // =========================================================================
4 #ifndef __RandomWalkSegmentation__hxx__
5 #define __RandomWalkSegmentation__hxx__
6
7 #include <itkBinaryThresholdImageFilter.h>
8 #include <itkSmoothingRecursiveGaussianImageFilter.h>
9 #include <ivq/ITK/RegionOfInterestWithPaddingImageFilter.h>
10 #include <fpa/Filters/Image/ExtractAxis.h>
11 #include <fpa/Filters/Image/RandomWalker.h>
12 #include "DijkstraWithMeanAndVariance.h"
13 #include "RandomWalkLabelling.h"
14
15 /* TODO
16    #include <itkImageRegionConstIterator.h>
17    #include <itkImageRegionIterator.h>
18
19
20    #include <fpa/Functors/Dijkstra/Image/Gaussian.h>
21
22    #include <itkImageFileWriter.h>
23 */
24
25 // -------------------------------------------------------------------------
26 template< class _TInputImage, class _TOutputImage >
27 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
28 AddSeed( const TIndex& s )
29 {
30   TSeed seed;
31   seed.Index = s;
32   seed.IsPoint = false;
33   this->m_Seeds.push_back( seed );
34   this->Modified( );
35 }
36
37 // -------------------------------------------------------------------------
38 template< class _TInputImage, class _TOutputImage >
39 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
40 AddSeed( const TPoint& s )
41 {
42   TSeed seed;
43   seed.Point = s;
44   seed.IsPoint = true;
45   this->m_Seeds.push_back( seed );
46   this->Modified( );
47 }
48
49 // -------------------------------------------------------------------------
50 template< class _TInputImage, class _TOutputImage >
51 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
52 ClearSeeds( )
53 {
54   this->m_Seeds.clear( );
55   this->Modified( );
56 }
57
58 // -------------------------------------------------------------------------
59 template< class _TInputImage, class _TOutputImage >
60 unsigned long RandomWalkSegmentation< _TInputImage, _TOutputImage >::
61 GetNumberOfSeeds( ) const
62 {
63   return( this->m_Seeds.size( ) );
64 }
65
66 // -------------------------------------------------------------------------
67 template< class _TInputImage, class _TOutputImage >
68 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
69 SetStartSeed( const TIndex& s )
70 {
71   this->m_StartSeed.Index = s;
72   this->m_StartSeed.IsPoint = false;
73   this->Modified( );
74 }
75
76 // -------------------------------------------------------------------------
77 template< class _TInputImage, class _TOutputImage >
78 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
79 SetStartSeed( const TPoint& s )
80 {
81   this->m_StartSeed.Point = s;
82   this->m_StartSeed.IsPoint = true;
83   this->Modified( );
84 }
85
86 // -------------------------------------------------------------------------
87 template< class _TInputImage, class _TOutputImage >
88 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
89 SetEndSeed( const TIndex& s )
90 {
91   this->m_EndSeed.Index = s;
92   this->m_EndSeed.IsPoint = false;
93   this->Modified( );
94 }
95
96 // -------------------------------------------------------------------------
97 template< class _TInputImage, class _TOutputImage >
98 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
99 SetEndSeed( const TPoint& s )
100 {
101   this->m_EndSeed.Point = s;
102   this->m_EndSeed.IsPoint = true;
103   this->Modified( );
104 }
105
106 // -------------------------------------------------------------------------
107 template< class _TInputImage, class _TOutputImage >
108 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
109 RandomWalkSegmentation( )
110   : Superclass( ),
111     m_Beta( double( 20 ) ),
112     m_Sigma( double( 10 ) ),
113     m_Radius( double( 5 ) )
114 {
115   fpaFilterOutputConfigureMacro( OutputAxis, TPath );
116 }
117
118 // -------------------------------------------------------------------------
119 template< class _TInputImage, class _TOutputImage >
120 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
121 ~RandomWalkSegmentation( )
122 {
123 }
124
125 // -------------------------------------------------------------------------
126 template< class _TInputImage, class _TOutputImage >
127 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
128 GenerateData( )
129 {
130   typedef unsigned char _TLabel;
131   typedef typename TOutputImage::PixelType _TScalar;
132   typedef itk::Image< _TLabel, TInputImage::ImageDimension > _TLabels;
133
134   // Prepare initial seeds
135   const TInputImage* input = this->GetInput( );
136   this->_SynchSeed( input, this->m_StartSeed );
137   this->_SynchSeed( input, this->m_EndSeed );
138   for( TSeed& seed: this->m_Seeds )
139     this->_SynchSeed( input, seed );
140
141   // Smooth input
142   typename TOutputImage::Pointer smooth_in;
143   this->_Smooth( this->GetInput( ), smooth_in );
144
145   // Initial segmentation
146   typename TOutputImage::Pointer init_seg;
147   typename TPath::Pointer init_axis;
148   _TScalar init_mean, init_std;
149   typename TInputImage::RegionType roi =
150     this->_RawSegmentation(
151       smooth_in.GetPointer( ), this->m_Seeds,
152       this->m_StartSeed.Index, this->m_EndSeed.Index,
153       this->m_Beta,
154       init_seg, init_axis, init_mean, init_std
155       );
156   init_std *= _TScalar( this->m_Sigma );
157
158   // Extract input ROIs
159   typename TOutputImage::Pointer smooth_in_roi, init_seg_roi;
160   roi = this->_ROI( smooth_in.GetPointer( ), roi, 10, smooth_in_roi );
161   this->_ROI( init_seg.GetPointer( ), roi, 0, init_seg_roi );
162   typename TPath::Pointer init_axis_roi = TPath::New( );
163   init_axis_roi->SetReferenceImage( smooth_in_roi.GetPointer( ) );
164   this->_AxisROI( init_axis.GetPointer( ), roi, init_axis_roi );
165
166   // Labelling
167   typename _TLabels::Pointer init_labels;
168   _TScalar radius = _TScalar( this->m_Radius );
169   this->_Label(
170     smooth_in_roi.GetPointer( ),
171     init_seg_roi.GetPointer( ),
172     init_axis_roi.GetPointer( ),
173     init_mean, init_std, radius,
174     init_labels
175     );
176
177   // Random walker
178   typename _TLabels::Pointer rw_seg;
179   this->_RandomWalker(
180     smooth_in_roi.GetPointer( ),
181     init_labels.GetPointer( ),
182     init_std / _TScalar( 2 ),
183     rw_seg
184     );
185
186   // ROI outputs
187   typename TOutputImage::Pointer out_dist;
188   typename TPath::Pointer out_axis;
189   this->_DistanceAndAxis(
190     rw_seg.GetPointer( ),
191     this->m_Seeds, this->m_StartSeed, this->m_EndSeed,
192     out_dist, out_axis
193     );
194
195   // Put everything back to requested region
196   /* TODO
197      std::cout << "6" << std::endl;
198      { // begin
199      TOutputImage* output = this->GetOutput( );
200      output->SetBufferedRegion( output->GetRequestedRegion( ) );
201      output->Allocate( );
202      output->FillBuffer( -std::numeric_limits< _TScalar >::max( ) );
203
204      itk::ImageRegionConstIterator< TOutputImage > rIt(
205      output_roi, output_roi->GetRequestedRegion( )
206      );
207      itk::ImageRegionIterator< TOutputImage > oIt( output, roi );
208      rIt.GoToBegin( );
209      oIt.GoToBegin( );
210      for( ; !rIt.IsAtEnd( ); ++rIt, ++oIt )
211      oIt.Set( rIt.Get( ) );
212
213      TPath* output_axis = this->GetOutputAxis( );
214      output_axis->SetReferenceImage( output );
215      for( unsigned long i = 0; i < output_axis_roi->GetSize( ); ++i )
216      {
217      TIndex v = output_axis_roi->GetVertex( i );
218      for( unsigned int d = 0; d < TInputImage::ImageDimension; ++d )
219      v[ d ] += roi.GetIndex( )[ d ];
220      output_axis->AddVertex( v );
221
222      } // rof
223
224      } // end
225      std::cout << "7" << std::endl;
226   */
227 }
228
229 // -------------------------------------------------------------------------
230 template< class _TInputImage, class _TOutputImage >
231 template< class _TIn, class _TSeed >
232 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
233 _SynchSeed( const _TIn* in, _TSeed& seed )
234 {
235   if( seed.IsPoint )
236     in->TransformPhysicalPointToIndex( seed.Point, seed.Index );
237   else
238     in->TransformIndexToPhysicalPoint( seed.Index, seed.Point );
239   seed.IsPoint = true;
240 }
241
242 // -------------------------------------------------------------------------
243 template< class _TInputImage, class _TOutputImage >
244 template< class _TIn, class _TOutPtr >
245 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
246 _Smooth( const _TIn* in, _TOutPtr& out )
247 {
248   typedef typename _TOutPtr::ObjectType _TOut;
249   typedef itk::SmoothingRecursiveGaussianImageFilter< _TIn, _TOut > _TSmooth;
250
251   typename _TSmooth::Pointer smooth = _TSmooth::New( );
252   smooth->SetInput( in );
253   smooth->SetNormalizeAcrossScale( true );
254   smooth->SetSigmaArray( in->GetSpacing( ) * double( 2 ) );
255   smooth->Update( );
256   out = smooth->GetOutput( );
257   out->DisconnectPipeline( );
258 }
259
260 // -------------------------------------------------------------------------
261 template< class _TInputImage, class _TOutputImage >
262 template< class _TIn, class _TOutPtr, class _TAxisPtr, class _TSeeds >
263 typename _TIn::RegionType
264 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
265 _RawSegmentation(
266   const _TIn* in, const _TSeeds& seeds,
267   const typename _TIn::IndexType& s0,
268   const typename _TIn::IndexType& s1,
269   const typename _TOutPtr::ObjectType::PixelType& beta,
270   _TOutPtr& out, _TAxisPtr& out_axis,
271   typename _TOutPtr::ObjectType::PixelType& oMean,
272   typename _TOutPtr::ObjectType::PixelType& oSTD
273   )
274 {
275   typedef typename _TOutPtr::ObjectType _TOut;
276   typedef typename _TOut::PixelType  _TScalar;
277   typedef DijkstraWithMeanAndVariance< _TIn, _TOut > _TInit;
278   typedef fpa::Functors::Dijkstra::Image::Gaussian< _TIn, _TScalar > _TFun;
279
280   typename _TFun::Pointer fun = _TFun::New( );
281   fun->SetBeta( beta );
282
283   typename _TInit::Pointer init = _TInit::New( );
284   init->SetInput( in );
285   init->SetWeightFunction( fun );
286   init->StopAtOneFrontOn( );
287   for( typename _TSeeds::value_type seed: seeds )
288     init->AddSeed( seed.Point );
289   init->Update( );
290
291   // Get initial values
292   oMean = _TScalar( init->GetMean( ) );
293   oSTD = _TScalar( init->GetDeviation( ) );
294
295   // Get initial objects
296   init->GetMinimumSpanningTree( )->GetPath( out_axis, s0, s1 );
297   out = init->GetOutput( );
298   out->DisconnectPipeline( );
299
300   // Get ROI
301   typename _TIn::IndexType min_idx = init->GetMinVertex( );
302   typename _TIn::IndexType max_idx = init->GetMaxVertex( );
303   typename _TIn::SizeType roi_size;
304   for( unsigned int i = 0; i < TInputImage::ImageDimension; ++i )
305     roi_size[ i ] = max_idx[ i ] - min_idx[ i ] + 1;
306   typename _TIn::RegionType roi;
307   roi.SetIndex( min_idx );
308   roi.SetSize( roi_size );
309   return( roi );
310 }
311
312 // -------------------------------------------------------------------------
313 template< class _TInputImage, class _TOutputImage >
314 template< class _TIn, class _TOutPtr >
315 typename _TIn::RegionType
316 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
317 _ROI(
318   const _TIn* in, const typename _TIn::RegionType& roi, unsigned int pad,
319   _TOutPtr& out
320   )
321 {
322   typedef typename _TOutPtr::ObjectType _TOut;
323   typedef ivq::ITK::RegionOfInterestWithPaddingImageFilter< _TIn, _TOut > _TROI;
324
325   typename _TROI::Pointer filter = _TROI::New( );
326   filter->SetInput( in );
327   filter->SetRegionOfInterest( roi );
328   filter->SetPadding( pad );
329   filter->Update( );
330   out = filter->GetOutput( );
331   out->DisconnectPipeline( );
332   return( filter->GetRegionOfInterest( ) );
333 }
334
335 // -------------------------------------------------------------------------
336 template< class _TInputImage, class _TOutputImage >
337 template< class _TAxisPtr, class _TRegion >
338 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
339 _AxisROI(
340   const typename _TAxisPtr::ObjectType* in, const _TRegion& roi,
341   _TAxisPtr& out
342   )
343 {
344   typedef typename _TAxisPtr::ObjectType _TAxis;
345
346   for( unsigned long i = 0; i < in->GetSize( ); ++i )
347   {
348     typename _TAxis::TIndex v = in->GetVertex( i );
349     for( unsigned int d = 0; d < _TAxis::Dimension; ++d )
350       v[ d ] -= roi.GetIndex( )[ d ];
351     out->AddVertex( v );
352
353   } // rof
354 }
355
356 // -------------------------------------------------------------------------
357 template< class _TInputImage, class _TOutputImage >
358 template< class _TInRaw, class _TInCosts, class _TAxis, class _TScalar, class _TOutPtr >
359 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
360 _Label(
361   const _TInRaw* raw, const _TInCosts* costs, const _TAxis* axis,
362   const _TScalar& mean, const _TScalar& dev, const _TScalar& radius,
363   _TOutPtr& out
364   )
365 {
366   typedef typename _TOutPtr::ObjectType _TOut;
367   typedef RandomWalkLabelling< _TInRaw, _TInCosts, _TOut > _TLabel;
368
369   typename _TLabel::Pointer label = _TLabel::New( );
370   label->SetInputImage( raw );
371   label->SetInputCosts( costs );
372   label->SetInputPath( axis );
373   label->SetInsideLabel( 1 );
374   label->SetOutsideLabel( 2 );
375   label->SetLowerLabel( 3 );
376   label->SetUpperLabel( 4 );
377   label->SetRadius( radius );
378   label->SetLowerThreshold( mean - dev );
379   label->SetUpperThreshold( mean + dev );
380   label->Update( );
381   out = label->GetOutputLabels( );
382   out->DisconnectPipeline( );
383 }
384
385 // -------------------------------------------------------------------------
386 template< class _TInputImage, class _TOutputImage >
387 template< class _TIn, class _TOutPtr >
388 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
389 _RandomWalker(
390   const _TIn* in, const typename _TOutPtr::ObjectType* labels,
391   const typename _TIn::PixelType& beta,
392   _TOutPtr& out
393   )
394 {
395   typedef typename _TIn::PixelType _TScalar;
396   typedef typename _TOutPtr::ObjectType _TOut;
397   typedef fpa::Functors::Dijkstra::Image::Gaussian< _TIn, _TScalar > _TFun;
398   typedef fpa::Filters::Image::RandomWalker< _TIn, _TOut, _TScalar > _TRandomWalker;
399   typedef itk::BinaryThresholdImageFilter< _TOut, _TOut > _TExtract;
400
401   typename _TFun::Pointer fun = _TFun::New( );
402   fun->SetBeta( beta );
403
404   typename _TRandomWalker::Pointer rw = _TRandomWalker::New( );
405   rw->SetInputImage( in );
406   rw->SetInputLabels( labels );
407   rw->SetWeightFunction( fun );
408
409   typename _TExtract::Pointer extract = _TExtract::New( );
410   extract->SetInput( rw->GetOutputLabels( ) );
411   extract->SetInsideValue( 1 );
412   extract->SetOutsideValue( 0 );
413   extract->SetLowerThreshold( 1 );
414   extract->SetUpperThreshold( 1 );
415   extract->Update( );
416   out = extract->GetOutput( );
417   out->DisconnectPipeline( );
418 }
419
420 // -------------------------------------------------------------------------
421 template< class _TInputImage, class _TOutputImage >
422 template< class _TIn, class _TSeeds, class _TSeed, class _TOutPtr, class _TAxisPtr >
423 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
424 _DistanceAndAxis(
425   const _TIn* in, const _TSeeds& seeds,
426   const _TSeed& p0, const _TSeed& p1,
427   _TOutPtr& out_dist, _TAxisPtr& out_axis
428   )
429 {
430   typedef typename _TOutPtr::ObjectType _TOut;
431   typedef typename _TOut::PixelType _TScalar;
432   typedef fpa::Filters::Image::ExtractAxis< _TIn, _TScalar > _TExtract;
433
434   // Prepare output values
435   typename _TIn::IndexType s0, s1;
436   in->TransformPhysicalPointToIndex( p0.Point, s0 );
437   in->TransformPhysicalPointToIndex( p1.Point, s1 );
438
439   typename _TExtract::Pointer extract = _TExtract::New( );
440   extract->SetInput( const_cast< _TIn* >( in ) );
441   extract->AddSeed( s0 );
442   extract->AddSeed( s1 );
443   for( typename _TSeeds::value_type seed: seeds )
444     extract->AddSeed( seed.Point );
445   extract->SetStartIndex( s0 );
446   extract->SetEndIndex( s1 );
447   extract->Update( );
448   out_axis = TPath::New( );
449   out_axis->Graft( extract->GetOutput( ) );
450   this->_Smooth( extract->GetCenterness( )->GetOutput( ), out_dist );
451 }
452
453 #endif // __RandomWalkSegmentation__hxx__
454
455 // eof - $RCSfile$