1 // =========================================================================
2 // @author Leonardo Florez Valencia
3 // @email florez-l@javeriana.edu.co
4 // =========================================================================
5 #ifndef __fpa__Common__OriginalRandomWalker__hxx__
6 #define __fpa__Common__OriginalRandomWalker__hxx__
8 #include <itkImageRegionConstIteratorWithIndex.h>
9 #include <itkImageRegionIteratorWithIndex.h>
10 #include <Eigen/Sparse>
12 // -------------------------------------------------------------------------
14 template< class _TImage, class _TLabels, class _TScalar >
15 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
16 AddSeed( const TIndex& seed, const TLabel& label )
18 this->m_Seeds.push_back( seed );
19 this->m_Labels.push_back( label );
24 // -------------------------------------------------------------------------
25 template< class _TImage, class _TLabels, class _TScalar >
26 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
27 OriginalRandomWalker( )
30 fpaFilterInputConfigureMacro( InputLabels, TLabels );
31 fpaFilterOutputConfigureMacro( OutputProbabilities, TScalarImage );
34 // -------------------------------------------------------------------------
35 template< class _TImage, class _TLabels, class _TScalar >
36 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
37 ~OriginalRandomWalker( )
41 // -------------------------------------------------------------------------
42 template< class _TImage, class _TLabels, class _TScalar >
43 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
47 typedef Eigen::Triplet< TScalar > _TTriplet;
48 typedef std::vector< _TTriplet > _TTriplets;
49 typedef Eigen::SparseMatrix< TScalar > _TMatrix;
50 typedef Eigen::SimplicialLDLT< _TMatrix > _TSolver;
52 // Configure edge function
53 if( this->m_EdgeFunction.IsNull( ) )
54 itkExceptionMacro( << "Undefined edge function." );
55 const TImage* input = this->GetInput( );
56 this->m_EdgeFunction->SetDataObject( input );
59 this->AllocateOutputs( );
61 // Build boundary triplets and count labels
63 std::map< TLabel, unsigned long > labels;
64 this->_Boundary( St, labels );
67 bool operator()( const _TTriplet& a, const _TTriplet& b )
69 return( a.row( ) < b.row( ) );
72 std::sort( St.begin( ), St.end( ), _TTripletsOrd( ) );
73 for( unsigned long i = 0; i < St.size( ); ++i )
74 Bt.push_back( _TTriplet( i, labels[ St[ i ].col( ) ], St[ i ].value( ) ) );
78 this->_Laplacian( At, Rt, St );
81 TRegion region = input->GetRequestedRegion( );
82 unsigned long nSeeds = St.size( );
83 unsigned long nLabels = labels.size( );
84 unsigned long N = region.GetNumberOfPixels( );
86 std::vector< TLabel > invLabels( nLabels );
87 for( typename std::map< TLabel, unsigned long >::value_type s: labels )
88 invLabels[ s.second ] = s.first;
90 _TMatrix B( nSeeds, nLabels );
91 B.setFromTriplets( Bt.begin( ), Bt.end( ) );
94 _TMatrix R( N - nSeeds, nSeeds );
95 R.setFromTriplets( Rt.begin( ), Rt.end( ) );
98 _TMatrix A( N - nSeeds, N - nSeeds );
99 A.setFromTriplets( At.begin( ), At.end( ) );
102 // Solve dirichlet problem
105 if( solver.info( ) != Eigen::Success )
106 itkExceptionMacro( << "Error decomposing matrix." );
107 _TMatrix x = solver.solve( R * B );
108 if( solver.info( ) != Eigen::Success )
109 itkExceptionMacro( << "Error solving system." );
112 this->_Output( x, St, invLabels );
115 // -------------------------------------------------------------------------
116 template< class _TImage, class _TLabels, class _TScalar >
117 _TScalar fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
118 _L( const TIndex& i, const TIndex& j )
122 TRegion r = this->GetInput( )->GetRequestedRegion( );
123 TScalar s = TScalar( 0 );
124 for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
126 for( int n = -1; n <= 1; n += 2 )
130 if( r.IsInside( k ) )
131 s += this->m_EdgeFunction->Evaluate( i, k );
139 return( -( this->m_EdgeFunction->Evaluate( i, j ) ) );
142 // -------------------------------------------------------------------------
143 template< class _TImage, class _TLabels, class _TScalar >
144 template< class _TTriplets >
145 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
146 _Boundary( _TTriplets& B, std::map< TLabel, unsigned long >& labels )
150 // Set up the multithreaded processing
151 _TBoundaryThreadStruct thrStr;
152 thrStr.Filter = this;
153 thrStr.Triplets = reinterpret_cast< void* >( &B );
154 thrStr.Labels = &labels;
156 // Configure threader
157 const TLabels* in_labels = this->GetInputLabels( );
158 const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
159 const unsigned int nThreads =
160 split->GetNumberOfSplits(
161 in_labels->GetRequestedRegion( ), this->GetNumberOfThreads( )
164 itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
165 threads->SetNumberOfThreads( nThreads );
166 threads->SetSingleMethod( this->_BoundaryCbk< _TTriplets >, &thrStr );
169 threads->SingleMethodExecute( );
172 // -------------------------------------------------------------------------
173 template< class _TImage, class _TLabels, class _TScalar >
174 template< class _TTriplets >
175 ITK_THREAD_RETURN_TYPE
176 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
177 _BoundaryCbk( void* arg )
179 _TBoundaryThreadStruct* thrStr;
180 itk::ThreadIdType total, thrId, thrCount;
181 itk::MultiThreader::ThreadInfoStruct* thrInfo =
182 reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
183 thrId = thrInfo->ThreadID;
184 thrCount = thrInfo->NumberOfThreads;
185 thrStr = reinterpret_cast< _TBoundaryThreadStruct* >( thrInfo->UserData );
188 total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
190 thrStr->Filter->_ThreadedBoundary(
192 reinterpret_cast< _TTriplets* >( thrStr->Triplets ),
195 return( ITK_THREAD_RETURN_VALUE );
198 // -------------------------------------------------------------------------
199 template< class _TImage, class _TLabels, class _TScalar >
200 template< class _TTriplets >
201 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
203 const TRegion& region, const itk::ThreadIdType& id,
205 std::map< TLabel, unsigned long >* labels
208 typedef itk::ImageRegionConstIteratorWithIndex< TLabels > _TIt;
209 typedef typename std::map< TLabel, unsigned long >::value_type _TMapValue;
210 typedef typename std::map< unsigned long, TLabel >::value_type _TInvValue;
211 typedef typename _TTriplets::value_type _TTriplet;
213 const TLabels* in_labels = this->GetInputLabels( );
214 TRegion reqRegion = in_labels->GetRequestedRegion( );
215 _TIt it( in_labels, region );
216 for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
220 unsigned long i = Self::_1D( it.GetIndex( ), reqRegion );
221 this->m_Mutex.Lock( );
222 B->push_back( _TTriplet( i, it.Get( ), TScalar( 1 ) ) );
223 if( labels->find( it.Get( ) ) == labels->end( ) )
224 labels->insert( _TMapValue( it.Get( ), labels->size( ) ) );
225 this->m_Mutex.Unlock( );
232 // -------------------------------------------------------------------------
233 template< class _TImage, class _TLabels, class _TScalar >
234 template< class _TTriplets >
235 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
236 _Laplacian( _TTriplets& A, _TTriplets& R, const _TTriplets& B )
241 // Set up the multithreaded processing
242 _TLaplacianThreadStruct thrStr;
243 thrStr.Filter = this;
244 thrStr.A = reinterpret_cast< void* >( &A );
245 thrStr.R = reinterpret_cast< void* >( &R );
246 thrStr.B = reinterpret_cast< const void* >( &B );
248 // Configure threader
249 const TImage* in = this->GetInput( );
250 const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
251 const unsigned int nThreads =
252 split->GetNumberOfSplits(
253 in->GetRequestedRegion( ), this->GetNumberOfThreads( )
256 itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
257 threads->SetNumberOfThreads( nThreads );
258 threads->SetSingleMethod( this->_LaplacianCbk< _TTriplets >, &thrStr );
261 threads->SingleMethodExecute( );
264 // -------------------------------------------------------------------------
265 template< class _TImage, class _TLabels, class _TScalar >
266 template< class _TTriplets >
267 ITK_THREAD_RETURN_TYPE
268 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
269 _LaplacianCbk( void* arg )
271 _TLaplacianThreadStruct* thrStr;
272 itk::ThreadIdType total, thrId, thrCount;
273 itk::MultiThreader::ThreadInfoStruct* thrInfo =
274 reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
275 thrId = thrInfo->ThreadID;
276 thrCount = thrInfo->NumberOfThreads;
277 thrStr = reinterpret_cast< _TLaplacianThreadStruct* >( thrInfo->UserData );
280 total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
282 thrStr->Filter->_ThreadedLaplacian(
284 reinterpret_cast< _TTriplets* >( thrStr->A ),
285 reinterpret_cast< _TTriplets* >( thrStr->R ),
286 reinterpret_cast< const _TTriplets* >( thrStr->B )
288 return( ITK_THREAD_RETURN_VALUE );
291 // -------------------------------------------------------------------------
292 template< class _TImage, class _TLabels, class _TScalar >
293 template< class _TTriplets >
294 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
296 const TRegion& region, const itk::ThreadIdType& id,
297 _TTriplets* A, _TTriplets* R, const _TTriplets* B
300 typedef itk::ImageRegionConstIteratorWithIndex< TImage > _TIt;
301 typedef typename _TTriplets::value_type _TTriplet;
303 const TImage* in = this->GetInput( );
304 const TLabels* in_labels = this->GetInputLabels( );
305 TRegion reqRegion = in->GetRequestedRegion( );
306 _TIt it( in, region );
307 for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
309 TIndex idx = it.GetIndex( );
310 bool iSeed = ( in_labels->GetPixel( idx ) != 0 );
311 unsigned long i = Self::_1D( idx, reqRegion );
314 // A's diagonal values
317 si = Self::_NearSeedIndex( i, *B );
318 this->m_Mutex.Lock( );
319 A->push_back( _TTriplet( si, si, this->_L( idx, idx ) ) );
320 this->m_Mutex.Unlock( );
323 si = Self::_SeedIndex( i, *B );
325 // Neighbors (final matrix is symmetric)
326 for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
328 for( int s = -1; s <= 1; s += 2 )
332 if( reqRegion.IsInside( jdx ) )
334 TScalar L = this->_L( idx, jdx );
335 unsigned long j = Self::_1D( jdx, reqRegion );
336 bool jSeed = ( in_labels->GetPixel( jdx ) != 0 );
339 unsigned long sj = Self::_NearSeedIndex( j, *B );
342 this->m_Mutex.Lock( );
343 A->push_back( _TTriplet( si, sj, L ) );
344 this->m_Mutex.Unlock( );
348 this->m_Mutex.Lock( );
349 R->push_back( _TTriplet( sj, si, -L ) );
350 this->m_Mutex.Unlock( );
365 // -------------------------------------------------------------------------
366 template< class _TImage, class _TLabels, class _TScalar >
367 template< class _TMatrix, class _TTriplets >
368 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
370 const _TMatrix& X, const _TTriplets& S, const std::vector< TLabel >& invLabels
373 // Set up the multithreaded processing
374 _TOutputThreadStruct thrStr;
375 thrStr.Filter = this;
376 thrStr.X = reinterpret_cast< const void* >( &X );
377 thrStr.S = reinterpret_cast< const void* >( &S );
378 thrStr.InvLabels = &invLabels;
380 // Configure threader
381 const TLabels* out = this->GetOutput( );
382 const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
383 const unsigned int nThreads =
384 split->GetNumberOfSplits(
385 out->GetRequestedRegion( ), this->GetNumberOfThreads( )
388 itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
389 threads->SetNumberOfThreads( nThreads );
390 threads->SetSingleMethod(
391 this->_OutputCbk< _TMatrix, _TTriplets >, &thrStr
395 threads->SingleMethodExecute( );
398 // -------------------------------------------------------------------------
399 template< class _TImage, class _TLabels, class _TScalar >
400 template< class _TMatrix, class _TTriplets >
401 ITK_THREAD_RETURN_TYPE
402 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
403 _OutputCbk( void* arg )
405 _TOutputThreadStruct* thrStr;
406 itk::ThreadIdType total, thrId, thrCount;
407 itk::MultiThreader::ThreadInfoStruct* thrInfo =
408 reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
409 thrId = thrInfo->ThreadID;
410 thrCount = thrInfo->NumberOfThreads;
411 thrStr = reinterpret_cast< _TOutputThreadStruct* >( thrInfo->UserData );
414 total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
416 thrStr->Filter->_ThreadedOutput(
418 reinterpret_cast< const _TMatrix* >( thrStr->X ),
419 reinterpret_cast< const _TTriplets* >( thrStr->S ),
422 return( ITK_THREAD_RETURN_VALUE );
425 // -------------------------------------------------------------------------
426 template< class _TImage, class _TLabels, class _TScalar >
427 template< class _TMatrix, class _TTriplets >
428 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
430 const TRegion& region, const itk::ThreadIdType& id,
431 const _TMatrix* X, const _TTriplets* S,
432 const std::vector< TLabel >* invLabels
436 const TLabels* in_labels = this->GetInputLabels( );
437 TLabels* out_labels = this->GetOutput( );
438 TScalarImage* out_probs = this->GetOutputProbabilities( );
439 TRegion reqRegion = out_labels->GetRequestedRegion( );
440 itk::ImageRegionConstIteratorWithIndex< TLabels > iIt( in_labels, region );
441 itk::ImageRegionIteratorWithIndex< TLabels > oIt( out_labels, region );
442 itk::ImageRegionIteratorWithIndex< TScalarImage > pIt( out_probs, region );
446 for( ; !iIt.IsAtEnd( ); ++iIt, ++oIt, ++pIt )
448 if( iIt.Get( ) == 0 )
450 unsigned long i = Self::_1D( iIt.GetIndex( ), reqRegion );
451 unsigned long j = Self::_NearSeedIndex( i, *S );
452 TScalar maxP = X->coeff( j, 0 );
453 unsigned long maxL = 0;
454 for( unsigned int s = 1; s < invLabels->size( ); ++s )
456 TScalar p = X->coeff( j, s );
465 oIt.Set( ( *invLabels )[ maxL ] );
470 oIt.Set( iIt.Get( ) );
471 pIt.Set( TScalar( 1 ) );
478 // -------------------------------------------------------------------------
479 template< class _TImage, class _TLabels, class _TScalar >
481 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
482 _1D( const TIndex& idx, const TRegion& region )
484 unsigned long i = idx[ 0 ];
485 unsigned long off = 1;
486 typename TRegion::SizeType size = region.GetSize( );
487 for( unsigned int d = 1; d < TIndex::Dimension; ++d )
489 off *= size[ d - 1 ];
496 // -------------------------------------------------------------------------
497 template< class _TImage, class _TLabels, class _TScalar >
498 template< class _TTriplets >
500 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
501 _SeedIndex( const unsigned long& i, const _TTriplets& t )
504 unsigned long f = t.size( );
505 unsigned long e = f - 1;
506 while( e > s && f == t.size( ) )
510 unsigned long h = ( e + s ) >> 1;
511 if ( i < t[ h ].row( ) ) e = h;
512 else if( t[ h ].row( ) < i ) s = h;
516 f = ( t[ s ].row( ) == i )? s: e;
522 // -------------------------------------------------------------------------
523 template< class _TImage, class _TLabels, class _TScalar >
524 template< class _TTriplets >
526 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
527 _NearSeedIndex( const unsigned long& i, const _TTriplets& t )
530 long e = t.size( ) - 1;
533 long h = ( e + s ) >> 1;
534 if ( i < t[ h ].row( ) ) e = h;
535 else if( t[ h ].row( ) < i ) s = h;
539 if( i < t[ s ].row( ) )
541 else if( t[ s ].row( ) < i && i < t[ e ].row( ) )
549 #endif // __fpa__Common__OriginalRandomWalker__hxx__