1 // =========================================================================
2 // @author Leonardo Florez Valencia
3 // @email florez-l@javeriana.edu.co
4 // =========================================================================
5 #ifndef __fpa__Common__OriginalRandomWalker__hxx__
6 #define __fpa__Common__OriginalRandomWalker__hxx__
9 #include <itkImageRegionConstIteratorWithIndex.h>
10 #include <itkImageRegionIteratorWithIndex.h>
11 #include <Eigen/Sparse>
13 // -------------------------------------------------------------------------
15 template< class _TImage, class _TLabels, class _TScalar >
16 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
17 AddSeed( const TIndex& seed, const TLabel& label )
19 this->m_Seeds.push_back( seed );
20 this->m_Labels.push_back( label );
25 // -------------------------------------------------------------------------
26 template< class _TImage, class _TLabels, class _TScalar >
27 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
28 OriginalRandomWalker( )
31 fpaFilterInputConfigureMacro( InputLabels, TLabels );
32 fpaFilterOutputConfigureMacro( OutputProbabilities, TScalarImage );
35 // -------------------------------------------------------------------------
36 template< class _TImage, class _TLabels, class _TScalar >
37 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
38 ~OriginalRandomWalker( )
42 // -------------------------------------------------------------------------
43 template< class _TImage, class _TLabels, class _TScalar >
44 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
48 typedef Eigen::Triplet< TScalar > _TTriplet;
49 typedef std::vector< _TTriplet > _TTriplets;
50 typedef Eigen::SparseMatrix< TScalar > _TMatrix;
51 typedef Eigen::SimplicialLDLT< _TMatrix > _TSolver;
53 // Configure edge function
54 if( this->m_EdgeFunction.IsNull( ) )
55 itkExceptionMacro( << "Undefined edge function." );
56 const TImage* input = this->GetInput( );
57 this->m_EdgeFunction->SetDataObject( input );
60 this->AllocateOutputs( );
62 // Build boundary triplets and count labels
64 std::map< TLabel, unsigned long > labels;
65 this->_Boundary( St, labels );
68 bool operator()( const _TTriplet& a, const _TTriplet& b )
70 return( a.row( ) < b.row( ) );
73 std::sort( St.begin( ), St.end( ), _TTripletsOrd( ) );
74 for( unsigned long i = 0; i < St.size( ); ++i )
75 Bt.push_back( _TTriplet( i, labels[ St[ i ].col( ) ], St[ i ].value( ) ) );
79 this->_Laplacian( At, Rt, St );
82 TRegion region = input->GetRequestedRegion( );
83 unsigned long nSeeds = St.size( );
84 unsigned long nLabels = labels.size( );
85 unsigned long N = region.GetNumberOfPixels( );
87 std::vector< TLabel > invLabels( nLabels );
88 for( typename std::map< TLabel, unsigned long >::value_type s: labels )
89 invLabels[ s.second ] = s.first;
91 _TMatrix B( nSeeds, nLabels );
92 B.setFromTriplets( Bt.begin( ), Bt.end( ) );
95 _TMatrix R( N - nSeeds, nSeeds );
96 R.setFromTriplets( Rt.begin( ), Rt.end( ) );
99 _TMatrix A( N - nSeeds, N - nSeeds );
100 A.setFromTriplets( At.begin( ), At.end( ) );
103 // Solve dirichlet problem
106 if( solver.info( ) != Eigen::Success )
107 itkExceptionMacro( << "Error decomposing matrix." );
108 _TMatrix x = solver.solve( R * B );
109 if( solver.info( ) != Eigen::Success )
110 itkExceptionMacro( << "Error solving system." );
113 this->_Output( x, St, invLabels );
116 // -------------------------------------------------------------------------
117 template< class _TImage, class _TLabels, class _TScalar >
118 _TScalar fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
119 _L( const TIndex& i, const TIndex& j )
123 TRegion r = this->GetInput( )->GetRequestedRegion( );
124 TScalar s = TScalar( 0 );
125 for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
127 for( int n = -1; n <= 1; n += 2 )
131 if( r.IsInside( k ) )
132 s += this->m_EdgeFunction->Evaluate( i, k );
140 return( -( this->m_EdgeFunction->Evaluate( i, j ) ) );
143 // -------------------------------------------------------------------------
144 template< class _TImage, class _TLabels, class _TScalar >
145 template< class _TTriplets >
146 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
147 _Boundary( _TTriplets& B, std::map< TLabel, unsigned long >& labels )
151 // Set up the multithreaded processing
152 _TBoundaryThreadStruct thrStr;
153 thrStr.Filter = this;
154 thrStr.Triplets = reinterpret_cast< void* >( &B );
155 thrStr.Labels = &labels;
157 // Configure threader
158 const TLabels* in_labels = this->GetInputLabels( );
159 const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
160 const unsigned int nThreads =
161 split->GetNumberOfSplits(
162 in_labels->GetRequestedRegion( ), this->GetNumberOfThreads( )
165 itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
166 threads->SetNumberOfThreads( nThreads );
167 threads->SetSingleMethod( this->_BoundaryCbk< _TTriplets >, &thrStr );
170 threads->SingleMethodExecute( );
173 // -------------------------------------------------------------------------
174 template< class _TImage, class _TLabels, class _TScalar >
175 template< class _TTriplets >
176 ITK_THREAD_RETURN_TYPE
177 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
178 _BoundaryCbk( void* arg )
180 _TBoundaryThreadStruct* thrStr;
181 itk::ThreadIdType total, thrId, thrCount;
182 itk::MultiThreader::ThreadInfoStruct* thrInfo =
183 reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
184 thrId = thrInfo->ThreadID;
185 thrCount = thrInfo->NumberOfThreads;
186 thrStr = reinterpret_cast< _TBoundaryThreadStruct* >( thrInfo->UserData );
189 total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
191 thrStr->Filter->_ThreadedBoundary(
193 reinterpret_cast< _TTriplets* >( thrStr->Triplets ),
196 return( ITK_THREAD_RETURN_VALUE );
199 // -------------------------------------------------------------------------
200 template< class _TImage, class _TLabels, class _TScalar >
201 template< class _TTriplets >
202 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
204 const TRegion& region, const itk::ThreadIdType& id,
206 std::map< TLabel, unsigned long >* labels
209 typedef itk::ImageRegionConstIteratorWithIndex< TLabels > _TIt;
210 typedef typename std::map< TLabel, unsigned long >::value_type _TMapValue;
211 typedef typename std::map< unsigned long, TLabel >::value_type _TInvValue;
212 typedef typename _TTriplets::value_type _TTriplet;
214 const TLabels* in_labels = this->GetInputLabels( );
215 TRegion reqRegion = in_labels->GetRequestedRegion( );
216 _TIt it( in_labels, region );
217 for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
221 unsigned long i = Self::_1D( it.GetIndex( ), reqRegion );
222 this->m_Mutex.Lock( );
223 B->push_back( _TTriplet( i, it.Get( ), TScalar( 1 ) ) );
224 if( labels->find( it.Get( ) ) == labels->end( ) )
225 labels->insert( _TMapValue( it.Get( ), labels->size( ) ) );
226 this->m_Mutex.Unlock( );
233 // -------------------------------------------------------------------------
234 template< class _TImage, class _TLabels, class _TScalar >
235 template< class _TTriplets >
236 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
237 _Laplacian( _TTriplets& A, _TTriplets& R, const _TTriplets& B )
242 // Set up the multithreaded processing
243 _TLaplacianThreadStruct thrStr;
244 thrStr.Filter = this;
245 thrStr.A = reinterpret_cast< void* >( &A );
246 thrStr.R = reinterpret_cast< void* >( &R );
247 thrStr.B = reinterpret_cast< const void* >( &B );
249 // Configure threader
250 const TImage* in = this->GetInputLabels( );
251 const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
252 const unsigned int nThreads =
253 split->GetNumberOfSplits(
254 in->GetRequestedRegion( ), this->GetNumberOfThreads( )
257 itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
258 threads->SetNumberOfThreads( nThreads );
259 threads->SetSingleMethod( this->_LaplacianCbk< _TTriplets >, &thrStr );
262 threads->SingleMethodExecute( );
265 // -------------------------------------------------------------------------
266 template< class _TImage, class _TLabels, class _TScalar >
267 template< class _TTriplets >
268 ITK_THREAD_RETURN_TYPE
269 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
270 _LaplacianCbk( void* arg )
272 _TLaplacianThreadStruct* thrStr;
273 itk::ThreadIdType total, thrId, thrCount;
274 itk::MultiThreader::ThreadInfoStruct* thrInfo =
275 reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
276 thrId = thrInfo->ThreadID;
277 thrCount = thrInfo->NumberOfThreads;
278 thrStr = reinterpret_cast< _TLaplacianThreadStruct* >( thrInfo->UserData );
281 total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
283 thrStr->Filter->_ThreadedLaplacian(
285 reinterpret_cast< _TTriplets* >( thrStr->A ),
286 reinterpret_cast< _TTriplets* >( thrStr->R ),
287 reinterpret_cast< const _TTriplets* >( thrStr->B )
289 return( ITK_THREAD_RETURN_VALUE );
292 // -------------------------------------------------------------------------
293 template< class _TImage, class _TLabels, class _TScalar >
294 template< class _TTriplets >
295 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
297 const TRegion& region, const itk::ThreadIdType& id,
298 _TTriplets* A, _TTriplets* R, const _TTriplets* B
301 typedef itk::ImageRegionConstIteratorWithIndex< TImage > _TIt;
302 typedef typename _TTriplets::value_type _TTriplet;
304 const TImage* in = this->GetInput( );
305 const TLabels* in_labels = this->GetInputLabels( );
306 TRegion reqRegion = in->GetRequestedRegion( );
307 _TIt it( in, region );
308 for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
310 TIndex idx = it.GetIndex( );
311 bool iSeed = ( in_labels->GetPixel( idx ) != 0 );
312 unsigned long i = Self::_1D( idx, reqRegion );
315 // A's diagonal values
318 si = Self::_NearSeedIndex( i, *B );
319 this->m_Mutex.Lock( );
320 A->push_back( _TTriplet( si, si, this->_L( idx, idx ) ) );
321 this->m_Mutex.Unlock( );
324 si = Self::_SeedIndex( i, *B );
326 // Neighbors (final matrix is symmetric)
327 for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
329 for( int s = -1; s <= 1; s += 2 )
333 if( reqRegion.IsInside( jdx ) )
335 TScalar L = this->_L( idx, jdx );
336 unsigned long j = Self::_1D( jdx, reqRegion );
337 bool jSeed = ( in_labels->GetPixel( jdx ) != 0 );
340 unsigned long sj = Self::_NearSeedIndex( j, *B );
343 this->m_Mutex.Lock( );
344 A->push_back( _TTriplet( si, sj, L ) );
345 this->m_Mutex.Unlock( );
349 this->m_Mutex.Lock( );
350 R->push_back( _TTriplet( sj, si, -L ) );
351 this->m_Mutex.Unlock( );
366 // -------------------------------------------------------------------------
367 template< class _TImage, class _TLabels, class _TScalar >
368 template< class _TMatrix, class _TTriplets >
369 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
371 const _TMatrix& X, const _TTriplets& S, const std::vector< TLabel >& invLabels
374 // Set up the multithreaded processing
375 _TOutputThreadStruct thrStr;
376 thrStr.Filter = this;
377 thrStr.X = reinterpret_cast< const void* >( &X );
378 thrStr.S = reinterpret_cast< const void* >( &S );
379 thrStr.InvLabels = &invLabels;
381 // Configure threader
382 const TLabels* out = this->GetOutput( );
383 const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
384 const unsigned int nThreads =
385 split->GetNumberOfSplits(
386 out->GetRequestedRegion( ), this->GetNumberOfThreads( )
389 itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
390 threads->SetNumberOfThreads( nThreads );
391 threads->SetSingleMethod(
392 this->_OutputCbk< _TMatrix, _TTriplets >, &thrStr
396 threads->SingleMethodExecute( );
399 // -------------------------------------------------------------------------
400 template< class _TImage, class _TLabels, class _TScalar >
401 template< class _TMatrix, class _TTriplets >
402 ITK_THREAD_RETURN_TYPE
403 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
404 _OutputCbk( void* arg )
406 _TOutputThreadStruct* thrStr;
407 itk::ThreadIdType total, thrId, thrCount;
408 itk::MultiThreader::ThreadInfoStruct* thrInfo =
409 reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
410 thrId = thrInfo->ThreadID;
411 thrCount = thrInfo->NumberOfThreads;
412 thrStr = reinterpret_cast< _TOutputThreadStruct* >( thrInfo->UserData );
415 total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
417 thrStr->Filter->_ThreadedOutput(
419 reinterpret_cast< const _TMatrix* >( thrStr->X ),
420 reinterpret_cast< const _TTriplets* >( thrStr->S ),
423 return( ITK_THREAD_RETURN_VALUE );
426 // -------------------------------------------------------------------------
427 template< class _TImage, class _TLabels, class _TScalar >
428 template< class _TMatrix, class _TTriplets >
429 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
431 const TRegion& region, const itk::ThreadIdType& id,
432 const _TMatrix* X, const _TTriplets* S,
433 const std::vector< TLabel >* invLabels
437 const TLabels* in_labels = this->GetInputLabels( );
438 TLabels* out_labels = this->GetOutput( );
439 TScalarImage* out_probs = this->GetOutputProbabilities( );
440 TRegion reqRegion = out_labels->GetRequestedRegion( );
441 itk::ImageRegionConstIteratorWithIndex< TLabels > iIt( in_labels, region );
442 itk::ImageRegionIteratorWithIndex< TLabels > oIt( out_labels, region );
443 itk::ImageRegionIteratorWithIndex< TScalarImage > pIt( out_probs, region );
447 for( ; !iIt.IsAtEnd( ); ++iIt, ++oIt, ++pIt )
449 if( iIt.Get( ) == 0 )
451 unsigned long i = Self::_1D( iIt.GetIndex( ), reqRegion );
452 unsigned long j = Self::_NearSeedIndex( i, *S );
453 TScalar maxP = X->coeff( j, 0 );
454 unsigned long maxL = 0;
455 for( unsigned int s = 1; s < invLabels->size( ); ++s )
457 TScalar p = X->coeff( j, s );
466 oIt.Set( ( *invLabels )[ maxL ] );
471 oIt.Set( iIt.Get( ) );
472 pIt.Set( TScalar( 1 ) );
479 // -------------------------------------------------------------------------
480 template< class _TImage, class _TLabels, class _TScalar >
482 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
483 _1D( const TIndex& idx, const TRegion& region )
485 unsigned long i = idx[ 0 ];
486 unsigned long off = 1;
487 typename TRegion::SizeType size = region.GetSize( );
488 for( unsigned int d = 1; d < TIndex::Dimension; ++d )
490 off *= size[ d - 1 ];
497 // -------------------------------------------------------------------------
498 template< class _TImage, class _TLabels, class _TScalar >
499 template< class _TTriplets >
501 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
502 _SeedIndex( const unsigned long& i, const _TTriplets& t )
505 unsigned long f = t.size( );
506 unsigned long e = f - 1;
507 while( e > s && f == t.size( ) )
511 unsigned long h = ( e + s ) >> 1;
512 if ( i < t[ h ].row( ) ) e = h;
513 else if( t[ h ].row( ) < i ) s = h;
517 f = ( t[ s ].row( ) == i )? s: e;
523 // -------------------------------------------------------------------------
524 template< class _TImage, class _TLabels, class _TScalar >
525 template< class _TTriplets >
527 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
528 _NearSeedIndex( const unsigned long& i, const _TTriplets& t )
531 long e = t.size( ) - 1;
534 long h = ( e + s ) >> 1;
535 if ( i < t[ h ].row( ) ) e = h;
536 else if( t[ h ].row( ) < i ) s = h;
540 if( i < t[ s ].row( ) )
542 else if( t[ s ].row( ) < i && i < t[ e ].row( ) )
550 #endif // __fpa__Common__OriginalRandomWalker__hxx__