1 // =========================================================================
2 // @author Leonardo Florez Valencia
3 // @email florez-l@javeriana.edu.co
4 // =========================================================================
5 #ifndef __fpa__Common__RandomWalker__hxx__
6 #define __fpa__Common__RandomWalker__hxx__
8 #include <itkImageRegionConstIteratorWithIndex.h>
9 #include <itkImageRegionIteratorWithIndex.h>
11 # include <Eigen/Sparse>
14 // -------------------------------------------------------------------------
15 template< class _TImage, class _TLabels, class _TScalar >
16 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
20 fpaFilterInputConfigureMacro( InputLabels, TLabels );
21 fpaFilterOutputConfigureMacro( OutputProbabilities, TScalarImage );
24 // -------------------------------------------------------------------------
25 template< class _TImage, class _TLabels, class _TScalar >
26 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
31 // -------------------------------------------------------------------------
32 template< class _TImage, class _TLabels, class _TScalar >
33 void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
38 typedef Eigen::Triplet< TScalar > _TTriplet;
39 typedef std::vector< _TTriplet > _TTriplets;
40 typedef Eigen::SparseMatrix< TScalar > _TMatrix;
41 typedef Eigen::SimplicialLDLT< _TMatrix > _TSolver;
43 // Configure edge function
44 if( this->m_EdgeFunction.IsNull( ) )
45 itkExceptionMacro( << "Undefined edge function." );
46 const TImage* input = this->GetInput( );
47 this->m_EdgeFunction->SetDataObject( input );
50 this->AllocateOutputs( );
53 _TMatrix A( 1, 1 ), C( 1, 1 );
55 std::vector< TLabel > invLabels;
58 // Build boundary triplets and count labels
60 std::map< TLabel, unsigned long > labels;
61 itkDebugMacro( << "Building boundary matrix..." );
62 this->_Boundary( St, labels );
65 bool operator()( const _TTriplet& a, const _TTriplet& b )
67 return( a.row( ) < b.row( ) );
70 itkDebugMacro( << "Sorting boundary pixels..." );
71 std::sort( St.begin( ), St.end( ), _TTripletsOrd( ) );
72 itkDebugMacro( << "Assigning boundary pixels..." );
73 for( unsigned long i = 0; i < St.size( ); ++i )
75 _TTriplet( i, labels[ St[ i ].col( ) ], St[ i ].value( ) )
79 itkDebugMacro( << "Building laplacian matrix..." );
81 this->_Laplacian( At, Rt, St );
84 TRegion region = input->GetRequestedRegion( );
85 unsigned long nSeeds = St.size( );
86 unsigned long nLabels = labels.size( );
87 unsigned long N = region.GetNumberOfPixels( );
89 itkDebugMacro( << "Creating inverse labels..." );
90 invLabels.resize( nLabels );
91 for( typename std::map< TLabel, unsigned long >::value_type s: labels )
92 invLabels[ s.second ] = s.first;
94 itkDebugMacro( << "Creating B matrix..." );
95 _TMatrix B( nSeeds, nLabels );
96 B.setFromTriplets( Bt.begin( ), Bt.end( ) );
99 itkDebugMacro( << "Creating R matrix..." );
100 _TMatrix R( N - nSeeds, nSeeds );
101 R.setFromTriplets( Rt.begin( ), Rt.end( ) );
104 itkDebugMacro( << "Creating C matrix..." );
107 itkDebugMacro( << "Creating A matrix..." );
108 A.resize( N - nSeeds, N - nSeeds );
109 A.setFromTriplets( At.begin( ), At.end( ) );
113 // Solve dirichlet problem
115 itkDebugMacro( << "Factorizing problem..." );
117 if( solver.info( ) != Eigen::Success )
118 itkExceptionMacro( << "Error decomposing matrix." );
119 itkDebugMacro( << "Solving problem..." );
120 _TMatrix x = solver.solve( C );
121 if( solver.info( ) != Eigen::Success )
122 itkExceptionMacro( << "Error solving system." );
125 itkDebugMacro( << "Filling output..." );
126 this->_Output( x, St, invLabels );
130 // -------------------------------------------------------------------------
131 template< class _TImage, class _TLabels, class _TScalar >
132 _TScalar fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
133 _L( const TIndex& i, const TIndex& j )
138 TRegion r = this->GetInput( )->GetRequestedRegion( );
139 TScalar s = TScalar( 0 );
140 for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
142 for( int n = -1; n <= 1; n += 2 )
146 if( r.IsInside( k ) )
147 s += this->m_EdgeFunction->Evaluate( i, k );
155 return( -( this->m_EdgeFunction->Evaluate( i, j ) ) );
157 return( _TScalar( 0 ) );
161 // -------------------------------------------------------------------------
162 template< class _TImage, class _TLabels, class _TScalar >
163 template< class _TTriplets >
164 void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
165 _Boundary( _TTriplets& B, std::map< TLabel, unsigned long >& labels )
170 // Set up the multithreaded processing
171 _TBoundaryThreadStruct thrStr;
172 thrStr.Filter = this;
173 thrStr.Triplets = reinterpret_cast< void* >( &B );
174 thrStr.Labels = &labels;
176 // Configure threader
177 const TLabels* in_labels = this->GetInputLabels( );
178 const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
179 const unsigned int nThreads =
180 split->GetNumberOfSplits(
181 in_labels->GetRequestedRegion( ), this->GetNumberOfThreads( )
184 itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
185 threads->SetNumberOfThreads( nThreads );
186 threads->SetSingleMethod( this->_BoundaryCbk< _TTriplets >, &thrStr );
189 threads->SingleMethodExecute( );
193 // -------------------------------------------------------------------------
194 template< class _TImage, class _TLabels, class _TScalar >
195 template< class _TTriplets >
196 ITK_THREAD_RETURN_TYPE
197 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
198 _BoundaryCbk( void* arg )
201 _TBoundaryThreadStruct* thrStr;
202 itk::ThreadIdType total, thrId, thrCount;
203 itk::MultiThreader::ThreadInfoStruct* thrInfo =
204 reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
205 thrId = thrInfo->ThreadID;
206 thrCount = thrInfo->NumberOfThreads;
207 thrStr = reinterpret_cast< _TBoundaryThreadStruct* >( thrInfo->UserData );
210 total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
212 thrStr->Filter->_ThreadedBoundary(
214 reinterpret_cast< _TTriplets* >( thrStr->Triplets ),
218 return( ITK_THREAD_RETURN_VALUE );
221 // -------------------------------------------------------------------------
222 template< class _TImage, class _TLabels, class _TScalar >
223 template< class _TTriplets >
224 void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
226 const TRegion& region, const itk::ThreadIdType& id,
228 std::map< TLabel, unsigned long >* labels
232 typedef itk::ImageRegionConstIteratorWithIndex< TLabels > _TIt;
233 typedef typename std::map< TLabel, unsigned long >::value_type _TMapValue;
234 typedef typename std::map< unsigned long, TLabel >::value_type _TInvValue;
235 typedef typename _TTriplets::value_type _TTriplet;
237 const TLabels* in_labels = this->GetInputLabels( );
238 TRegion reqRegion = in_labels->GetRequestedRegion( );
239 _TIt it( in_labels, region );
240 for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
244 unsigned long i = Self::_1D( it.GetIndex( ), reqRegion );
245 this->m_Mutex.Lock( );
246 B->push_back( _TTriplet( i, it.Get( ), TScalar( 1 ) ) );
247 if( labels->find( it.Get( ) ) == labels->end( ) )
248 labels->insert( _TMapValue( it.Get( ), labels->size( ) ) );
249 this->m_Mutex.Unlock( );
257 // -------------------------------------------------------------------------
258 template< class _TImage, class _TLabels, class _TScalar >
259 template< class _TTriplets >
260 void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
261 _Laplacian( _TTriplets& A, _TTriplets& R, const _TTriplets& B )
267 // Set up the multithreaded processing
268 _TLaplacianThreadStruct thrStr;
269 thrStr.Filter = this;
270 thrStr.A = reinterpret_cast< void* >( &A );
271 thrStr.R = reinterpret_cast< void* >( &R );
272 thrStr.B = reinterpret_cast< const void* >( &B );
274 // Configure threader
275 const TImage* in = this->GetInput( );
276 const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
277 const unsigned int nThreads =
278 split->GetNumberOfSplits(
279 in->GetRequestedRegion( ), this->GetNumberOfThreads( )
282 itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
283 threads->SetNumberOfThreads( nThreads );
284 threads->SetSingleMethod( this->_LaplacianCbk< _TTriplets >, &thrStr );
287 threads->SingleMethodExecute( );
291 // -------------------------------------------------------------------------
292 template< class _TImage, class _TLabels, class _TScalar >
293 template< class _TTriplets >
294 ITK_THREAD_RETURN_TYPE
295 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
296 _LaplacianCbk( void* arg )
299 _TLaplacianThreadStruct* thrStr;
300 itk::ThreadIdType total, thrId, thrCount;
301 itk::MultiThreader::ThreadInfoStruct* thrInfo =
302 reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
303 thrId = thrInfo->ThreadID;
304 thrCount = thrInfo->NumberOfThreads;
305 thrStr = reinterpret_cast< _TLaplacianThreadStruct* >( thrInfo->UserData );
308 total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
310 thrStr->Filter->_ThreadedLaplacian(
312 reinterpret_cast< _TTriplets* >( thrStr->A ),
313 reinterpret_cast< _TTriplets* >( thrStr->R ),
314 reinterpret_cast< const _TTriplets* >( thrStr->B )
317 return( ITK_THREAD_RETURN_VALUE );
320 // -------------------------------------------------------------------------
321 template< class _TImage, class _TLabels, class _TScalar >
322 template< class _TTriplets >
323 void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
325 const TRegion& region, const itk::ThreadIdType& id,
326 _TTriplets* A, _TTriplets* R, const _TTriplets* B
330 typedef itk::ImageRegionConstIteratorWithIndex< TImage > _TIt;
331 typedef typename _TTriplets::value_type _TTriplet;
333 const TImage* in = this->GetInput( );
334 const TLabels* in_labels = this->GetInputLabels( );
335 TRegion reqRegion = in->GetRequestedRegion( );
336 _TIt it( in, region );
337 for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
339 TIndex idx = it.GetIndex( );
340 bool iSeed = ( in_labels->GetPixel( idx ) != 0 );
341 unsigned long i = Self::_1D( idx, reqRegion );
344 // A's diagonal values
347 si = Self::_NearSeedIndex( i, *B );
348 this->m_Mutex.Lock( );
349 A->push_back( _TTriplet( si, si, this->_L( idx, idx ) ) );
350 this->m_Mutex.Unlock( );
353 si = Self::_SeedIndex( i, *B );
355 // Neighbors (final matrix is symmetric)
356 for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
358 for( int s = -1; s <= 1; s += 2 )
362 if( reqRegion.IsInside( jdx ) )
364 TScalar L = this->_L( idx, jdx );
365 unsigned long j = Self::_1D( jdx, reqRegion );
366 bool jSeed = ( in_labels->GetPixel( jdx ) != 0 );
369 unsigned long sj = Self::_NearSeedIndex( j, *B );
372 this->m_Mutex.Lock( );
373 A->push_back( _TTriplet( si, sj, L ) );
374 this->m_Mutex.Unlock( );
378 this->m_Mutex.Lock( );
379 R->push_back( _TTriplet( sj, si, -L ) );
380 this->m_Mutex.Unlock( );
396 // -------------------------------------------------------------------------
397 template< class _TImage, class _TLabels, class _TScalar >
398 template< class _TMatrix, class _TTriplets >
399 void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
401 const _TMatrix& X, const _TTriplets& S, const std::vector< TLabel >& invLabels
405 // Set up the multithreaded processing
406 _TOutputThreadStruct thrStr;
407 thrStr.Filter = this;
408 thrStr.X = reinterpret_cast< const void* >( &X );
409 thrStr.S = reinterpret_cast< const void* >( &S );
410 thrStr.InvLabels = &invLabels;
412 // Configure threader
413 const TLabels* out = this->GetOutput( );
414 const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
415 const unsigned int nThreads =
416 split->GetNumberOfSplits(
417 out->GetRequestedRegion( ), this->GetNumberOfThreads( )
420 itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
421 threads->SetNumberOfThreads( nThreads );
422 threads->SetSingleMethod(
423 this->_OutputCbk< _TMatrix, _TTriplets >, &thrStr
427 threads->SingleMethodExecute( );
431 // -------------------------------------------------------------------------
432 template< class _TImage, class _TLabels, class _TScalar >
433 template< class _TMatrix, class _TTriplets >
434 ITK_THREAD_RETURN_TYPE
435 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
436 _OutputCbk( void* arg )
439 _TOutputThreadStruct* thrStr;
440 itk::ThreadIdType total, thrId, thrCount;
441 itk::MultiThreader::ThreadInfoStruct* thrInfo =
442 reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
443 thrId = thrInfo->ThreadID;
444 thrCount = thrInfo->NumberOfThreads;
445 thrStr = reinterpret_cast< _TOutputThreadStruct* >( thrInfo->UserData );
448 total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
450 thrStr->Filter->_ThreadedOutput(
452 reinterpret_cast< const _TMatrix* >( thrStr->X ),
453 reinterpret_cast< const _TTriplets* >( thrStr->S ),
457 return( ITK_THREAD_RETURN_VALUE );
460 // -------------------------------------------------------------------------
461 template< class _TImage, class _TLabels, class _TScalar >
462 template< class _TMatrix, class _TTriplets >
463 void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
465 const TRegion& region, const itk::ThreadIdType& id,
466 const _TMatrix* X, const _TTriplets* S,
467 const std::vector< TLabel >* invLabels
472 const TLabels* in_labels = this->GetInputLabels( );
473 TLabels* out_labels = this->GetOutput( );
474 TScalarImage* out_probs = this->GetOutputProbabilities( );
475 TRegion reqRegion = out_labels->GetRequestedRegion( );
476 itk::ImageRegionConstIteratorWithIndex< TLabels > iIt( in_labels, region );
477 itk::ImageRegionIteratorWithIndex< TLabels > oIt( out_labels, region );
478 itk::ImageRegionIteratorWithIndex< TScalarImage > pIt( out_probs, region );
482 for( ; !iIt.IsAtEnd( ); ++iIt, ++oIt, ++pIt )
484 if( iIt.Get( ) == 0 )
486 unsigned long i = Self::_1D( iIt.GetIndex( ), reqRegion );
487 unsigned long j = Self::_NearSeedIndex( i, *S );
488 TScalar maxP = X->coeff( j, 0 );
489 unsigned long maxL = 0;
490 for( unsigned int s = 1; s < invLabels->size( ); ++s )
492 TScalar p = X->coeff( j, s );
501 oIt.Set( ( *invLabels )[ maxL ] );
506 oIt.Set( iIt.Get( ) );
507 pIt.Set( TScalar( 1 ) );
515 // -------------------------------------------------------------------------
516 template< class _TImage, class _TLabels, class _TScalar >
518 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
519 _1D( const TIndex& idx, const TRegion& region )
521 unsigned long i = idx[ 0 ];
522 unsigned long off = 1;
523 typename TRegion::SizeType size = region.GetSize( );
524 for( unsigned int d = 1; d < TIndex::Dimension; ++d )
526 off *= size[ d - 1 ];
533 // -------------------------------------------------------------------------
534 template< class _TImage, class _TLabels, class _TScalar >
535 template< class _TTriplets >
537 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
538 _SeedIndex( const unsigned long& i, const _TTriplets& t )
541 unsigned long f = t.size( );
542 unsigned long e = f - 1;
543 while( e > s && f == t.size( ) )
547 unsigned long h = ( e + s ) >> 1;
548 if ( i < t[ h ].row( ) ) e = h;
549 else if( t[ h ].row( ) < i ) s = h;
553 f = ( t[ s ].row( ) == i )? s: e;
559 // -------------------------------------------------------------------------
560 template< class _TImage, class _TLabels, class _TScalar >
561 template< class _TTriplets >
563 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
564 _NearSeedIndex( const unsigned long& i, const _TTriplets& t )
567 long e = t.size( ) - 1;
570 long h = ( e + s ) >> 1;
571 if ( i < t[ h ].row( ) ) e = h;
572 else if( t[ h ].row( ) < i ) s = h;
576 if( i < t[ s ].row( ) )
578 else if( t[ s ].row( ) < i && i < t[ e ].row( ) )
586 #endif // __fpa__Common__RandomWalker__hxx__