]> Creatis software - FrontAlgorithms.git/blob - appli/CTArteries/algorithms/RandomWalkSegmentation.hxx
...
[FrontAlgorithms.git] / appli / CTArteries / algorithms / RandomWalkSegmentation.hxx
1 // =========================================================================
2 // @author Leonardo Florez-Valencia (florez-l@javeriana.edu.co)
3 // =========================================================================
4 #ifndef __RandomWalkSegmentation__hxx__
5 #define __RandomWalkSegmentation__hxx__
6
7 #include <itkBinaryThresholdImageFilter.h>
8 #include <itkImageRegionConstIterator.h>
9 #include <itkImageRegionIterator.h>
10 #include <itkSmoothingRecursiveGaussianImageFilter.h>
11
12 #include <ivq/ITK/RegionOfInterestWithPaddingImageFilter.h>
13
14 #include <fpa/Filters/Image/RandomWalker.h>
15 #include <fpa/Filters/Image/ExtractAxis.h>
16 #include <fpa/Functors/Dijkstra/Image/Gaussian.h>
17
18 #include "DijkstraWithMeanAndVariance.h"
19 #include "RandomWalkLabelling.h"
20
21 #include <itkImageFileWriter.h>
22
23 // -------------------------------------------------------------------------
24 template< class _TInputImage, class _TOutputImage >
25 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
26 AddSeed( const TIndex& s )
27 {
28   TSeed seed;
29   seed.Index = s;
30   seed.IsPoint = false;
31   this->m_Seeds.push_back( seed );
32   this->Modified( );
33 }
34
35 // -------------------------------------------------------------------------
36 template< class _TInputImage, class _TOutputImage >
37 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
38 AddSeed( const TPoint& s )
39 {
40   TSeed seed;
41   seed.Point = s;
42   seed.IsPoint = true;
43   this->m_Seeds.push_back( seed );
44   this->Modified( );
45 }
46
47 // -------------------------------------------------------------------------
48 template< class _TInputImage, class _TOutputImage >
49 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
50 ClearSeeds( )
51 {
52   this->m_Seeds.clear( );
53   this->Modified( );
54 }
55
56 // -------------------------------------------------------------------------
57 template< class _TInputImage, class _TOutputImage >
58 unsigned long RandomWalkSegmentation< _TInputImage, _TOutputImage >::
59 GetNumberOfSeeds( ) const
60 {
61   return( this->m_Seeds.size( ) );
62 }
63
64 // -------------------------------------------------------------------------
65 template< class _TInputImage, class _TOutputImage >
66 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
67 SetStartSeed( const TIndex& s )
68 {
69   this->m_StartSeed.Index = s;
70   this->m_StartSeed.IsPoint = false;
71   this->Modified( );
72 }
73
74 // -------------------------------------------------------------------------
75 template< class _TInputImage, class _TOutputImage >
76 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
77 SetStartSeed( const TPoint& s )
78 {
79   this->m_StartSeed.Point = s;
80   this->m_StartSeed.IsPoint = true;
81   this->Modified( );
82 }
83
84 // -------------------------------------------------------------------------
85 template< class _TInputImage, class _TOutputImage >
86 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
87 SetEndSeed( const TIndex& s )
88 {
89   this->m_EndSeed.Index = s;
90   this->m_EndSeed.IsPoint = false;
91   this->Modified( );
92 }
93
94 // -------------------------------------------------------------------------
95 template< class _TInputImage, class _TOutputImage >
96 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
97 SetEndSeed( const TPoint& s )
98 {
99   this->m_EndSeed.Point = s;
100   this->m_EndSeed.IsPoint = true;
101   this->Modified( );
102 }
103
104 // -------------------------------------------------------------------------
105 template< class _TInputImage, class _TOutputImage >
106 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
107 RandomWalkSegmentation( )
108   : Superclass( ),
109     m_Beta( double( 20 ) ),
110     m_Sigma( double( 10 ) ),
111     m_Radius( double( 5 ) )
112 {
113   fpaFilterOutputConfigureMacro( OutputAxis, TPath );
114 }
115
116 // -------------------------------------------------------------------------
117 template< class _TInputImage, class _TOutputImage >
118 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
119 ~RandomWalkSegmentation( )
120 {
121 }
122
123 // -------------------------------------------------------------------------
124 template< class _TInputImage, class _TOutputImage >
125 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
126 GenerateData( )
127 {
128   typedef typename TOutputImage::PixelType _TScalar;
129   typedef DijkstraWithMeanAndVariance< TOutputImage, TOutputImage > _TInit;
130   typedef typename _TInit::TMarksImage _TMarksImage;
131   typedef fpa::Functors::Dijkstra::Image::Gaussian< TOutputImage, _TScalar > _TGaussFun;
132   typedef RandomWalkLabelling< TOutputImage, _TMarksImage > _TLabelling;
133   typedef typename _TLabelling::TLabelsImage _TLabelsImage;
134   typedef fpa::Filters::Image::RandomWalker< TOutputImage, _TLabelsImage, _TScalar > _TRandomWalker;
135   typedef itk::BinaryThresholdImageFilter< _TLabelsImage, _TLabelsImage > _TLabelExtract;
136   typedef fpa::Filters::Image::ExtractAxis< _TLabelsImage, _TScalar > _TAxisExtract;
137   typedef ivq::ITK::RegionOfInterestWithPaddingImageFilter< TOutputImage, TOutputImage > _TInputROI;
138   typedef ivq::ITK::RegionOfInterestWithPaddingImageFilter< _TMarksImage, _TMarksImage > _TMarksROI;
139   typedef itk::SmoothingRecursiveGaussianImageFilter< TInputImage, TOutputImage > _TSmooth;
140
141   // Prepare initial seeds
142   const TInputImage* input = this->GetInput( );
143   if( this->m_StartSeed.IsPoint )
144     input->TransformPhysicalPointToIndex(
145       this->m_StartSeed.Point,
146       this->m_StartSeed.Index
147       );
148   else
149     input->TransformIndexToPhysicalPoint(
150       this->m_StartSeed.Index,
151       this->m_StartSeed.Point
152       );
153   this->m_StartSeed.IsPoint = true;
154   if( this->m_EndSeed.IsPoint )
155     input->TransformPhysicalPointToIndex(
156       this->m_EndSeed.Point,
157       this->m_EndSeed.Index
158       );
159   else
160     input->TransformIndexToPhysicalPoint(
161       this->m_EndSeed.Index,
162       this->m_EndSeed.Point
163       );
164   this->m_EndSeed.IsPoint = true;
165   for( TSeed& seed: this->m_Seeds )
166   {
167     if( seed.IsPoint )
168       input->TransformPhysicalPointToIndex( seed.Point, seed.Index );
169     else
170       input->TransformIndexToPhysicalPoint( seed.Index, seed.Point );
171     seed.IsPoint = true;
172
173   } // rof
174
175   // Intermediary objects
176   typename TOutputImage::Pointer smooth_input;
177   typename TOutputImage::Pointer input_roi;
178   typename TPath::Pointer init_axis;
179   typename TPath::Pointer init_axis_roi;
180   typename TPath::Pointer output_axis_roi;
181   typename _TMarksImage::Pointer init_marks;
182   typename _TMarksImage::Pointer init_marks_roi;
183   typename _TLabelsImage::Pointer init_labels;
184   typename _TLabelsImage::Pointer final_labels;
185   typename _TLabelsImage::Pointer inside_labels;
186   typename TInputImage::RegionType roi;
187   typename TOutputImage::Pointer output_roi;
188   double init_mean, init_std;
189
190   // Smooth input
191   std::cout << "0" << std::endl;
192   { // begin
193     typename _TSmooth::Pointer smooth = _TSmooth::New( );
194     smooth->SetInput( input );
195     smooth->SetNormalizeAcrossScale( true );
196     smooth->SetSigmaArray( input->GetSpacing( ) * double( 2 ) );
197     smooth->Update( );
198     smooth_input = smooth->GetOutput( );
199     smooth_input->DisconnectPipeline( );
200
201   } // end
202
203   // Initial segmentation
204   std::cout << "1" << std::endl;
205   { // begin
206     typename _TGaussFun::Pointer init_fun = _TGaussFun::New( );
207     init_fun->SetBeta( this->m_Beta );
208
209     typename _TInit::Pointer init = _TInit::New( );
210     init->SetInput( smooth_input );
211     init->SetWeightFunction( init_fun );
212     init->StopAtOneFrontOn( );
213     for( TSeed seed: this->m_Seeds )
214       init->AddSeed( seed.Point );
215     init->Update( );
216
217     // Get initial values
218     init_mean = init->GetMean( );
219     init_std = init->GetDeviation( ) * this->m_Sigma;
220
221     // Get initial objects
222     init->GetMinimumSpanningTree( )->
223       GetPath( init_axis, this->m_StartSeed.Index, this->m_EndSeed.Index );
224     init_marks = init->GetMarks( );
225     init_marks->DisconnectPipeline( );
226
227     typename TInputImage::IndexType min_idx = init->GetMinVertex( );
228     typename TInputImage::IndexType max_idx = init->GetMaxVertex( );
229     typename TInputImage::SizeType roi_size;
230     for( unsigned int i = 0; i < TInputImage::ImageDimension; ++i )
231       roi_size[ i ] = max_idx[ i ] - min_idx[ i ] + 1;
232     roi.SetIndex( min_idx );
233     roi.SetSize( roi_size );
234
235   } // end
236
237   // Extract input ROIs
238   unsigned int pad = 10;
239   { // begin
240     typename _TInputROI::Pointer input_roi_filter = _TInputROI::New( );
241     input_roi_filter->SetInput( smooth_input );
242     input_roi_filter->SetRegionOfInterest( roi );
243     input_roi_filter->SetPadding( pad );
244     input_roi_filter->Update( );
245     input_roi = input_roi_filter->GetOutput( );
246     input_roi->DisconnectPipeline( );
247     roi = input_roi_filter->GetRegionOfInterest( );
248
249   } // end
250
251   { // begin
252     typename _TMarksROI::Pointer init_marks_roi_filter = _TMarksROI::New( );
253     init_marks_roi_filter->SetInput( init_marks );
254     init_marks_roi_filter->SetRegionOfInterest( roi );
255     init_marks_roi_filter->SetPadding( 0 );
256     init_marks_roi_filter->Update( );
257     init_marks_roi = init_marks_roi_filter->GetOutput( );
258     init_marks_roi->DisconnectPipeline( );
259
260   } // end
261   
262   // Convert initial axis
263   { // begin
264     init_axis_roi = TPath::New( );
265     init_axis_roi->SetReferenceImage( input_roi.GetPointer( ) );
266     for( unsigned long i = 0; i < init_axis->GetSize( ); ++i )
267     {
268       TIndex v = init_axis->GetVertex( i );
269       for( unsigned int d = 0; d < TInputImage::ImageDimension; ++d )
270         v[ d ] -= roi.GetIndex( )[ d ];
271       init_axis_roi->AddVertex( v );
272
273     } // rof
274
275   } // end
276
277   // Labelling
278   std::cout << "2" << std::endl;
279   { // begin
280     typename _TLabelling::Pointer labelling = _TLabelling::New( );
281     labelling->SetInputImage( input_roi );
282     labelling->SetInputMarks( init_marks_roi );
283     labelling->SetInputPath( init_axis_roi );
284     labelling->SetInsideLabel( 1 );
285     labelling->SetOutsideLabel( 2 );
286     labelling->SetLowerLabel( 3 );
287     labelling->SetUpperLabel( 4 );
288     labelling->SetRadius( this->m_Radius );
289     labelling->SetLowerThreshold( init_mean - init_std );
290     labelling->SetUpperThreshold( init_mean + init_std );
291     labelling->Update( );
292     init_labels = labelling->GetOutputLabels( );
293     init_labels->DisconnectPipeline( );
294
295   } // end
296
297   // Random walker
298   std::cout << "3" << std::endl;
299   { // begin
300     typename _TGaussFun::Pointer rw_fun = _TGaussFun::New( );
301     rw_fun->SetBeta( init_std / double( 2 ) );
302
303     typename _TRandomWalker::Pointer rw = _TRandomWalker::New( );
304     rw->SetInputImage( input_roi );
305     rw->SetInputLabels( init_labels );
306     rw->SetWeightFunction( rw_fun );
307     rw->Update( );
308     final_labels = rw->GetOutputLabels( );
309     final_labels->DisconnectPipeline( );
310
311   } // end
312
313   // Extract inside label
314   std::cout << "4" << std::endl;
315   { // begin
316     typename _TLabelExtract::Pointer label_extract = _TLabelExtract::New( );
317     label_extract->SetInput( final_labels );
318     label_extract->SetInsideValue( 1 );
319     label_extract->SetOutsideValue( 0 );
320     label_extract->SetLowerThreshold( 1 );
321     label_extract->SetUpperThreshold( 1 );
322     label_extract->Update( );
323     inside_labels = label_extract->GetOutput( );
324     inside_labels->DisconnectPipeline( );
325
326   } // end
327
328   // Prepare output values
329   std::cout << "5" << std::endl;
330   { // begin
331     TIndex start_seed, end_seed;
332     inside_labels->TransformPhysicalPointToIndex( this->m_StartSeed.Point, start_seed );
333     inside_labels->TransformPhysicalPointToIndex( this->m_EndSeed.Point, end_seed );
334
335     typename _TAxisExtract::Pointer axis_extract = _TAxisExtract::New( );
336     axis_extract->SetInput( inside_labels );
337     axis_extract->AddSeed( start_seed );
338     axis_extract->AddSeed( end_seed );
339     for( TSeed seed: this->m_Seeds )
340       axis_extract->AddSeed( seed.Point );
341     axis_extract->SetStartIndex( start_seed );
342     axis_extract->SetEndIndex( end_seed );
343     axis_extract->Update( );
344     output_axis_roi = TPath::New( );
345     output_axis_roi->Graft( axis_extract->GetOutput( ) );
346     output_roi = axis_extract->GetCenterness( )->GetOutput( );
347     output_roi->DisconnectPipeline( );
348
349   } // end
350
351   // Put everything back to requested region
352   std::cout << "6" << std::endl;
353   { // begin
354     TOutputImage* output = this->GetOutput( );
355     output->SetBufferedRegion( output->GetRequestedRegion( ) );
356     output->Allocate( );
357     output->FillBuffer( -std::numeric_limits< _TScalar >::max( ) );
358
359     itk::ImageRegionConstIterator< TOutputImage > rIt(
360       output_roi, output_roi->GetRequestedRegion( )
361       );
362     itk::ImageRegionIterator< TOutputImage > oIt( output, roi );
363     rIt.GoToBegin( );
364     oIt.GoToBegin( );
365     for( ; !rIt.IsAtEnd( ); ++rIt, ++oIt )
366       oIt.Set( rIt.Get( ) );
367
368     TPath* output_axis = this->GetOutputAxis( );
369     output_axis->SetReferenceImage( output );
370     for( unsigned long i = 0; i < output_axis_roi->GetSize( ); ++i )
371     {
372       TIndex v = output_axis_roi->GetVertex( i );
373       for( unsigned int d = 0; d < TInputImage::ImageDimension; ++d )
374         v[ d ] += roi.GetIndex( )[ d ];
375       output_axis->AddVertex( v );
376
377     } // rof
378
379   } // end
380   std::cout << "7" << std::endl;
381 }
382
383 // -------------------------------------------------------------------------
384 template< class _TInputImage, class _TOutputImage >
385 template< class _TImage >
386 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
387 _save( _TImage* image, const std::string& fname )
388 {
389   typename itk::ImageFileWriter< _TImage >::Pointer w =
390     itk::ImageFileWriter< _TImage >::New( );
391   w->SetInput( image );
392   w->SetFileName( fname );
393   w->Update( );
394 }
395
396
397 #endif // __RandomWalkSegmentation__hxx__
398
399 // eof - $RCSfile$