1 // =========================================================================
2 // @author Leonardo Florez-Valencia (florez-l@javeriana.edu.co)
3 // =========================================================================
4 #ifndef __RandomWalkSegmentation__hxx__
5 #define __RandomWalkSegmentation__hxx__
7 #include <itkBinaryThresholdImageFilter.h>
8 #include <itkSmoothingRecursiveGaussianImageFilter.h>
9 #include <ivq/ITK/RegionOfInterestWithPaddingImageFilter.h>
10 #include <fpa/Filters/Image/ExtractAxis.h>
11 #include <fpa/Filters/Image/RandomWalker.h>
12 #include "DijkstraWithMeanAndVariance.h"
13 #include "RandomWalkLabelling.h"
15 #include <itkImageFileWriter.h>
17 // -------------------------------------------------------------------------
18 template< class _TInputImage, class _TOutputImage >
19 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
20 AddSeed( const TIndex& s )
25 this->m_Seeds.push_back( seed );
29 // -------------------------------------------------------------------------
30 template< class _TInputImage, class _TOutputImage >
31 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
32 AddSeed( const TPoint& s )
37 this->m_Seeds.push_back( seed );
41 // -------------------------------------------------------------------------
42 template< class _TInputImage, class _TOutputImage >
43 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
46 this->m_Seeds.clear( );
50 // -------------------------------------------------------------------------
51 template< class _TInputImage, class _TOutputImage >
52 unsigned long RandomWalkSegmentation< _TInputImage, _TOutputImage >::
53 GetNumberOfSeeds( ) const
55 return( this->m_Seeds.size( ) );
58 // -------------------------------------------------------------------------
59 template< class _TInputImage, class _TOutputImage >
60 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
61 SetStartSeed( const TIndex& s )
63 this->m_StartSeed.Index = s;
64 this->m_StartSeed.IsPoint = false;
68 // -------------------------------------------------------------------------
69 template< class _TInputImage, class _TOutputImage >
70 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
71 SetStartSeed( const TPoint& s )
73 this->m_StartSeed.Point = s;
74 this->m_StartSeed.IsPoint = true;
78 // -------------------------------------------------------------------------
79 template< class _TInputImage, class _TOutputImage >
80 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
81 SetEndSeed( const TIndex& s )
83 this->m_EndSeed.Index = s;
84 this->m_EndSeed.IsPoint = false;
88 // -------------------------------------------------------------------------
89 template< class _TInputImage, class _TOutputImage >
90 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
91 SetEndSeed( const TPoint& s )
93 this->m_EndSeed.Point = s;
94 this->m_EndSeed.IsPoint = true;
98 // -------------------------------------------------------------------------
99 template< class _TInputImage, class _TOutputImage >
100 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
101 RandomWalkSegmentation( )
103 m_Beta( double( 20 ) ),
104 m_Sigma( double( 10 ) ),
105 m_Radius( double( 5 ) )
107 fpaFilterOutputConfigureMacro( OutputAxis, TPath );
110 // -------------------------------------------------------------------------
111 template< class _TInputImage, class _TOutputImage >
112 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
113 ~RandomWalkSegmentation( )
117 // -------------------------------------------------------------------------
118 template< class _TInputImage, class _TOutputImage >
119 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
122 typedef unsigned char _TLabel;
123 typedef typename TOutputImage::PixelType _TScalar;
124 typedef itk::Image< _TLabel, TInputImage::ImageDimension > _TLabels;
126 // Prepare initial seeds
127 const TInputImage* input = this->GetInput( );
128 this->_SynchSeed( input, this->m_StartSeed );
129 this->_SynchSeed( input, this->m_EndSeed );
130 for( TSeed& seed: this->m_Seeds )
131 this->_SynchSeed( input, seed );
134 typename TOutputImage::Pointer smooth_in;
135 std::cout << "smooth" << std::endl;
136 this->_Smooth( this->GetInput( ), smooth_in, 2 );
137 this->_Save( smooth_in, "smooth.mhd" );
139 // Initial segmentation
140 std::cout << "raw" << std::endl;
141 typename TOutputImage::Pointer init_seg;
142 typename TPath::Pointer init_axis;
143 _TScalar init_mean, init_std;
144 typename TInputImage::RegionType roi =
145 this->_RawSegmentation(
146 smooth_in.GetPointer( ), this->m_Seeds,
147 this->m_StartSeed.Index, this->m_EndSeed.Index,
149 init_seg, init_axis, init_mean, init_std
151 std::cout << "Stat: " << init_mean << " +/- " << init_std << std::endl;
152 init_std *= _TScalar( this->m_Sigma );
153 this->_Save( init_seg, "raw.mhd" );
155 // Extract input ROIs
156 std::cout << "ROI" << std::endl;
157 typename TOutputImage::Pointer smooth_in_roi, init_seg_roi;
158 roi = this->_ROI( smooth_in.GetPointer( ), roi, 10, smooth_in_roi );
159 this->_ROI( init_seg.GetPointer( ), roi, 0, init_seg_roi );
160 typename TPath::Pointer init_axis_roi = TPath::New( );
161 init_axis_roi->SetReferenceImage( smooth_in_roi.GetPointer( ) );
162 this->_AxisROI( init_axis.GetPointer( ), roi, init_axis_roi );
165 std::cout << "labelling" << std::endl;
166 typename _TLabels::Pointer init_labels;
167 _TScalar radius = _TScalar( this->m_Radius );
169 smooth_in_roi.GetPointer( ),
170 init_seg_roi.GetPointer( ),
171 init_axis_roi.GetPointer( ),
172 init_mean, init_std, radius,
175 this->_Save( init_labels, "init_labels.mhd" );
178 std::cout << "random walker " << init_std << " " << this->m_Beta << std::endl;
179 typename _TLabels::Pointer rw_seg;
181 smooth_in_roi.GetPointer( ),
182 init_labels.GetPointer( ),
183 this->m_Beta, // init_std / _TScalar( 2 ),
188 std::cout << "axis" << std::endl;
189 typename TOutputImage::Pointer out_dist;
190 typename TPath::Pointer out_axis;
191 this->_DistanceAndAxis(
192 rw_seg.GetPointer( ),
193 this->m_Seeds, this->m_StartSeed, this->m_EndSeed,
197 // Put everything back to requested region
198 std::cout << "output" << std::endl;
200 TOutputImage* output = this->GetOutput( );
201 output->SetBufferedRegion( output->GetRequestedRegion( ) );
203 output->FillBuffer( -std::numeric_limits< _TScalar >::max( ) );
205 itk::ImageRegionConstIterator< TOutputImage > rIt(
206 out_dist, out_dist->GetRequestedRegion( )
208 itk::ImageRegionIterator< TOutputImage > oIt( output, roi );
211 for( ; !rIt.IsAtEnd( ); ++rIt, ++oIt )
212 oIt.Set( rIt.Get( ) );
214 TPath* output_axis = this->GetOutputAxis( );
215 output_axis->SetReferenceImage( output );
216 for( unsigned long i = 0; i < out_axis->GetSize( ); ++i )
218 TIndex v = out_axis->GetVertex( i );
219 for( unsigned int d = 0; d < TInputImage::ImageDimension; ++d )
220 v[ d ] += roi.GetIndex( )[ d ];
221 output_axis->AddVertex( v );
228 // -------------------------------------------------------------------------
229 template< class _TInputImage, class _TOutputImage >
230 template< class _TIn, class _TSeed >
231 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
232 _SynchSeed( const _TIn* in, _TSeed& seed )
235 in->TransformPhysicalPointToIndex( seed.Point, seed.Index );
237 in->TransformIndexToPhysicalPoint( seed.Index, seed.Point );
241 // -------------------------------------------------------------------------
242 template< class _TInputImage, class _TOutputImage >
243 template< class _TIn, class _TOutPtr >
244 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
245 _Smooth( const _TIn* in, _TOutPtr& out, double s )
247 typedef typename _TOutPtr::ObjectType _TOut;
248 typedef itk::SmoothingRecursiveGaussianImageFilter< _TIn, _TOut > _TSmooth;
250 typename _TSmooth::Pointer smooth = _TSmooth::New( );
251 smooth->SetInput( in );
252 smooth->SetNormalizeAcrossScale( true );
253 smooth->SetSigmaArray( in->GetSpacing( ) * s );
255 out = smooth->GetOutput( );
256 out->DisconnectPipeline( );
259 // -------------------------------------------------------------------------
260 template< class _TInputImage, class _TOutputImage >
261 template< class _TIn, class _TOutPtr, class _TAxisPtr, class _TSeeds >
262 typename _TIn::RegionType
263 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
265 const _TIn* in, const _TSeeds& seeds,
266 const typename _TIn::IndexType& s0,
267 const typename _TIn::IndexType& s1,
268 const typename _TOutPtr::ObjectType::PixelType& beta,
269 _TOutPtr& out, _TAxisPtr& out_axis,
270 typename _TOutPtr::ObjectType::PixelType& oMean,
271 typename _TOutPtr::ObjectType::PixelType& oSTD
274 typedef typename _TOutPtr::ObjectType _TOut;
275 typedef typename _TOut::PixelType _TScalar;
276 typedef DijkstraWithMeanAndVariance< _TIn, _TOut > _TInit;
277 typedef fpa::Functors::Dijkstra::Image::Gaussian< _TIn, _TScalar > _TFun;
279 typename _TFun::Pointer fun = _TFun::New( );
280 fun->SetBeta( beta );
282 typename _TInit::Pointer init = _TInit::New( );
283 init->SetInput( in );
284 init->SetWeightFunction( fun );
285 init->StopAtOneFrontOn( );
286 for( typename _TSeeds::value_type seed: seeds )
287 init->AddSeed( seed.Point );
290 // Get initial values
291 oMean = _TScalar( init->GetMean( ) );
292 oSTD = _TScalar( init->GetDeviation( ) );
294 // Get initial objects
295 init->GetMinimumSpanningTree( )->GetPath( out_axis, s0, s1 );
296 out = init->GetOutput( );
297 out->DisconnectPipeline( );
300 typename _TIn::IndexType min_idx = init->GetMinVertex( );
301 typename _TIn::IndexType max_idx = init->GetMaxVertex( );
302 typename _TIn::SizeType roi_size;
303 for( unsigned int i = 0; i < TInputImage::ImageDimension; ++i )
304 roi_size[ i ] = max_idx[ i ] - min_idx[ i ] + 1;
305 typename _TIn::RegionType roi;
306 roi.SetIndex( min_idx );
307 roi.SetSize( roi_size );
311 // -------------------------------------------------------------------------
312 template< class _TInputImage, class _TOutputImage >
313 template< class _TIn, class _TOutPtr >
314 typename _TIn::RegionType
315 RandomWalkSegmentation< _TInputImage, _TOutputImage >::
317 const _TIn* in, const typename _TIn::RegionType& roi, unsigned int pad,
321 typedef typename _TOutPtr::ObjectType _TOut;
322 typedef ivq::ITK::RegionOfInterestWithPaddingImageFilter< _TIn, _TOut > _TROI;
324 typename _TROI::Pointer filter = _TROI::New( );
325 filter->SetInput( in );
326 filter->SetRegionOfInterest( roi );
327 filter->SetPadding( pad );
329 out = filter->GetOutput( );
330 out->DisconnectPipeline( );
331 return( filter->GetRegionOfInterest( ) );
334 // -------------------------------------------------------------------------
335 template< class _TInputImage, class _TOutputImage >
336 template< class _TAxisPtr, class _TRegion >
337 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
339 const typename _TAxisPtr::ObjectType* in, const _TRegion& roi,
343 typedef typename _TAxisPtr::ObjectType _TAxis;
345 for( unsigned long i = 0; i < in->GetSize( ); ++i )
347 typename _TAxis::TIndex v = in->GetVertex( i );
348 for( unsigned int d = 0; d < _TAxis::Dimension; ++d )
349 v[ d ] -= roi.GetIndex( )[ d ];
355 // -------------------------------------------------------------------------
356 template< class _TInputImage, class _TOutputImage >
357 template< class _TInRaw, class _TInCosts, class _TAxis, class _TScalar, class _TOutPtr >
358 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
360 const _TInRaw* raw, const _TInCosts* costs, const _TAxis* axis,
361 const _TScalar& mean, const _TScalar& dev, const _TScalar& radius,
365 typedef typename _TOutPtr::ObjectType _TOut;
366 typedef RandomWalkLabelling< _TInRaw, _TInCosts, _TOut > _TLabel;
368 typename _TLabel::Pointer label = _TLabel::New( );
369 label->SetInputImage( raw );
370 label->SetInputCosts( costs );
371 label->SetInputPath( axis );
372 label->SetInsideLabel( 1 );
373 label->SetOutsideLabel( 2 );
374 label->SetLowerLabel( 3 );
375 label->SetUpperLabel( 4 );
376 label->SetRadius( radius );
377 label->SetLowerThreshold( mean - dev );
378 label->SetUpperThreshold( mean + dev );
380 out = label->GetOutputLabels( );
381 out->DisconnectPipeline( );
383 std::cout << label->GetLowerThreshold( ) << std::endl;
384 std::cout << label->GetUpperThreshold( ) << std::endl;
388 // -------------------------------------------------------------------------
389 template< class _TInputImage, class _TOutputImage >
390 template< class _TIn, class _TOutPtr >
391 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
393 const _TIn* in, const typename _TOutPtr::ObjectType* labels,
394 const typename _TIn::PixelType& beta,
398 typedef typename _TIn::PixelType _TScalar;
399 typedef typename _TOutPtr::ObjectType _TOut;
400 typedef fpa::Functors::Dijkstra::Image::Gaussian< _TIn, _TScalar > _TFun;
401 typedef fpa::Filters::Image::RandomWalker< _TIn, _TOut, _TScalar > _TRandomWalker;
402 typedef itk::BinaryThresholdImageFilter< _TOut, _TOut > _TExtract;
404 typename _TFun::Pointer fun = _TFun::New( );
405 fun->SetBeta( beta );
407 typename _TRandomWalker::Pointer rw = _TRandomWalker::New( );
408 rw->SetInputImage( in );
409 rw->SetInputLabels( labels );
410 rw->SetWeightFunction( fun );
412 typename _TExtract::Pointer extract = _TExtract::New( );
413 extract->SetInput( rw->GetOutputLabels( ) );
414 extract->SetInsideValue( 1 );
415 extract->SetOutsideValue( 0 );
416 extract->SetLowerThreshold( 1 );
417 extract->SetUpperThreshold( 1 );
419 out = extract->GetOutput( );
420 out->DisconnectPipeline( );
423 // -------------------------------------------------------------------------
424 template< class _TInputImage, class _TOutputImage >
425 template< class _TIn, class _TSeeds, class _TSeed, class _TOutPtr, class _TAxisPtr >
426 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
428 const _TIn* in, const _TSeeds& seeds,
429 const _TSeed& p0, const _TSeed& p1,
430 _TOutPtr& out_dist, _TAxisPtr& out_axis
433 typedef typename _TOutPtr::ObjectType _TOut;
434 typedef typename _TOut::PixelType _TScalar;
435 typedef fpa::Filters::Image::ExtractAxis< _TIn, _TScalar > _TExtract;
437 // Prepare output values
438 typename _TIn::IndexType s0, s1;
439 in->TransformPhysicalPointToIndex( p0.Point, s0 );
440 in->TransformPhysicalPointToIndex( p1.Point, s1 );
442 typename _TExtract::Pointer extract = _TExtract::New( );
443 extract->SetInput( const_cast< _TIn* >( in ) );
444 extract->AddSeed( s0 );
445 extract->AddSeed( s1 );
446 for( typename _TSeeds::value_type seed: seeds )
447 extract->AddSeed( seed.Point );
448 extract->SetStartIndex( s0 );
449 extract->SetEndIndex( s1 );
451 out_axis = TPath::New( );
452 out_axis->Graft( extract->GetOutput( ) );
453 this->_Smooth( extract->GetCenterness( )->GetOutput( ), out_dist, 1 );
456 // -------------------------------------------------------------------------
457 template< class _TInputImage, class _TOutputImage >
458 template< class _TInPtr >
459 void RandomWalkSegmentation< _TInputImage, _TOutputImage >::
460 _Save( const _TInPtr& in, const std::string& fname )
462 typedef itk::ImageFileWriter< typename _TInPtr::ObjectType > _TWriter;
463 typename _TWriter::Pointer w = _TWriter::New( );
465 w->SetFileName( fname );
469 #endif // __RandomWalkSegmentation__hxx__