1 // =========================================================================
2 // @author Leonardo Florez Valencia
3 // @email florez-l@javeriana.edu.co
4 // =========================================================================
5 #ifndef __fpa__Common__OriginalRandomWalker__hxx__
6 #define __fpa__Common__OriginalRandomWalker__hxx__
8 #include <itkImageRegionConstIteratorWithIndex.h>
9 #include <itkImageRegionIteratorWithIndex.h>
10 #include <Eigen/Sparse>
12 // -------------------------------------------------------------------------
14 template< class _TImage, class _TLabels, class _TScalar >
15 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
16 AddSeed( const TIndex& seed, const TLabel& label )
18 this->m_Seeds.push_back( seed );
19 this->m_Labels.push_back( label );
24 // -------------------------------------------------------------------------
25 template< class _TImage, class _TLabels, class _TScalar >
26 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
27 OriginalRandomWalker( )
30 fpaFilterInputConfigureMacro( InputLabels, TLabels );
31 fpaFilterOutputConfigureMacro( OutputProbabilities, TScalarImage );
34 // -------------------------------------------------------------------------
35 template< class _TImage, class _TLabels, class _TScalar >
36 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
37 ~OriginalRandomWalker( )
41 // -------------------------------------------------------------------------
42 template< class _TImage, class _TLabels, class _TScalar >
43 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
47 typedef Eigen::Triplet< TScalar > _TTriplet;
48 typedef std::vector< _TTriplet > _TTriplets;
49 typedef Eigen::SparseMatrix< TScalar > _TMatrix;
50 typedef Eigen::SimplicialLDLT< _TMatrix > _TSolver;
52 // Configure edge function
53 if( this->m_EdgeFunction.IsNull( ) )
54 itkExceptionMacro( << "Undefined edge function." );
55 const TImage* input = this->GetInput( );
56 this->m_EdgeFunction->SetDataObject( input );
59 this->AllocateOutputs( );
62 _TMatrix A( 1, 1 ), C( 1, 1 );
64 std::vector< TLabel > invLabels;
67 // Build boundary triplets and count labels
69 std::map< TLabel, unsigned long > labels;
70 itkDebugMacro( << "Building boundary matrix..." );
71 this->_Boundary( St, labels );
74 bool operator()( const _TTriplet& a, const _TTriplet& b )
76 return( a.row( ) < b.row( ) );
79 itkDebugMacro( << "Sorting boundary pixels..." );
80 std::sort( St.begin( ), St.end( ), _TTripletsOrd( ) );
81 itkDebugMacro( << "Assigning boundary pixels..." );
82 for( unsigned long i = 0; i < St.size( ); ++i )
84 _TTriplet( i, labels[ St[ i ].col( ) ], St[ i ].value( ) )
88 itkDebugMacro( << "Building laplacian matrix..." );
90 this->_Laplacian( At, Rt, St );
93 TRegion region = input->GetRequestedRegion( );
94 unsigned long nSeeds = St.size( );
95 unsigned long nLabels = labels.size( );
96 unsigned long N = region.GetNumberOfPixels( );
98 itkDebugMacro( << "Creating inverse labels..." );
99 invLabels.resize( nLabels );
100 for( typename std::map< TLabel, unsigned long >::value_type s: labels )
101 invLabels[ s.second ] = s.first;
103 itkDebugMacro( << "Creating B matrix..." );
104 _TMatrix B( nSeeds, nLabels );
105 B.setFromTriplets( Bt.begin( ), Bt.end( ) );
108 itkDebugMacro( << "Creating R matrix..." );
109 _TMatrix R( N - nSeeds, nSeeds );
110 R.setFromTriplets( Rt.begin( ), Rt.end( ) );
113 itkDebugMacro( << "Creating C matrix..." );
116 itkDebugMacro( << "Creating A matrix..." );
117 A.resize( N - nSeeds, N - nSeeds );
118 A.setFromTriplets( At.begin( ), At.end( ) );
122 // Solve dirichlet problem
124 itkDebugMacro( << "Factorizing problem..." );
126 if( solver.info( ) != Eigen::Success )
127 itkExceptionMacro( << "Error decomposing matrix." );
128 itkDebugMacro( << "Solving problem..." );
129 _TMatrix x = solver.solve( C );
130 if( solver.info( ) != Eigen::Success )
131 itkExceptionMacro( << "Error solving system." );
134 itkDebugMacro( << "Filling output..." );
135 this->_Output( x, St, invLabels );
138 // -------------------------------------------------------------------------
139 template< class _TImage, class _TLabels, class _TScalar >
140 _TScalar fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
141 _L( const TIndex& i, const TIndex& j )
145 TRegion r = this->GetInput( )->GetRequestedRegion( );
146 TScalar s = TScalar( 0 );
147 for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
149 for( int n = -1; n <= 1; n += 2 )
153 if( r.IsInside( k ) )
154 s += this->m_EdgeFunction->Evaluate( i, k );
162 return( -( this->m_EdgeFunction->Evaluate( i, j ) ) );
165 // -------------------------------------------------------------------------
166 template< class _TImage, class _TLabels, class _TScalar >
167 template< class _TTriplets >
168 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
169 _Boundary( _TTriplets& B, std::map< TLabel, unsigned long >& labels )
173 // Set up the multithreaded processing
174 _TBoundaryThreadStruct thrStr;
175 thrStr.Filter = this;
176 thrStr.Triplets = reinterpret_cast< void* >( &B );
177 thrStr.Labels = &labels;
179 // Configure threader
180 const TLabels* in_labels = this->GetInputLabels( );
181 const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
182 const unsigned int nThreads =
183 split->GetNumberOfSplits(
184 in_labels->GetRequestedRegion( ), this->GetNumberOfThreads( )
187 itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
188 threads->SetNumberOfThreads( nThreads );
189 threads->SetSingleMethod( this->_BoundaryCbk< _TTriplets >, &thrStr );
192 threads->SingleMethodExecute( );
195 // -------------------------------------------------------------------------
196 template< class _TImage, class _TLabels, class _TScalar >
197 template< class _TTriplets >
198 ITK_THREAD_RETURN_TYPE
199 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
200 _BoundaryCbk( void* arg )
202 _TBoundaryThreadStruct* thrStr;
203 itk::ThreadIdType total, thrId, thrCount;
204 itk::MultiThreader::ThreadInfoStruct* thrInfo =
205 reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
206 thrId = thrInfo->ThreadID;
207 thrCount = thrInfo->NumberOfThreads;
208 thrStr = reinterpret_cast< _TBoundaryThreadStruct* >( thrInfo->UserData );
211 total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
213 thrStr->Filter->_ThreadedBoundary(
215 reinterpret_cast< _TTriplets* >( thrStr->Triplets ),
218 return( ITK_THREAD_RETURN_VALUE );
221 // -------------------------------------------------------------------------
222 template< class _TImage, class _TLabels, class _TScalar >
223 template< class _TTriplets >
224 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
226 const TRegion& region, const itk::ThreadIdType& id,
228 std::map< TLabel, unsigned long >* labels
231 typedef itk::ImageRegionConstIteratorWithIndex< TLabels > _TIt;
232 typedef typename std::map< TLabel, unsigned long >::value_type _TMapValue;
233 typedef typename std::map< unsigned long, TLabel >::value_type _TInvValue;
234 typedef typename _TTriplets::value_type _TTriplet;
236 const TLabels* in_labels = this->GetInputLabels( );
237 TRegion reqRegion = in_labels->GetRequestedRegion( );
238 _TIt it( in_labels, region );
239 for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
243 unsigned long i = Self::_1D( it.GetIndex( ), reqRegion );
244 this->m_Mutex.Lock( );
245 B->push_back( _TTriplet( i, it.Get( ), TScalar( 1 ) ) );
246 if( labels->find( it.Get( ) ) == labels->end( ) )
247 labels->insert( _TMapValue( it.Get( ), labels->size( ) ) );
248 this->m_Mutex.Unlock( );
255 // -------------------------------------------------------------------------
256 template< class _TImage, class _TLabels, class _TScalar >
257 template< class _TTriplets >
258 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
259 _Laplacian( _TTriplets& A, _TTriplets& R, const _TTriplets& B )
264 // Set up the multithreaded processing
265 _TLaplacianThreadStruct thrStr;
266 thrStr.Filter = this;
267 thrStr.A = reinterpret_cast< void* >( &A );
268 thrStr.R = reinterpret_cast< void* >( &R );
269 thrStr.B = reinterpret_cast< const void* >( &B );
271 // Configure threader
272 const TImage* in = this->GetInput( );
273 const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
274 const unsigned int nThreads =
275 split->GetNumberOfSplits(
276 in->GetRequestedRegion( ), this->GetNumberOfThreads( )
279 itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
280 threads->SetNumberOfThreads( nThreads );
281 threads->SetSingleMethod( this->_LaplacianCbk< _TTriplets >, &thrStr );
284 threads->SingleMethodExecute( );
287 // -------------------------------------------------------------------------
288 template< class _TImage, class _TLabels, class _TScalar >
289 template< class _TTriplets >
290 ITK_THREAD_RETURN_TYPE
291 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
292 _LaplacianCbk( void* arg )
294 _TLaplacianThreadStruct* thrStr;
295 itk::ThreadIdType total, thrId, thrCount;
296 itk::MultiThreader::ThreadInfoStruct* thrInfo =
297 reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
298 thrId = thrInfo->ThreadID;
299 thrCount = thrInfo->NumberOfThreads;
300 thrStr = reinterpret_cast< _TLaplacianThreadStruct* >( thrInfo->UserData );
303 total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
305 thrStr->Filter->_ThreadedLaplacian(
307 reinterpret_cast< _TTriplets* >( thrStr->A ),
308 reinterpret_cast< _TTriplets* >( thrStr->R ),
309 reinterpret_cast< const _TTriplets* >( thrStr->B )
311 return( ITK_THREAD_RETURN_VALUE );
314 // -------------------------------------------------------------------------
315 template< class _TImage, class _TLabels, class _TScalar >
316 template< class _TTriplets >
317 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
319 const TRegion& region, const itk::ThreadIdType& id,
320 _TTriplets* A, _TTriplets* R, const _TTriplets* B
323 typedef itk::ImageRegionConstIteratorWithIndex< TImage > _TIt;
324 typedef typename _TTriplets::value_type _TTriplet;
326 const TImage* in = this->GetInput( );
327 const TLabels* in_labels = this->GetInputLabels( );
328 TRegion reqRegion = in->GetRequestedRegion( );
329 _TIt it( in, region );
330 for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
332 TIndex idx = it.GetIndex( );
333 bool iSeed = ( in_labels->GetPixel( idx ) != 0 );
334 unsigned long i = Self::_1D( idx, reqRegion );
337 // A's diagonal values
340 si = Self::_NearSeedIndex( i, *B );
341 this->m_Mutex.Lock( );
342 A->push_back( _TTriplet( si, si, this->_L( idx, idx ) ) );
343 this->m_Mutex.Unlock( );
346 si = Self::_SeedIndex( i, *B );
348 // Neighbors (final matrix is symmetric)
349 for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
351 for( int s = -1; s <= 1; s += 2 )
355 if( reqRegion.IsInside( jdx ) )
357 TScalar L = this->_L( idx, jdx );
358 unsigned long j = Self::_1D( jdx, reqRegion );
359 bool jSeed = ( in_labels->GetPixel( jdx ) != 0 );
362 unsigned long sj = Self::_NearSeedIndex( j, *B );
365 this->m_Mutex.Lock( );
366 A->push_back( _TTriplet( si, sj, L ) );
367 this->m_Mutex.Unlock( );
371 this->m_Mutex.Lock( );
372 R->push_back( _TTriplet( sj, si, -L ) );
373 this->m_Mutex.Unlock( );
388 // -------------------------------------------------------------------------
389 template< class _TImage, class _TLabels, class _TScalar >
390 template< class _TMatrix, class _TTriplets >
391 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
393 const _TMatrix& X, const _TTriplets& S, const std::vector< TLabel >& invLabels
396 // Set up the multithreaded processing
397 _TOutputThreadStruct thrStr;
398 thrStr.Filter = this;
399 thrStr.X = reinterpret_cast< const void* >( &X );
400 thrStr.S = reinterpret_cast< const void* >( &S );
401 thrStr.InvLabels = &invLabels;
403 // Configure threader
404 const TLabels* out = this->GetOutput( );
405 const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
406 const unsigned int nThreads =
407 split->GetNumberOfSplits(
408 out->GetRequestedRegion( ), this->GetNumberOfThreads( )
411 itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
412 threads->SetNumberOfThreads( nThreads );
413 threads->SetSingleMethod(
414 this->_OutputCbk< _TMatrix, _TTriplets >, &thrStr
418 threads->SingleMethodExecute( );
421 // -------------------------------------------------------------------------
422 template< class _TImage, class _TLabels, class _TScalar >
423 template< class _TMatrix, class _TTriplets >
424 ITK_THREAD_RETURN_TYPE
425 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
426 _OutputCbk( void* arg )
428 _TOutputThreadStruct* thrStr;
429 itk::ThreadIdType total, thrId, thrCount;
430 itk::MultiThreader::ThreadInfoStruct* thrInfo =
431 reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
432 thrId = thrInfo->ThreadID;
433 thrCount = thrInfo->NumberOfThreads;
434 thrStr = reinterpret_cast< _TOutputThreadStruct* >( thrInfo->UserData );
437 total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
439 thrStr->Filter->_ThreadedOutput(
441 reinterpret_cast< const _TMatrix* >( thrStr->X ),
442 reinterpret_cast< const _TTriplets* >( thrStr->S ),
445 return( ITK_THREAD_RETURN_VALUE );
448 // -------------------------------------------------------------------------
449 template< class _TImage, class _TLabels, class _TScalar >
450 template< class _TMatrix, class _TTriplets >
451 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
453 const TRegion& region, const itk::ThreadIdType& id,
454 const _TMatrix* X, const _TTriplets* S,
455 const std::vector< TLabel >* invLabels
459 const TLabels* in_labels = this->GetInputLabels( );
460 TLabels* out_labels = this->GetOutput( );
461 TScalarImage* out_probs = this->GetOutputProbabilities( );
462 TRegion reqRegion = out_labels->GetRequestedRegion( );
463 itk::ImageRegionConstIteratorWithIndex< TLabels > iIt( in_labels, region );
464 itk::ImageRegionIteratorWithIndex< TLabels > oIt( out_labels, region );
465 itk::ImageRegionIteratorWithIndex< TScalarImage > pIt( out_probs, region );
469 for( ; !iIt.IsAtEnd( ); ++iIt, ++oIt, ++pIt )
471 if( iIt.Get( ) == 0 )
473 unsigned long i = Self::_1D( iIt.GetIndex( ), reqRegion );
474 unsigned long j = Self::_NearSeedIndex( i, *S );
475 TScalar maxP = X->coeff( j, 0 );
476 unsigned long maxL = 0;
477 for( unsigned int s = 1; s < invLabels->size( ); ++s )
479 TScalar p = X->coeff( j, s );
488 oIt.Set( ( *invLabels )[ maxL ] );
493 oIt.Set( iIt.Get( ) );
494 pIt.Set( TScalar( 1 ) );
501 // -------------------------------------------------------------------------
502 template< class _TImage, class _TLabels, class _TScalar >
504 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
505 _1D( const TIndex& idx, const TRegion& region )
507 unsigned long i = idx[ 0 ];
508 unsigned long off = 1;
509 typename TRegion::SizeType size = region.GetSize( );
510 for( unsigned int d = 1; d < TIndex::Dimension; ++d )
512 off *= size[ d - 1 ];
519 // -------------------------------------------------------------------------
520 template< class _TImage, class _TLabels, class _TScalar >
521 template< class _TTriplets >
523 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
524 _SeedIndex( const unsigned long& i, const _TTriplets& t )
527 unsigned long f = t.size( );
528 unsigned long e = f - 1;
529 while( e > s && f == t.size( ) )
533 unsigned long h = ( e + s ) >> 1;
534 if ( i < t[ h ].row( ) ) e = h;
535 else if( t[ h ].row( ) < i ) s = h;
539 f = ( t[ s ].row( ) == i )? s: e;
545 // -------------------------------------------------------------------------
546 template< class _TImage, class _TLabels, class _TScalar >
547 template< class _TTriplets >
549 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
550 _NearSeedIndex( const unsigned long& i, const _TTriplets& t )
553 long e = t.size( ) - 1;
556 long h = ( e + s ) >> 1;
557 if ( i < t[ h ].row( ) ) e = h;
558 else if( t[ h ].row( ) < i ) s = h;
562 if( i < t[ s ].row( ) )
564 else if( t[ s ].row( ) < i && i < t[ e ].row( ) )
572 #endif // __fpa__Common__OriginalRandomWalker__hxx__