// ========================================================================= // @author Leonardo Florez Valencia // @email florez-l@javeriana.edu.co // ========================================================================= #ifndef __fpa__Common__RandomWalker__hxx__ #define __fpa__Common__RandomWalker__hxx__ #include #include #ifdef USE_Eigen3 # include #endif // USE_Eigen3 // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: RandomWalker( ) : Superclass( ) { ivqITKInputConfigureMacro( InputLabels, TLabels ); ivqITKOutputConfigureMacro( OutputProbabilities, TScalarImage ); } // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: ~RandomWalker( ) { } // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: GenerateData( ) { #ifdef USE_Eigen3 // Useful typedefs typedef Eigen::Triplet< TScalar > _TTriplet; typedef std::vector< _TTriplet > _TTriplets; typedef Eigen::SparseMatrix< TScalar > _TMatrix; typedef Eigen::SimplicialLDLT< _TMatrix > _TSolver; // Configure edge function if( this->m_EdgeFunction.IsNull( ) ) itkExceptionMacro( << "Undefined edge function." ); const TImage* input = this->GetInput( ); this->m_EdgeFunction->SetDataObject( input ); // Allocate outputs this->AllocateOutputs( ); // Persisting objects _TMatrix A( 1, 1 ), C( 1, 1 ); _TTriplets St; std::vector< TLabel > invLabels; { // begin // Build boundary triplets and count labels _TTriplets Bt; std::map< TLabel, unsigned long > labels; itkDebugMacro( << "Building boundary matrix..." ); this->_Boundary( St, labels ); struct _TTripletsOrd { bool operator()( const _TTriplet& a, const _TTriplet& b ) { return( a.row( ) < b.row( ) ); } }; itkDebugMacro( << "Sorting boundary pixels..." ); std::sort( St.begin( ), St.end( ), _TTripletsOrd( ) ); itkDebugMacro( << "Assigning boundary pixels..." ); for( unsigned long i = 0; i < St.size( ); ++i ) Bt.push_back( _TTriplet( i, labels[ St[ i ].col( ) ], St[ i ].value( ) ) ); // Laplacian triplets itkDebugMacro( << "Building laplacian matrix..." ); _TTriplets At, Rt; this->_Laplacian( At, Rt, St ); // Matrices TRegion region = input->GetRequestedRegion( ); unsigned long nSeeds = St.size( ); unsigned long nLabels = labels.size( ); unsigned long N = region.GetNumberOfPixels( ); itkDebugMacro( << "Creating inverse labels..." ); invLabels.resize( nLabels ); for( typename std::map< TLabel, unsigned long >::value_type s: labels ) invLabels[ s.second ] = s.first; itkDebugMacro( << "Creating B matrix..." ); _TMatrix B( nSeeds, nLabels ); B.setFromTriplets( Bt.begin( ), Bt.end( ) ); B.makeCompressed( ); itkDebugMacro( << "Creating R matrix..." ); _TMatrix R( N - nSeeds, nSeeds ); R.setFromTriplets( Rt.begin( ), Rt.end( ) ); R.makeCompressed( ); itkDebugMacro( << "Creating C matrix..." ); C = R * B; itkDebugMacro( << "Creating A matrix..." ); A.resize( N - nSeeds, N - nSeeds ); A.setFromTriplets( At.begin( ), At.end( ) ); A.makeCompressed( ); } // end // Solve dirichlet problem _TSolver solver; itkDebugMacro( << "Factorizing problem..." ); solver.compute( A ); if( solver.info( ) != Eigen::Success ) itkExceptionMacro( << "Error decomposing matrix." ); itkDebugMacro( << "Solving problem..." ); _TMatrix x = solver.solve( C ); if( solver.info( ) != Eigen::Success ) itkExceptionMacro( << "Error solving system." ); // Fill outputs itkDebugMacro( << "Filling output..." ); this->_Output( x, St, invLabels ); #endif // USE_Eigen3 } // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > _TScalar fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: _L( const TIndex& i, const TIndex& j ) { #ifdef USE_Eigen3 if( i == j ) { TRegion r = this->GetInput( )->GetRequestedRegion( ); TScalar s = TScalar( 0 ); for( unsigned int d = 0; d < TImage::ImageDimension; ++d ) { for( int n = -1; n <= 1; n += 2 ) { TIndex k = i; k[ d ] += n; if( r.IsInside( k ) ) s += this->m_EdgeFunction->Evaluate( i, k ); } // rof } // rof return( s ); } else return( -( this->m_EdgeFunction->Evaluate( i, j ) ) ); #else // USE_Eigen3 return( _TScalar( 0 ) ); #endif // USE_Eigen3 } // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > template< class _TTriplets > void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: _Boundary( _TTriplets& B, std::map< TLabel, unsigned long >& labels ) { #ifdef USE_Eigen3 B.clear( ); // Set up the multithreaded processing _TBoundaryThreadStruct thrStr; thrStr.Filter = this; thrStr.Triplets = reinterpret_cast< void* >( &B ); thrStr.Labels = &labels; // Configure threader const TLabels* in_labels = this->GetInputLabels( ); const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( ); const unsigned int nThreads = split->GetNumberOfSplits( in_labels->GetRequestedRegion( ), this->GetNumberOfThreads( ) ); itk::MultiThreader::Pointer threads = itk::MultiThreader::New( ); threads->SetNumberOfThreads( nThreads ); threads->SetSingleMethod( this->_BoundaryCbk< _TTriplets >, &thrStr ); // Execute threader threads->SingleMethodExecute( ); #endif // USE_Eigen3 } // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > template< class _TTriplets > ITK_THREAD_RETURN_TYPE fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: _BoundaryCbk( void* arg ) { #ifdef USE_Eigen3 _TBoundaryThreadStruct* thrStr; itk::ThreadIdType total, thrId, thrCount; itk::MultiThreader::ThreadInfoStruct* thrInfo = reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg ); thrId = thrInfo->ThreadID; thrCount = thrInfo->NumberOfThreads; thrStr = reinterpret_cast< _TBoundaryThreadStruct* >( thrInfo->UserData ); TRegion region; total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region ); if( thrId < total ) thrStr->Filter->_ThreadedBoundary( region, thrId, reinterpret_cast< _TTriplets* >( thrStr->Triplets ), thrStr->Labels ); #endif // USE_Eigen3 return( ITK_THREAD_RETURN_VALUE ); } // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > template< class _TTriplets > void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: _ThreadedBoundary( const TRegion& region, const itk::ThreadIdType& id, _TTriplets* B, std::map< TLabel, unsigned long >* labels ) { #ifdef USE_Eigen3 typedef itk::ImageRegionConstIteratorWithIndex< TLabels > _TIt; typedef typename std::map< TLabel, unsigned long >::value_type _TMapValue; typedef typename std::map< unsigned long, TLabel >::value_type _TInvValue; typedef typename _TTriplets::value_type _TTriplet; const TLabels* in_labels = this->GetInputLabels( ); TRegion reqRegion = in_labels->GetRequestedRegion( ); _TIt it( in_labels, region ); for( it.GoToBegin( ); !it.IsAtEnd( ); ++it ) { if( it.Get( ) != 0 ) { unsigned long i = Self::_1D( it.GetIndex( ), reqRegion ); this->m_Mutex.Lock( ); B->push_back( _TTriplet( i, it.Get( ), TScalar( 1 ) ) ); if( labels->find( it.Get( ) ) == labels->end( ) ) labels->insert( _TMapValue( it.Get( ), labels->size( ) ) ); this->m_Mutex.Unlock( ); } // fi } // rof #endif // USE_Eigen3 } // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > template< class _TTriplets > void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: _Laplacian( _TTriplets& A, _TTriplets& R, const _TTriplets& B ) { #ifdef USE_Eigen3 A.clear( ); R.clear( ); // Set up the multithreaded processing _TLaplacianThreadStruct thrStr; thrStr.Filter = this; thrStr.A = reinterpret_cast< void* >( &A ); thrStr.R = reinterpret_cast< void* >( &R ); thrStr.B = reinterpret_cast< const void* >( &B ); // Configure threader const TImage* in = this->GetInput( ); const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( ); const unsigned int nThreads = split->GetNumberOfSplits( in->GetRequestedRegion( ), this->GetNumberOfThreads( ) ); itk::MultiThreader::Pointer threads = itk::MultiThreader::New( ); threads->SetNumberOfThreads( nThreads ); threads->SetSingleMethod( this->_LaplacianCbk< _TTriplets >, &thrStr ); // Execute threader threads->SingleMethodExecute( ); #endif // USE_Eigen3 } // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > template< class _TTriplets > ITK_THREAD_RETURN_TYPE fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: _LaplacianCbk( void* arg ) { #ifdef USE_Eigen3 _TLaplacianThreadStruct* thrStr; itk::ThreadIdType total, thrId, thrCount; itk::MultiThreader::ThreadInfoStruct* thrInfo = reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg ); thrId = thrInfo->ThreadID; thrCount = thrInfo->NumberOfThreads; thrStr = reinterpret_cast< _TLaplacianThreadStruct* >( thrInfo->UserData ); TRegion region; total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region ); if( thrId < total ) thrStr->Filter->_ThreadedLaplacian( region, thrId, reinterpret_cast< _TTriplets* >( thrStr->A ), reinterpret_cast< _TTriplets* >( thrStr->R ), reinterpret_cast< const _TTriplets* >( thrStr->B ) ); #endif // USE_Eigen3 return( ITK_THREAD_RETURN_VALUE ); } // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > template< class _TTriplets > void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: _ThreadedLaplacian( const TRegion& region, const itk::ThreadIdType& id, _TTriplets* A, _TTriplets* R, const _TTriplets* B ) { #ifdef USE_Eigen3 typedef itk::ImageRegionConstIteratorWithIndex< TImage > _TIt; typedef typename _TTriplets::value_type _TTriplet; const TImage* in = this->GetInput( ); const TLabels* in_labels = this->GetInputLabels( ); TRegion reqRegion = in->GetRequestedRegion( ); _TIt it( in, region ); for( it.GoToBegin( ); !it.IsAtEnd( ); ++it ) { TIndex idx = it.GetIndex( ); bool iSeed = ( in_labels->GetPixel( idx ) != 0 ); unsigned long i = Self::_1D( idx, reqRegion ); unsigned long si; // A's diagonal values if( !iSeed ) { si = Self::_NearSeedIndex( i, *B ); this->m_Mutex.Lock( ); A->push_back( _TTriplet( si, si, this->_L( idx, idx ) ) ); this->m_Mutex.Unlock( ); } else si = Self::_SeedIndex( i, *B ); // Neighbors (final matrix is symmetric) for( unsigned int d = 0; d < TImage::ImageDimension; ++d ) { for( int s = -1; s <= 1; s += 2 ) { TIndex jdx = idx; jdx[ d ] += s; if( reqRegion.IsInside( jdx ) ) { TScalar L = this->_L( idx, jdx ); unsigned long j = Self::_1D( jdx, reqRegion ); bool jSeed = ( in_labels->GetPixel( jdx ) != 0 ); if( !jSeed ) { unsigned long sj = Self::_NearSeedIndex( j, *B ); if( !iSeed ) { this->m_Mutex.Lock( ); A->push_back( _TTriplet( si, sj, L ) ); this->m_Mutex.Unlock( ); } else { this->m_Mutex.Lock( ); R->push_back( _TTriplet( sj, si, -L ) ); this->m_Mutex.Unlock( ); } // fi } // fi } // fi } // rof } // rof } // rof #endif // USE_Eigen3 } // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > template< class _TMatrix, class _TTriplets > void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: _Output( const _TMatrix& X, const _TTriplets& S, const std::vector< TLabel >& invLabels ) { #ifdef USE_Eigen3 // Set up the multithreaded processing _TOutputThreadStruct thrStr; thrStr.Filter = this; thrStr.X = reinterpret_cast< const void* >( &X ); thrStr.S = reinterpret_cast< const void* >( &S ); thrStr.InvLabels = &invLabels; // Configure threader const TLabels* out = this->GetOutput( ); const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( ); const unsigned int nThreads = split->GetNumberOfSplits( out->GetRequestedRegion( ), this->GetNumberOfThreads( ) ); itk::MultiThreader::Pointer threads = itk::MultiThreader::New( ); threads->SetNumberOfThreads( nThreads ); threads->SetSingleMethod( this->_OutputCbk< _TMatrix, _TTriplets >, &thrStr ); // Execute threader threads->SingleMethodExecute( ); #endif // USE_Eigen3 } // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > template< class _TMatrix, class _TTriplets > ITK_THREAD_RETURN_TYPE fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: _OutputCbk( void* arg ) { #ifdef USE_Eigen3 _TOutputThreadStruct* thrStr; itk::ThreadIdType total, thrId, thrCount; itk::MultiThreader::ThreadInfoStruct* thrInfo = reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg ); thrId = thrInfo->ThreadID; thrCount = thrInfo->NumberOfThreads; thrStr = reinterpret_cast< _TOutputThreadStruct* >( thrInfo->UserData ); TRegion region; total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region ); if( thrId < total ) thrStr->Filter->_ThreadedOutput( region, thrId, reinterpret_cast< const _TMatrix* >( thrStr->X ), reinterpret_cast< const _TTriplets* >( thrStr->S ), thrStr->InvLabels ); #endif // USE_Eigen3 return( ITK_THREAD_RETURN_VALUE ); } // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > template< class _TMatrix, class _TTriplets > void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: _ThreadedOutput( const TRegion& region, const itk::ThreadIdType& id, const _TMatrix* X, const _TTriplets* S, const std::vector< TLabel >* invLabels ) { #ifdef USE_Eigen3 // Fill outputs const TLabels* in_labels = this->GetInputLabels( ); TLabels* out_labels = this->GetOutput( ); TScalarImage* out_probs = this->GetOutputProbabilities( ); TRegion reqRegion = out_labels->GetRequestedRegion( ); itk::ImageRegionConstIteratorWithIndex< TLabels > iIt( in_labels, region ); itk::ImageRegionIteratorWithIndex< TLabels > oIt( out_labels, region ); itk::ImageRegionIteratorWithIndex< TScalarImage > pIt( out_probs, region ); iIt.GoToBegin( ); oIt.GoToBegin( ); pIt.GoToBegin( ); for( ; !iIt.IsAtEnd( ); ++iIt, ++oIt, ++pIt ) { if( iIt.Get( ) == 0 ) { unsigned long i = Self::_1D( iIt.GetIndex( ), reqRegion ); unsigned long j = Self::_NearSeedIndex( i, *S ); TScalar maxP = X->coeff( j, 0 ); unsigned long maxL = 0; for( unsigned int s = 1; s < invLabels->size( ); ++s ) { TScalar p = X->coeff( j, s ); if( maxP <= p ) { maxP = p; maxL = s; } // fi } // rof oIt.Set( ( *invLabels )[ maxL ] ); pIt.Set( maxP ); } else { oIt.Set( iIt.Get( ) ); pIt.Set( TScalar( 1 ) ); } // fi } // rof #endif // USE_Eigen3 } // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > unsigned long fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: _1D( const TIndex& idx, const TRegion& region ) { unsigned long i = idx[ 0 ]; unsigned long off = 1; typename TRegion::SizeType size = region.GetSize( ); for( unsigned int d = 1; d < TIndex::Dimension; ++d ) { off *= size[ d - 1 ]; i += idx[ d ] * off; } // rof return( i ); } // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > template< class _TTriplets > unsigned long fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: _SeedIndex( const unsigned long& i, const _TTriplets& t ) { unsigned long s = 0; unsigned long f = t.size( ); unsigned long e = f - 1; while( e > s && f == t.size( ) ) { if( e > s + 1 ) { unsigned long h = ( e + s ) >> 1; if ( i < t[ h ].row( ) ) e = h; else if( t[ h ].row( ) < i ) s = h; else f = h; } else f = ( t[ s ].row( ) == i )? s: e; } // elihw return( f ); } // ------------------------------------------------------------------------- template< class _TImage, class _TLabels, class _TScalar > template< class _TTriplets > unsigned long fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >:: _NearSeedIndex( const unsigned long& i, const _TTriplets& t ) { long s = 0; long e = t.size( ) - 1; while( e > s + 1 ) { long h = ( e + s ) >> 1; if ( i < t[ h ].row( ) ) e = h; else if( t[ h ].row( ) < i ) s = h; } // elihw long d; if( i < t[ s ].row( ) ) d = -1; else if( t[ s ].row( ) < i && i < t[ e ].row( ) ) d = s + 1; else d = e + 1; if( d < 0 ) d = 0; return( i - d ); } #endif // __fpa__Common__RandomWalker__hxx__ // eof - $RCSfile$