]> Creatis software - FrontAlgorithms.git/blob - appli/CTArteries/algorithms/RandomWalkSegmentation.hxx
...
[FrontAlgorithms.git] / appli / CTArteries / algorithms / RandomWalkSegmentation.hxx
1 // =========================================================================
2 // @author Leonardo Florez-Valencia (florez-l@javeriana.edu.co)
3 // =========================================================================
4 #ifndef __RandomWalkSegmentation__hxx__
5 #define __RandomWalkSegmentation__hxx__
6
7 #include <itkBinaryThresholdImageFilter.h>
8 #include <itkSmoothingRecursiveGaussianImageFilter.h>
9 #include <ivq/ITK/RegionOfInterestWithPaddingImageFilter.h>
10 #include <fpa/Filters/Image/ExtractAxis.h>
11 #include <fpa/Filters/Image/RandomWalker.h>
12 #include "DijkstraWithMeanAndVariance.h"
13 #include "RandomWalkLabelling.h"
14
15 #include <itkImageFileWriter.h>
16
17 // -------------------------------------------------------------------------
18 template< class _TInputImage, class _TOutputImage >
19 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
20 AddSeed( const TIndex& s )
21 {
22   TSeed seed;
23   seed.Index = s;
24   seed.IsPoint = false;
25   this->m_Seeds.push_back( seed );
26   this->Modified( );
27 }
28
29 // -------------------------------------------------------------------------
30 template< class _TInputImage, class _TOutputImage >
31 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
32 AddSeed( const TPoint& s )
33 {
34   TSeed seed;
35   seed.Point = s;
36   seed.IsPoint = true;
37   this->m_Seeds.push_back( seed );
38   this->Modified( );
39 }
40
41 // -------------------------------------------------------------------------
42 template< class _TInputImage, class _TOutputImage >
43 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
44 ClearSeeds( )
45 {
46   this->m_Seeds.clear( );
47   this->Modified( );
48 }
49
50 // -------------------------------------------------------------------------
51 template< class _TInputImage, class _TOutputImage >
52 unsigned long RandomWalkSegmentation< _TInputImage, _TOutputImage >::
53 GetNumberOfSeeds( ) const
54 {
55   return( this->m_Seeds.size( ) );
56 }
57
58 // -------------------------------------------------------------------------
59 template< class _TInputImage, class _TOutputImage >
60 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
61 SetStartSeed( const TIndex& s )
62 {
63   this->m_StartSeed.Index = s;
64   this->m_StartSeed.IsPoint = false;
65   this->Modified( );
66 }
67
68 // -------------------------------------------------------------------------
69 template< class _TInputImage, class _TOutputImage >
70 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
71 SetStartSeed( const TPoint& s )
72 {
73   this->m_StartSeed.Point = s;
74   this->m_StartSeed.IsPoint = true;
75   this->Modified( );
76 }
77
78 // -------------------------------------------------------------------------
79 template< class _TInputImage, class _TOutputImage >
80 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
81 SetEndSeed( const TIndex& s )
82 {
83   this->m_EndSeed.Index = s;
84   this->m_EndSeed.IsPoint = false;
85   this->Modified( );
86 }
87
88 // -------------------------------------------------------------------------
89 template< class _TInputImage, class _TOutputImage >
90 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
91 SetEndSeed( const TPoint& s )
92 {
93   this->m_EndSeed.Point = s;
94   this->m_EndSeed.IsPoint = true;
95   this->Modified( );
96 }
97
98 // -------------------------------------------------------------------------
99 template< class _TInputImage, class _TOutputImage >
100 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
101 RandomWalkSegmentation( )
102   : Superclass( ),
103     m_Beta( double( 20 ) ),
104     m_Sigma( double( 10 ) ),
105     m_Radius( double( 5 ) )
106 {
107   fpaFilterOutputConfigureMacro( OutputAxis, TPath );
108 }
109
110 // -------------------------------------------------------------------------
111 template< class _TInputImage, class _TOutputImage >
112 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
113 ~RandomWalkSegmentation( )
114 {
115 }
116
117 // -------------------------------------------------------------------------
118 template< class _TInputImage, class _TOutputImage >
119 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
120 GenerateData( )
121 {
122   typedef unsigned char _TLabel;
123   typedef typename TOutputImage::PixelType _TScalar;
124   typedef itk::Image< _TLabel, TInputImage::ImageDimension > _TLabels;
125
126   // Prepare initial seeds
127   const TInputImage* input = this->GetInput( );
128   this->_SynchSeed( input, this->m_StartSeed );
129   this->_SynchSeed( input, this->m_EndSeed );
130   for( TSeed& seed: this->m_Seeds )
131     this->_SynchSeed( input, seed );
132
133   // Smooth input
134   typename TOutputImage::Pointer smooth_in;
135   std::cout << "smooth" << std::endl;
136   this->_Smooth( this->GetInput( ), smooth_in, 2 );
137   this->_Save( smooth_in, "smooth.mhd" );
138
139   // Initial segmentation
140   std::cout << "raw" << std::endl;
141   typename TOutputImage::Pointer init_seg;
142   typename TPath::Pointer init_axis;
143   _TScalar init_mean, init_std;
144   typename TInputImage::RegionType roi =
145     this->_RawSegmentation(
146       smooth_in.GetPointer( ), this->m_Seeds,
147       this->m_StartSeed.Index, this->m_EndSeed.Index,
148       this->m_Beta,
149       init_seg, init_axis, init_mean, init_std
150       );
151   std::cout << "Stat: " << init_mean << " +/- " << init_std << std::endl;
152   init_std *= _TScalar( this->m_Sigma );
153   this->_Save( init_seg, "raw.mhd" );
154
155   // Extract input ROIs
156   std::cout << "ROI" << std::endl;
157   typename TOutputImage::Pointer smooth_in_roi, init_seg_roi;
158   roi = this->_ROI( smooth_in.GetPointer( ), roi, 10, smooth_in_roi );
159   this->_ROI( init_seg.GetPointer( ), roi, 0, init_seg_roi );
160   typename TPath::Pointer init_axis_roi = TPath::New( );
161   init_axis_roi->SetReferenceImage( smooth_in_roi.GetPointer( ) );
162   this->_AxisROI( init_axis.GetPointer( ), roi, init_axis_roi );
163
164   // Labelling
165   std::cout << "labelling" << std::endl;
166   typename _TLabels::Pointer init_labels;
167   _TScalar radius = _TScalar( this->m_Radius );
168   this->_Label(
169     smooth_in_roi.GetPointer( ),
170     init_seg_roi.GetPointer( ),
171     init_axis_roi.GetPointer( ),
172     init_mean, init_std, radius,
173     init_labels
174     );
175   this->_Save( init_labels, "init_labels.mhd" );
176
177   // Random walker
178   std::cout << "random walker " << init_std << " " << this->m_Beta << std::endl;
179   typename _TLabels::Pointer rw_seg;
180   this->_RandomWalker(
181     smooth_in_roi.GetPointer( ),
182     init_labels.GetPointer( ),
183     this->m_Beta, // init_std / _TScalar( 2 ),
184     rw_seg
185     );
186
187   // ROI outputs
188   std::cout << "axis" << std::endl;
189   typename TOutputImage::Pointer out_dist;
190   typename TPath::Pointer out_axis;
191   this->_DistanceAndAxis(
192     rw_seg.GetPointer( ),
193     this->m_Seeds, this->m_StartSeed, this->m_EndSeed,
194     out_dist, out_axis
195     );
196
197   // Put everything back to requested region
198   std::cout << "output" << std::endl;
199   { // begin
200     TOutputImage* output = this->GetOutput( );
201     output->SetBufferedRegion( output->GetRequestedRegion( ) );
202     output->Allocate( );
203     output->FillBuffer( -std::numeric_limits< _TScalar >::max( ) );
204
205     itk::ImageRegionConstIterator< TOutputImage > rIt(
206       out_dist, out_dist->GetRequestedRegion( )
207       );
208     itk::ImageRegionIterator< TOutputImage > oIt( output, roi );
209     rIt.GoToBegin( );
210     oIt.GoToBegin( );
211     for( ; !rIt.IsAtEnd( ); ++rIt, ++oIt )
212       oIt.Set( rIt.Get( ) );
213
214     TPath* output_axis = this->GetOutputAxis( );
215     output_axis->SetReferenceImage( output );
216     for( unsigned long i = 0; i < out_axis->GetSize( ); ++i )
217     {
218       TIndex v = out_axis->GetVertex( i );
219       for( unsigned int d = 0; d < TInputImage::ImageDimension; ++d )
220         v[ d ] += roi.GetIndex( )[ d ];
221       output_axis->AddVertex( v );
222
223     } // rof
224
225   } // end
226 }
227
228 // -------------------------------------------------------------------------
229 template< class _TInputImage, class _TOutputImage >
230 template< class _TIn, class _TSeed >
231 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
232 _SynchSeed( const _TIn* in, _TSeed& seed )
233 {
234   if( seed.IsPoint )
235     in->TransformPhysicalPointToIndex( seed.Point, seed.Index );
236   else
237     in->TransformIndexToPhysicalPoint( seed.Index, seed.Point );
238   seed.IsPoint = true;
239 }
240
241 // -------------------------------------------------------------------------
242 template< class _TInputImage, class _TOutputImage >
243 template< class _TIn, class _TOutPtr >
244 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
245 _Smooth( const _TIn* in, _TOutPtr& out, double s )
246 {
247   typedef typename _TOutPtr::ObjectType _TOut;
248   typedef itk::SmoothingRecursiveGaussianImageFilter< _TIn, _TOut > _TSmooth;
249
250   typename _TSmooth::Pointer smooth = _TSmooth::New( );
251   smooth->SetInput( in );
252   smooth->SetNormalizeAcrossScale( true );
253   smooth->SetSigmaArray( in->GetSpacing( ) * s );
254   smooth->Update( );
255   out = smooth->GetOutput( );
256   out->DisconnectPipeline( );
257 }
258
259 // -------------------------------------------------------------------------
260 template< class _TInputImage, class _TOutputImage >
261 template< class _TIn, class _TOutPtr, class _TAxisPtr, class _TSeeds >
262 typename _TIn::RegionType
263 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
264 _RawSegmentation(
265   const _TIn* in, const _TSeeds& seeds,
266   const typename _TIn::IndexType& s0,
267   const typename _TIn::IndexType& s1,
268   const typename _TOutPtr::ObjectType::PixelType& beta,
269   _TOutPtr& out, _TAxisPtr& out_axis,
270   typename _TOutPtr::ObjectType::PixelType& oMean,
271   typename _TOutPtr::ObjectType::PixelType& oSTD
272   )
273 {
274   typedef typename _TOutPtr::ObjectType _TOut;
275   typedef typename _TOut::PixelType  _TScalar;
276   typedef DijkstraWithMeanAndVariance< _TIn, _TOut > _TInit;
277   typedef fpa::Functors::Dijkstra::Image::Gaussian< _TIn, _TScalar > _TFun;
278
279   typename _TFun::Pointer fun = _TFun::New( );
280   fun->SetBeta( beta );
281
282   typename _TInit::Pointer init = _TInit::New( );
283   init->SetInput( in );
284   init->SetWeightFunction( fun );
285   init->StopAtOneFrontOn( );
286   for( typename _TSeeds::value_type seed: seeds )
287     init->AddSeed( seed.Point );
288   init->Update( );
289
290   // Get initial values
291   oMean = _TScalar( init->GetMean( ) );
292   oSTD = _TScalar( init->GetDeviation( ) );
293
294   // Get initial objects
295   init->GetMinimumSpanningTree( )->GetPath( out_axis, s0, s1 );
296   out = init->GetOutput( );
297   out->DisconnectPipeline( );
298
299   // Get ROI
300   typename _TIn::IndexType min_idx = init->GetMinVertex( );
301   typename _TIn::IndexType max_idx = init->GetMaxVertex( );
302   typename _TIn::SizeType roi_size;
303   for( unsigned int i = 0; i < TInputImage::ImageDimension; ++i )
304     roi_size[ i ] = max_idx[ i ] - min_idx[ i ] + 1;
305   typename _TIn::RegionType roi;
306   roi.SetIndex( min_idx );
307   roi.SetSize( roi_size );
308   return( roi );
309 }
310
311 // -------------------------------------------------------------------------
312 template< class _TInputImage, class _TOutputImage >
313 template< class _TIn, class _TOutPtr >
314 typename _TIn::RegionType
315 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
316 _ROI(
317   const _TIn* in, const typename _TIn::RegionType& roi, unsigned int pad,
318   _TOutPtr& out
319   )
320 {
321   typedef typename _TOutPtr::ObjectType _TOut;
322   typedef ivq::ITK::RegionOfInterestWithPaddingImageFilter< _TIn, _TOut > _TROI;
323
324   typename _TROI::Pointer filter = _TROI::New( );
325   filter->SetInput( in );
326   filter->SetRegionOfInterest( roi );
327   filter->SetPadding( pad );
328   filter->Update( );
329   out = filter->GetOutput( );
330   out->DisconnectPipeline( );
331   return( filter->GetRegionOfInterest( ) );
332 }
333
334 // -------------------------------------------------------------------------
335 template< class _TInputImage, class _TOutputImage >
336 template< class _TAxisPtr, class _TRegion >
337 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
338 _AxisROI(
339   const typename _TAxisPtr::ObjectType* in, const _TRegion& roi,
340   _TAxisPtr& out
341   )
342 {
343   typedef typename _TAxisPtr::ObjectType _TAxis;
344
345   for( unsigned long i = 0; i < in->GetSize( ); ++i )
346   {
347     typename _TAxis::TIndex v = in->GetVertex( i );
348     for( unsigned int d = 0; d < _TAxis::Dimension; ++d )
349       v[ d ] -= roi.GetIndex( )[ d ];
350     out->AddVertex( v );
351
352   } // rof
353 }
354
355 // -------------------------------------------------------------------------
356 template< class _TInputImage, class _TOutputImage >
357 template< class _TInRaw, class _TInCosts, class _TAxis, class _TScalar, class _TOutPtr >
358 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
359 _Label(
360   const _TInRaw* raw, const _TInCosts* costs, const _TAxis* axis,
361   const _TScalar& mean, const _TScalar& dev, const _TScalar& radius,
362   _TOutPtr& out
363   )
364 {
365   typedef typename _TOutPtr::ObjectType _TOut;
366   typedef RandomWalkLabelling< _TInRaw, _TInCosts, _TOut > _TLabel;
367
368   typename _TLabel::Pointer label = _TLabel::New( );
369   label->SetInputImage( raw );
370   label->SetInputCosts( costs );
371   label->SetInputPath( axis );
372   label->SetInsideLabel( 1 );
373   label->SetOutsideLabel( 2 );
374   label->SetLowerLabel( 3 );
375   label->SetUpperLabel( 4 );
376   label->SetRadius( radius );
377   label->SetLowerThreshold( mean - dev );
378   label->SetUpperThreshold( mean + dev );
379   label->Update( );
380   out = label->GetOutputLabels( );
381   out->DisconnectPipeline( );
382
383   std::cout << label->GetLowerThreshold( ) << std::endl;
384   std::cout << label->GetUpperThreshold( ) << std::endl;
385
386 }
387
388 // -------------------------------------------------------------------------
389 template< class _TInputImage, class _TOutputImage >
390 template< class _TIn, class _TOutPtr >
391 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
392 _RandomWalker(
393   const _TIn* in, const typename _TOutPtr::ObjectType* labels,
394   const typename _TIn::PixelType& beta,
395   _TOutPtr& out
396   )
397 {
398   typedef typename _TIn::PixelType _TScalar;
399   typedef typename _TOutPtr::ObjectType _TOut;
400   typedef fpa::Functors::Dijkstra::Image::Gaussian< _TIn, _TScalar > _TFun;
401   typedef fpa::Filters::Image::RandomWalker< _TIn, _TOut, _TScalar > _TRandomWalker;
402   typedef itk::BinaryThresholdImageFilter< _TOut, _TOut > _TExtract;
403
404   typename _TFun::Pointer fun = _TFun::New( );
405   fun->SetBeta( beta );
406
407   typename _TRandomWalker::Pointer rw = _TRandomWalker::New( );
408   rw->SetInputImage( in );
409   rw->SetInputLabels( labels );
410   rw->SetWeightFunction( fun );
411
412   typename _TExtract::Pointer extract = _TExtract::New( );
413   extract->SetInput( rw->GetOutputLabels( ) );
414   extract->SetInsideValue( 1 );
415   extract->SetOutsideValue( 0 );
416   extract->SetLowerThreshold( 1 );
417   extract->SetUpperThreshold( 1 );
418   extract->Update( );
419   out = extract->GetOutput( );
420   out->DisconnectPipeline( );
421 }
422
423 // -------------------------------------------------------------------------
424 template< class _TInputImage, class _TOutputImage >
425 template< class _TIn, class _TSeeds, class _TSeed, class _TOutPtr, class _TAxisPtr >
426 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
427 _DistanceAndAxis(
428   const _TIn* in, const _TSeeds& seeds,
429   const _TSeed& p0, const _TSeed& p1,
430   _TOutPtr& out_dist, _TAxisPtr& out_axis
431   )
432 {
433   typedef typename _TOutPtr::ObjectType _TOut;
434   typedef typename _TOut::PixelType _TScalar;
435   typedef fpa::Filters::Image::ExtractAxis< _TIn, _TScalar > _TExtract;
436
437   // Prepare output values
438   typename _TIn::IndexType s0, s1;
439   in->TransformPhysicalPointToIndex( p0.Point, s0 );
440   in->TransformPhysicalPointToIndex( p1.Point, s1 );
441
442   typename _TExtract::Pointer extract = _TExtract::New( );
443   extract->SetInput( const_cast< _TIn* >( in ) );
444   extract->AddSeed( s0 );
445   extract->AddSeed( s1 );
446   for( typename _TSeeds::value_type seed: seeds )
447     extract->AddSeed( seed.Point );
448   extract->SetStartIndex( s0 );
449   extract->SetEndIndex( s1 );
450   extract->Update( );
451   out_axis = TPath::New( );
452   out_axis->Graft( extract->GetOutput( ) );
453   this->_Smooth( extract->GetCenterness( )->GetOutput( ), out_dist, 1 );
454 }
455
456 // -------------------------------------------------------------------------
457 template< class _TInputImage, class _TOutputImage >
458 template< class _TInPtr >
459 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
460 _Save( const _TInPtr& in, const std::string& fname )
461 {
462   typedef itk::ImageFileWriter< typename _TInPtr::ObjectType > _TWriter;
463   typename _TWriter::Pointer w = _TWriter::New( );
464   w->SetInput( in );
465   w->SetFileName( fname );
466   w->Update( );
467 }
468
469 #endif // __RandomWalkSegmentation__hxx__
470
471 // eof - $RCSfile$