1 // =========================================================================
2 // @author Leonardo Florez-Valencia (florez-l@javeriana.edu.co)
3 // =========================================================================
4 #ifndef __RandomWalkSegmentation__hxx__
5 #define __RandomWalkSegmentation__hxx__
7 #include <itkBinaryThresholdImageFilter.h>
8 #include <itkImageRegionConstIterator.h>
9 #include <itkImageRegionIterator.h>
10 #include <itkSmoothingRecursiveGaussianImageFilter.h>
12 #include <ivq/ITK/RegionOfInterestWithPaddingImageFilter.h>
14 #include <fpa/Filters/Image/RandomWalker.h>
15 #include <fpa/Filters/Image/ExtractAxis.h>
16 #include <fpa/Functors/Dijkstra/Image/Gaussian.h>
18 #include "DijkstraWithMeanAndVariance.h"
19 #include "RandomWalkLabelling.h"
21 #include <itkImageFileWriter.h>
23 // -------------------------------------------------------------------------
24 template< class _TInputImage, class _TOutputImage >
25 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
26 AddSeed( const TIndex& s )
31 this->m_Seeds.push_back( seed );
35 // -------------------------------------------------------------------------
36 template< class _TInputImage, class _TOutputImage >
37 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
38 AddSeed( const TPoint& s )
43 this->m_Seeds.push_back( seed );
47 // -------------------------------------------------------------------------
48 template< class _TInputImage, class _TOutputImage >
49 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
52 this->m_Seeds.clear( );
56 // -------------------------------------------------------------------------
57 template< class _TInputImage, class _TOutputImage >
58 unsigned long RandomWalkSegmentation< _TInputImage, _TOutputImage >::
59 GetNumberOfSeeds( ) const
61 return( this->m_Seeds.size( ) );
64 // -------------------------------------------------------------------------
65 template< class _TInputImage, class _TOutputImage >
66 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
67 SetStartSeed( const TIndex& s )
69 this->m_StartSeed.Index = s;
70 this->m_StartSeed.IsPoint = false;
74 // -------------------------------------------------------------------------
75 template< class _TInputImage, class _TOutputImage >
76 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
77 SetStartSeed( const TPoint& s )
79 this->m_StartSeed.Point = s;
80 this->m_StartSeed.IsPoint = true;
84 // -------------------------------------------------------------------------
85 template< class _TInputImage, class _TOutputImage >
86 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
87 SetEndSeed( const TIndex& s )
89 this->m_EndSeed.Index = s;
90 this->m_EndSeed.IsPoint = false;
94 // -------------------------------------------------------------------------
95 template< class _TInputImage, class _TOutputImage >
96 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
97 SetEndSeed( const TPoint& s )
99 this->m_EndSeed.Point = s;
100 this->m_EndSeed.IsPoint = true;
104 // -------------------------------------------------------------------------
105 template< class _TInputImage, class _TOutputImage >
106 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
107 RandomWalkSegmentation( )
109 m_Beta( double( 20 ) ),
110 m_Sigma( double( 10 ) ),
111 m_Radius( double( 5 ) )
113 fpaFilterOutputConfigureMacro( OutputAxis, TPath );
116 // -------------------------------------------------------------------------
117 template< class _TInputImage, class _TOutputImage >
118 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
119 ~RandomWalkSegmentation( )
123 // -------------------------------------------------------------------------
124 template< class _TInputImage, class _TOutputImage >
125 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
128 typedef typename TOutputImage::PixelType _TScalar;
129 typedef DijkstraWithMeanAndVariance< TOutputImage, TOutputImage > _TInit;
130 typedef typename _TInit::TMarksImage _TMarksImage;
131 typedef fpa::Functors::Dijkstra::Image::Gaussian< TOutputImage, _TScalar > _TGaussFun;
132 typedef RandomWalkLabelling< TOutputImage, _TMarksImage > _TLabelling;
133 typedef typename _TLabelling::TLabelsImage _TLabelsImage;
134 typedef fpa::Filters::Image::RandomWalker< TOutputImage, _TLabelsImage, _TScalar > _TRandomWalker;
135 typedef itk::BinaryThresholdImageFilter< _TLabelsImage, _TLabelsImage > _TLabelExtract;
136 typedef fpa::Filters::Image::ExtractAxis< _TLabelsImage, _TScalar > _TAxisExtract;
137 typedef ivq::ITK::RegionOfInterestWithPaddingImageFilter< TOutputImage, TOutputImage > _TInputROI;
138 typedef ivq::ITK::RegionOfInterestWithPaddingImageFilter< _TMarksImage, _TMarksImage > _TMarksROI;
139 typedef itk::SmoothingRecursiveGaussianImageFilter< TInputImage, TOutputImage > _TSmooth;
141 // Prepare initial seeds
142 const TInputImage* input = this->GetInput( );
143 if( this->m_StartSeed.IsPoint )
144 input->TransformPhysicalPointToIndex(
145 this->m_StartSeed.Point,
146 this->m_StartSeed.Index
149 input->TransformIndexToPhysicalPoint(
150 this->m_StartSeed.Index,
151 this->m_StartSeed.Point
153 this->m_StartSeed.IsPoint = true;
154 if( this->m_EndSeed.IsPoint )
155 input->TransformPhysicalPointToIndex(
156 this->m_EndSeed.Point,
157 this->m_EndSeed.Index
160 input->TransformIndexToPhysicalPoint(
161 this->m_EndSeed.Index,
162 this->m_EndSeed.Point
164 this->m_EndSeed.IsPoint = true;
165 for( TSeed& seed: this->m_Seeds )
168 input->TransformPhysicalPointToIndex( seed.Point, seed.Index );
170 input->TransformIndexToPhysicalPoint( seed.Index, seed.Point );
175 // Intermediary objects
176 typename TOutputImage::Pointer smooth_input;
177 typename TOutputImage::Pointer input_roi;
178 typename TPath::Pointer init_axis;
179 typename TPath::Pointer init_axis_roi;
180 typename TPath::Pointer output_axis_roi;
181 typename _TMarksImage::Pointer init_marks;
182 typename _TMarksImage::Pointer init_marks_roi;
183 typename _TLabelsImage::Pointer init_labels;
184 typename _TLabelsImage::Pointer final_labels;
185 typename _TLabelsImage::Pointer inside_labels;
186 typename TInputImage::RegionType roi;
187 typename TOutputImage::Pointer output_roi;
188 double init_mean, init_std;
191 std::cout << "0" << std::endl;
193 typename _TSmooth::Pointer smooth = _TSmooth::New( );
194 smooth->SetInput( input );
195 smooth->SetNormalizeAcrossScale( true );
196 smooth->SetSigmaArray( input->GetSpacing( ) * double( 2 ) );
198 smooth_input = smooth->GetOutput( );
199 smooth_input->DisconnectPipeline( );
203 // Initial segmentation
204 std::cout << "1" << std::endl;
206 typename _TGaussFun::Pointer init_fun = _TGaussFun::New( );
207 init_fun->SetBeta( this->m_Beta );
209 typename _TInit::Pointer init = _TInit::New( );
210 init->SetInput( smooth_input );
211 init->SetWeightFunction( init_fun );
212 init->StopAtOneFrontOn( );
213 for( TSeed seed: this->m_Seeds )
214 init->AddSeed( seed.Point );
217 // Get initial values
218 init_mean = init->GetMean( );
219 init_std = init->GetDeviation( ) * this->m_Sigma;
221 // Get initial objects
222 init->GetMinimumSpanningTree( )->
223 GetPath( init_axis, this->m_StartSeed.Index, this->m_EndSeed.Index );
224 init_marks = init->GetMarks( );
225 init_marks->DisconnectPipeline( );
227 typename TInputImage::IndexType min_idx = init->GetMinVertex( );
228 typename TInputImage::IndexType max_idx = init->GetMaxVertex( );
229 typename TInputImage::SizeType roi_size;
230 for( unsigned int i = 0; i < TInputImage::ImageDimension; ++i )
231 roi_size[ i ] = max_idx[ i ] - min_idx[ i ] + 1;
232 roi.SetIndex( min_idx );
233 roi.SetSize( roi_size );
237 // Extract input ROIs
238 unsigned int pad = 10;
240 typename _TInputROI::Pointer input_roi_filter = _TInputROI::New( );
241 input_roi_filter->SetInput( smooth_input );
242 input_roi_filter->SetRegionOfInterest( roi );
243 input_roi_filter->SetPadding( pad );
244 input_roi_filter->Update( );
245 input_roi = input_roi_filter->GetOutput( );
246 input_roi->DisconnectPipeline( );
247 roi = input_roi_filter->GetRegionOfInterest( );
252 typename _TMarksROI::Pointer init_marks_roi_filter = _TMarksROI::New( );
253 init_marks_roi_filter->SetInput( init_marks );
254 init_marks_roi_filter->SetRegionOfInterest( roi );
255 init_marks_roi_filter->SetPadding( 0 );
256 init_marks_roi_filter->Update( );
257 init_marks_roi = init_marks_roi_filter->GetOutput( );
258 init_marks_roi->DisconnectPipeline( );
262 // Convert initial axis
264 init_axis_roi = TPath::New( );
265 init_axis_roi->SetReferenceImage( input_roi.GetPointer( ) );
266 for( unsigned long i = 0; i < init_axis->GetSize( ); ++i )
268 TIndex v = init_axis->GetVertex( i );
269 for( unsigned int d = 0; d < TInputImage::ImageDimension; ++d )
270 v[ d ] -= roi.GetIndex( )[ d ];
271 init_axis_roi->AddVertex( v );
278 std::cout << "2" << std::endl;
280 typename _TLabelling::Pointer labelling = _TLabelling::New( );
281 labelling->SetInputImage( input_roi );
282 labelling->SetInputMarks( init_marks_roi );
283 labelling->SetInputPath( init_axis_roi );
284 labelling->SetInsideLabel( 1 );
285 labelling->SetOutsideLabel( 2 );
286 labelling->SetLowerLabel( 3 );
287 labelling->SetUpperLabel( 4 );
288 labelling->SetRadius( this->m_Radius );
289 labelling->SetLowerThreshold( init_mean - init_std );
290 labelling->SetUpperThreshold( init_mean + init_std );
291 labelling->Update( );
292 init_labels = labelling->GetOutputLabels( );
293 init_labels->DisconnectPipeline( );
298 std::cout << "3" << std::endl;
300 typename _TGaussFun::Pointer rw_fun = _TGaussFun::New( );
301 rw_fun->SetBeta( this->m_Beta /*init_std / double( 2 )*/ );
303 typename _TRandomWalker::Pointer rw = _TRandomWalker::New( );
304 rw->SetInputImage( input_roi );
305 rw->SetInputLabels( init_labels );
306 rw->SetWeightFunction( rw_fun );
308 final_labels = rw->GetOutputLabels( );
309 final_labels->DisconnectPipeline( );
313 // Extract inside label
314 std::cout << "4" << std::endl;
316 typename _TLabelExtract::Pointer label_extract = _TLabelExtract::New( );
317 label_extract->SetInput( final_labels );
318 label_extract->SetInsideValue( 1 );
319 label_extract->SetOutsideValue( 0 );
320 label_extract->SetLowerThreshold( 1 );
321 label_extract->SetUpperThreshold( 1 );
322 label_extract->Update( );
323 inside_labels = label_extract->GetOutput( );
324 inside_labels->DisconnectPipeline( );
328 // Prepare output values
329 std::cout << "5" << std::endl;
331 TIndex start_seed, end_seed;
332 inside_labels->TransformPhysicalPointToIndex( this->m_StartSeed.Point, start_seed );
333 inside_labels->TransformPhysicalPointToIndex( this->m_EndSeed.Point, end_seed );
335 typename _TAxisExtract::Pointer axis_extract = _TAxisExtract::New( );
336 axis_extract->SetInput( inside_labels );
337 axis_extract->AddSeed( start_seed );
338 axis_extract->AddSeed( end_seed );
339 for( TSeed seed: this->m_Seeds )
340 axis_extract->AddSeed( seed.Point );
341 axis_extract->SetStartIndex( start_seed );
342 axis_extract->SetEndIndex( end_seed );
343 axis_extract->Update( );
344 output_axis_roi = TPath::New( );
345 output_axis_roi->Graft( axis_extract->GetOutput( ) );
346 output_roi = axis_extract->GetCenterness( )->GetOutput( );
347 output_roi->DisconnectPipeline( );
351 // Put everything back to requested region
352 std::cout << "6" << std::endl;
354 TOutputImage* output = this->GetOutput( );
355 output->SetBufferedRegion( output->GetRequestedRegion( ) );
357 output->FillBuffer( -std::numeric_limits< _TScalar >::max( ) );
359 itk::ImageRegionConstIterator< TOutputImage > rIt(
360 output_roi, output_roi->GetRequestedRegion( )
362 itk::ImageRegionIterator< TOutputImage > oIt( output, roi );
365 for( ; !rIt.IsAtEnd( ); ++rIt, ++oIt )
366 oIt.Set( rIt.Get( ) );
368 TPath* output_axis = this->GetOutputAxis( );
369 output_axis->SetReferenceImage( output );
370 for( unsigned long i = 0; i < output_axis_roi->GetSize( ); ++i )
372 TIndex v = output_axis_roi->GetVertex( i );
373 for( unsigned int d = 0; d < TInputImage::ImageDimension; ++d )
374 v[ d ] += roi.GetIndex( )[ d ];
375 output_axis->AddVertex( v );
380 std::cout << "7" << std::endl;
383 // -------------------------------------------------------------------------
384 template< class _TInputImage, class _TOutputImage >
385 template< class _TImage >
386 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
387 _save( _TImage* image, const std::string& fname )
389 typename itk::ImageFileWriter< _TImage >::Pointer w =
390 itk::ImageFileWriter< _TImage >::New( );
391 w->SetInput( image );
392 w->SetFileName( fname );
397 #endif // __RandomWalkSegmentation__hxx__