1 // =========================================================================
2 // @author Leonardo Florez Valencia
3 // @email florez-l@javeriana.edu.co
4 // =========================================================================
5 #ifndef __fpa__Common__OriginalRandomWalker__hxx__
6 #define __fpa__Common__OriginalRandomWalker__hxx__
9 #include <itkImageRegionConstIteratorWithIndex.h>
10 #include <itkImageRegionIteratorWithIndex.h>
11 #include <Eigen/Sparse>
13 // -------------------------------------------------------------------------
14 template< class _TImage, class _TLabels, class _TScalar >
15 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
16 AddSeed( const TIndex& seed, const TLabel& label )
18 this->m_Seeds.push_back( seed );
19 this->m_Labels.push_back( label );
23 // -------------------------------------------------------------------------
24 template< class _TImage, class _TLabels, class _TScalar >
25 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
26 OriginalRandomWalker( )
28 m_Beta( TScalar( 90 ) ),
29 m_Epsilon( TScalar( 1e-5 ) ),
30 m_NormalizeWeights( true )
32 fpaFilterInputConfigureMacro( InputLabels, TLabels );
33 fpaFilterOutputConfigureMacro( OutputProbabilities, TScalarImage );
36 // -------------------------------------------------------------------------
37 template< class _TImage, class _TLabels, class _TScalar >
38 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
39 ~OriginalRandomWalker( )
43 // -------------------------------------------------------------------------
44 template< class _TImage, class _TLabels, class _TScalar >
45 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
49 typedef Eigen::Triplet< TScalar > _TTriplet;
50 typedef std::vector< _TTriplet > _TTriplets;
51 typedef Eigen::SparseMatrix< TScalar > _TMatrix;
52 typedef Eigen::SimplicialLDLT< _TMatrix > _TSolver;
55 this->AllocateOutputs( );
57 // Build boundary triplets and count labels
59 std::map< TLabel, unsigned long > labels;
60 this->_Boundary( St, labels );
63 bool operator()( const _TTriplet& a, const _TTriplet& b )
65 return( a.row( ) < b.row( ) );
68 std::sort( St.begin( ), St.end( ), _TTripletsOrd( ) );
69 for( unsigned long i = 0; i < St.size( ); ++i )
70 Bt.push_back( _TTriplet( i, labels[ St[ i ].col( ) ], St[ i ].value( ) ) );
74 this->_Laplacian( At, Rt, St );
77 TRegion region = this->GetInput( )->GetRequestedRegion( );
78 unsigned long nSeeds = St.size( );
79 unsigned long nLabels = labels.size( );
80 unsigned long N = region.GetNumberOfPixels( );
82 std::vector< TLabel > invLabels( nLabels );
83 for( typename std::map< TLabel, unsigned long >::value_type s: labels )
84 invLabels[ s.second ] = s.first;
86 _TMatrix B( nSeeds, nLabels );
87 B.setFromTriplets( Bt.begin( ), Bt.end( ) );
90 _TMatrix R( N - nSeeds, nSeeds );
91 R.setFromTriplets( Rt.begin( ), Rt.end( ) );
94 _TMatrix A( N - nSeeds, N - nSeeds );
95 A.setFromTriplets( At.begin( ), At.end( ) );
98 // Solve dirichlet problem
101 if( solver.info( ) != Eigen::Success )
103 std::cerr << "Error computing." << std::endl;
105 _TMatrix x = solver.solve( R * B );
106 if( solver.info( ) != Eigen::Success )
108 std::cerr << "Error solving." << std::endl;
112 this->_Output( x, St, invLabels );
115 // -------------------------------------------------------------------------
116 template< class _TImage, class _TLabels, class _TScalar >
117 _TScalar fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
118 _W( const TIndex& i, const TIndex& j )
120 const TImage* in = this->GetInput( );
121 TScalar a = TScalar( in->GetPixel( i ) );
122 TScalar b = TScalar( in->GetPixel( j ) );
123 TScalar v = std::exp( -this->m_Beta * std::fabs( a - b ) );
124 if( v < this->m_Epsilon )
125 return( this->m_Epsilon );
130 // -------------------------------------------------------------------------
131 template< class _TImage, class _TLabels, class _TScalar >
132 _TScalar fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
133 _L( const TIndex& i, const TIndex& j )
137 TRegion r = this->GetInput( )->GetRequestedRegion( );
138 TScalar s = TScalar( 0 );
139 for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
141 for( int n = -1; n <= 1; n += 2 )
145 if( r.IsInside( k ) )
146 s += this->_W( i, k );
154 return( this->_W( i, j ) * TScalar( -1 ) );
157 // -------------------------------------------------------------------------
158 template< class _TImage, class _TLabels, class _TScalar >
159 template< class _TTriplets >
160 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
161 _Boundary( _TTriplets& B, std::map< TLabel, unsigned long >& labels )
165 // Set up the multithreaded processing
166 _TBoundaryThreadStruct thrStr;
167 thrStr.Filter = this;
168 thrStr.Triplets = reinterpret_cast< void* >( &B );
169 thrStr.Labels = &labels;
171 // Configure threader
172 const TLabels* in_labels = this->GetInputLabels( );
173 const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
174 const unsigned int nThreads =
175 split->GetNumberOfSplits(
176 in_labels->GetRequestedRegion( ), this->GetNumberOfThreads( )
179 itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
180 threads->SetNumberOfThreads( nThreads );
181 threads->SetSingleMethod( this->_BoundaryCbk< _TTriplets >, &thrStr );
184 threads->SingleMethodExecute( );
187 // -------------------------------------------------------------------------
188 template< class _TImage, class _TLabels, class _TScalar >
189 template< class _TTriplets >
190 ITK_THREAD_RETURN_TYPE
191 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
192 _BoundaryCbk( void* arg )
194 _TBoundaryThreadStruct* thrStr;
195 itk::ThreadIdType total, thrId, thrCount;
196 itk::MultiThreader::ThreadInfoStruct* thrInfo =
197 reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
198 thrId = thrInfo->ThreadID;
199 thrCount = thrInfo->NumberOfThreads;
200 thrStr = reinterpret_cast< _TBoundaryThreadStruct* >( thrInfo->UserData );
203 total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
205 thrStr->Filter->_ThreadedBoundary(
207 reinterpret_cast< _TTriplets* >( thrStr->Triplets ),
210 return( ITK_THREAD_RETURN_VALUE );
213 // -------------------------------------------------------------------------
214 template< class _TImage, class _TLabels, class _TScalar >
215 template< class _TTriplets >
216 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
218 const TRegion& region, const itk::ThreadIdType& id,
220 std::map< TLabel, unsigned long >* labels
223 typedef itk::ImageRegionConstIteratorWithIndex< TLabels > _TIt;
224 typedef typename std::map< TLabel, unsigned long >::value_type _TMapValue;
225 typedef typename std::map< unsigned long, TLabel >::value_type _TInvValue;
226 typedef typename _TTriplets::value_type _TTriplet;
228 const TLabels* in_labels = this->GetInputLabels( );
229 TRegion reqRegion = in_labels->GetRequestedRegion( );
230 _TIt it( in_labels, region );
231 for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
235 unsigned long i = Self::_1D( it.GetIndex( ), reqRegion );
236 this->m_Mutex.Lock( );
237 B->push_back( _TTriplet( i, it.Get( ), TScalar( 1 ) ) );
238 if( labels->find( it.Get( ) ) == labels->end( ) )
239 labels->insert( _TMapValue( it.Get( ), labels->size( ) ) );
240 this->m_Mutex.Unlock( );
247 // -------------------------------------------------------------------------
248 template< class _TImage, class _TLabels, class _TScalar >
249 template< class _TTriplets >
250 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
251 _Laplacian( _TTriplets& A, _TTriplets& R, const _TTriplets& B )
256 // Set up the multithreaded processing
257 _TLaplacianThreadStruct thrStr;
258 thrStr.Filter = this;
259 thrStr.A = reinterpret_cast< void* >( &A );
260 thrStr.R = reinterpret_cast< void* >( &R );
261 thrStr.B = reinterpret_cast< const void* >( &B );
263 // Configure threader
264 const TImage* in = this->GetInputLabels( );
265 const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
266 const unsigned int nThreads =
267 split->GetNumberOfSplits(
268 in->GetRequestedRegion( ), this->GetNumberOfThreads( )
271 itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
272 threads->SetNumberOfThreads( nThreads );
273 threads->SetSingleMethod( this->_LaplacianCbk< _TTriplets >, &thrStr );
276 threads->SingleMethodExecute( );
279 // -------------------------------------------------------------------------
280 template< class _TImage, class _TLabels, class _TScalar >
281 template< class _TTriplets >
282 ITK_THREAD_RETURN_TYPE
283 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
284 _LaplacianCbk( void* arg )
286 _TLaplacianThreadStruct* thrStr;
287 itk::ThreadIdType total, thrId, thrCount;
288 itk::MultiThreader::ThreadInfoStruct* thrInfo =
289 reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
290 thrId = thrInfo->ThreadID;
291 thrCount = thrInfo->NumberOfThreads;
292 thrStr = reinterpret_cast< _TLaplacianThreadStruct* >( thrInfo->UserData );
295 total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
297 thrStr->Filter->_ThreadedLaplacian(
299 reinterpret_cast< _TTriplets* >( thrStr->A ),
300 reinterpret_cast< _TTriplets* >( thrStr->R ),
301 reinterpret_cast< const _TTriplets* >( thrStr->B )
303 return( ITK_THREAD_RETURN_VALUE );
306 // -------------------------------------------------------------------------
307 template< class _TImage, class _TLabels, class _TScalar >
308 template< class _TTriplets >
309 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
311 const TRegion& region, const itk::ThreadIdType& id,
312 _TTriplets* A, _TTriplets* R, const _TTriplets* B
315 typedef itk::ImageRegionConstIteratorWithIndex< TImage > _TIt;
316 typedef typename _TTriplets::value_type _TTriplet;
318 const TImage* in = this->GetInput( );
319 const TLabels* in_labels = this->GetInputLabels( );
320 TRegion rqRegion = in->GetRequestedRegion( );
321 _TIt it( in, region );
322 for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
324 TIndex idx = it.GetIndex( );
325 bool iSeed = ( in_labels->GetPixel( idx ) != 0 );
326 unsigned long i = Self::_1D( idx, rqRegion );
329 // A's diagonal values
332 si = Self::_NearSeedIndex( i, *B );
333 this->m_Mutex.Lock( );
334 A->push_back( _TTriplet( si, si, this->_L( idx, idx ) ) );
335 this->m_Mutex.Unlock( );
338 si = Self::_SeedIndex( i, *B );
340 // Neighbors (final matrix is symmetric)
341 for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
343 for( int s = -1; s <= 1; s += 2 )
347 if( rqRegion.IsInside( jdx ) )
349 TScalar L = this->_L( idx, jdx );
350 unsigned long j = Self::_1D( jdx, rqRegion );
351 bool jSeed = ( in_labels->GetPixel( jdx ) != 0 );
354 unsigned long sj = Self::_NearSeedIndex( j, *B );
357 this->m_Mutex.Lock( );
358 A->push_back( _TTriplet( si, sj, L ) );
359 this->m_Mutex.Unlock( );
363 this->m_Mutex.Lock( );
364 R->push_back( _TTriplet( sj, si, -L ) );
365 this->m_Mutex.Unlock( );
380 // -------------------------------------------------------------------------
381 template< class _TImage, class _TLabels, class _TScalar >
382 template< class _TMatrix, class _TTriplets >
383 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
385 const _TMatrix& X, const _TTriplets& S, const std::vector< TLabel >& invLabels
388 // Set up the multithreaded processing
389 _TOutputThreadStruct thrStr;
390 thrStr.Filter = this;
391 thrStr.X = reinterpret_cast< const void* >( &X );
392 thrStr.S = reinterpret_cast< const void* >( &S );
393 thrStr.InvLabels = &invLabels;
395 // Configure threader
396 const TLabels* out = this->GetOutput( );
397 const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
398 const unsigned int nThreads =
399 split->GetNumberOfSplits(
400 out->GetRequestedRegion( ), this->GetNumberOfThreads( )
403 itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
404 threads->SetNumberOfThreads( nThreads );
405 threads->SetSingleMethod(
406 this->_OutputCbk< _TMatrix, _TTriplets >, &thrStr
410 threads->SingleMethodExecute( );
413 // -------------------------------------------------------------------------
414 template< class _TImage, class _TLabels, class _TScalar >
415 template< class _TMatrix, class _TTriplets >
416 ITK_THREAD_RETURN_TYPE
417 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
418 _OutputCbk( void* arg )
420 _TOutputThreadStruct* thrStr;
421 itk::ThreadIdType total, thrId, thrCount;
422 itk::MultiThreader::ThreadInfoStruct* thrInfo =
423 reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
424 thrId = thrInfo->ThreadID;
425 thrCount = thrInfo->NumberOfThreads;
426 thrStr = reinterpret_cast< _TOutputThreadStruct* >( thrInfo->UserData );
429 total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
431 thrStr->Filter->_ThreadedOutput(
433 reinterpret_cast< const _TMatrix* >( thrStr->X ),
434 reinterpret_cast< const _TTriplets* >( thrStr->S ),
437 return( ITK_THREAD_RETURN_VALUE );
440 // -------------------------------------------------------------------------
441 template< class _TImage, class _TLabels, class _TScalar >
442 template< class _TMatrix, class _TTriplets >
443 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
445 const TRegion& region, const itk::ThreadIdType& id,
446 const _TMatrix* X, const _TTriplets* S,
447 const std::vector< TLabel >* invLabels
451 const TLabels* in_labels = this->GetInputLabels( );
452 TLabels* out_labels = this->GetOutput( );
453 TScalarImage* out_probs = this->GetOutputProbabilities( );
454 itk::ImageRegionConstIteratorWithIndex< TLabels > iIt( in_labels, region );
455 itk::ImageRegionIteratorWithIndex< TLabels > oIt( out_labels, region );
456 itk::ImageRegionIteratorWithIndex< TScalarImage > pIt( out_probs, region );
460 for( ; !iIt.IsAtEnd( ); ++iIt, ++oIt, ++pIt )
462 if( iIt.Get( ) == 0 )
464 unsigned long i = Self::_1D( iIt.GetIndex( ), region );
465 unsigned long j = Self::_NearSeedIndex( i, *S );
466 TScalar maxP = X->coeff( j, 0 );
467 unsigned long maxL = 0;
468 for( unsigned int s = 1; s < invLabels->size( ); ++s )
470 TScalar p = X->coeff( j, s );
479 oIt.Set( ( *invLabels )[ maxL ] );
484 oIt.Set( iIt.Get( ) );
485 pIt.Set( TScalar( 1 ) );
492 // -------------------------------------------------------------------------
493 template< class _TImage, class _TLabels, class _TScalar >
495 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
496 _1D( const TIndex& idx, const TRegion& region )
498 unsigned long i = idx[ 0 ];
499 unsigned long off = 1;
500 typename TRegion::SizeType size = region.GetSize( );
501 for( unsigned int d = 1; d < TIndex::Dimension; ++d )
503 off *= size[ d - 1 ];
510 // -------------------------------------------------------------------------
511 template< class _TImage, class _TLabels, class _TScalar >
512 template< class _TTriplets >
514 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
515 _SeedIndex( const unsigned long& i, const _TTriplets& t )
518 unsigned long f = t.size( );
519 unsigned long e = f - 1;
520 while( e > s && f == t.size( ) )
524 unsigned long h = ( e + s ) >> 1;
525 if ( i < t[ h ].row( ) ) e = h;
526 else if( t[ h ].row( ) < i ) s = h;
530 f = ( t[ s ].row( ) == i )? s: e;
536 // -------------------------------------------------------------------------
537 template< class _TImage, class _TLabels, class _TScalar >
538 template< class _TTriplets >
540 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
541 _NearSeedIndex( const unsigned long& i, const _TTriplets& t )
544 long e = t.size( ) - 1;
547 long h = ( e + s ) >> 1;
548 if ( i < t[ h ].row( ) ) e = h;
549 else if( t[ h ].row( ) < i ) s = h;
553 if( i < t[ s ].row( ) )
555 else if( t[ s ].row( ) < i && i < t[ e ].row( ) )
563 #endif // __fpa__Common__OriginalRandomWalker__hxx__