1 // =========================================================================
2 // @author Leonardo Florez Valencia
3 // @email florez-l@javeriana.edu.co
4 // =========================================================================
5 #ifndef __fpa__Common__OriginalRandomWalker__hxx__
6 #define __fpa__Common__OriginalRandomWalker__hxx__
10 #include <itkImageRegionConstIteratorWithIndex.h>
11 #include <Eigen/Sparse>
13 // -------------------------------------------------------------------------
14 template< class _TImage, class _TLabels, class _TScalar >
15 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
16 AddSeed( const TIndex& seed, const TLabel& label )
18 this->m_Seeds.push_back( seed );
19 this->m_Labels.push_back( label );
23 // -------------------------------------------------------------------------
24 template< class _TImage, class _TLabels, class _TScalar >
25 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
26 OriginalRandomWalker( )
28 m_Beta( TScalar( 90 ) ),
29 m_Epsilon( TScalar( 1e-5 ) ),
30 m_NormalizeWeights( true )
32 fpaFilterOptionalInputConfigureMacro( InputLabels, TLabels );
33 fpaFilterOutputConfigureMacro( OutputProbabilities, TScalarImage );
36 // -------------------------------------------------------------------------
37 template< class _TImage, class _TLabels, class _TScalar >
38 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
39 ~OriginalRandomWalker( )
43 // -------------------------------------------------------------------------
44 template< class _TImage, class _TLabels, class _TScalar >
45 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
48 typedef Eigen::SparseMatrix< TScalar > _TSparseMatrix;
49 typedef Eigen::SimplicialLDLT< _TSparseMatrix > _TSparseSolver;
50 typedef Eigen::Triplet< TScalar > _TTriplet;
51 typedef std::vector< _TTriplet > _TTriplets;
52 typedef std::map< unsigned long, unsigned long > _TMap;
53 typedef _TMap::value_type _TMapValue;
56 const TImage* in = this->GetInput( );
57 TRegion region = in->GetRequestedRegion( );
58 unsigned long N = region.GetNumberOfPixels( );
60 // Prepare seeds for linear algebra
62 for( unsigned int i = 0; i < this->m_Seeds.size( ); ++i )
65 Self::_1D( this->m_Seeds[ i ], region ),
72 const TLabels* in_labels = this->GetInputLabels( );
73 if( in_labels != NULL )
75 itk::ImageRegionConstIteratorWithIndex< TLabels > lIt( in_labels, region );
76 for( lIt.GoToBegin( ); !lIt.IsAtEnd( ); ++lIt )
79 _TMapValue( Self::_1D( lIt.GetIndex( ), region ), lIt.Get( ) )
85 // Prepare label tables
86 _TMap labels, inv_labels, seeds_indexes;
87 _TMap::const_iterator hIt = seeds.begin( );
88 for( ; hIt != seeds.end( ); ++hIt )
90 seeds_indexes[ hIt->first ] = seeds_indexes.size( ) - 1;
91 if( labels.find( hIt->second ) == labels.end( ) )
93 labels[ hIt->second ] = labels.size( ) - 1;
94 inv_labels[ labels[ hIt->second ] ] = hIt->second;
100 // Prepare matrix/image index conversion
103 for( unsigned long n = 0; n < N; ++n )
105 if( seeds.find( n ) == seeds.end( ) )
106 indexes.insert( _TMap::value_type( n, n - o ) );
111 unsigned long nLabels = labels.size( );
115 _TMap::const_iterator mIt = seeds.begin( );
116 for( unsigned long i = 0; mIt != seeds.end( ); ++mIt, ++i )
117 Bt.push_back( _TTriplet( i, labels[ mIt->second ], TScalar( 1 ) ) );
118 _TSparseMatrix B( seeds.size( ), nLabels );
119 B.setFromTriplets( Bt.begin( ), Bt.end( ) );
124 itk::ImageRegionConstIteratorWithIndex< TImage > it( in, region );
125 TScalar maxV = TScalar( -1 );
126 for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
128 TIndex idx = it.GetIndex( );
129 TScalar vidx = TScalar( it.Get( ) );
130 unsigned long iidx = Self::_1D( idx, region );
133 TScalar s = TScalar( 0 );
134 for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
137 for( int l = -1; l <= 1; l += 2 )
141 if( region.IsInside( jdx ) )
143 TScalar vjdx = TScalar( in->GetPixel( jdx ) );
144 unsigned long ijdx = Self::_1D( jdx, region );
145 TScalar v = std::fabs( vidx - vjdx );
146 Lt.push_back( _TTriplet( iidx, ijdx, v ) );
157 std::vector< TScalar > diag( N, TScalar( 0 ) );
158 typename _TTriplets::iterator tIt;
159 TScalar betaV = -this->m_Beta;
160 if( this->m_NormalizeWeights )
162 for( tIt = Lt.begin( ); tIt != Lt.end( ); ++tIt )
164 TScalar v = std::exp( betaV * tIt->value( ) );
165 if( v < this->m_Epsilon )
167 *tIt = _TTriplet( tIt->row( ), tIt->col( ), -v );
168 diag[ tIt->col( ) ] += v;
171 for( unsigned long i = 0; i < diag.size( ); ++i )
172 Lt.push_back( _TTriplet( i, i, diag[ i ] ) );
176 for( tIt = Lt.begin( ); tIt != Lt.end( ); ++tIt )
178 _TMap::const_iterator cIt, rIt;
179 cIt = seeds.find( tIt->col( ) );
180 rIt = seeds.find( tIt->row( ) );
181 if( cIt != seeds.end( ) )
183 if( rIt == seeds.end( ) )
185 _TMap::const_iterator iIt = indexes.find( tIt->row( ) );
186 _TMap::const_iterator jIt = seeds_indexes.find( cIt->first );
187 Rt.push_back( _TTriplet( iIt->second, jIt->second, -tIt->value( ) ) );
193 if( rIt == seeds.end( ) )
195 _TMap::const_iterator iIt = indexes.find( tIt->row( ) );
196 _TMap::const_iterator jIt = indexes.find( tIt->col( ) );
197 At.push_back( _TTriplet( iIt->second, jIt->second, tIt->value( ) ) );
205 _TSparseMatrix R( N - seeds.size( ), seeds.size( ) );
206 R.setFromTriplets( Rt.begin( ), Rt.end( ) );
209 _TSparseMatrix A( N - seeds.size( ), N - seeds.size( ) );
210 A.setFromTriplets( At.begin( ), At.end( ) );
213 // Solve dirichlet problem
214 _TSparseSolver solver;
216 if( solver.info( ) != Eigen::Success )
218 std::cerr << "Error computing." << std::endl;
220 _TSparseMatrix x = solver.solve( R * B );
221 if( solver.info( ) != Eigen::Success )
223 std::cerr << "Error solving." << std::endl;
227 TLabels* out_labels = this->GetOutput( );
228 out_labels->SetLargestPossibleRegion( in->GetLargestPossibleRegion( ) );
229 out_labels->SetRequestedRegion( in->GetRequestedRegion( ) );
230 out_labels->SetBufferedRegion( in->GetBufferedRegion( ) );
231 out_labels->SetSpacing( in->GetSpacing( ) );
232 out_labels->SetOrigin( in->GetOrigin( ) );
233 out_labels->SetDirection( in->GetDirection( ) );
234 out_labels->Allocate( );
236 TScalarImage* out_probs = this->GetOutputProbabilities( );
237 out_probs->SetLargestPossibleRegion( in->GetLargestPossibleRegion( ) );
238 out_probs->SetRequestedRegion( in->GetRequestedRegion( ) );
239 out_probs->SetBufferedRegion( in->GetBufferedRegion( ) );
240 out_probs->SetSpacing( in->GetSpacing( ) );
241 out_probs->SetOrigin( in->GetOrigin( ) );
242 out_probs->SetDirection( in->GetDirection( ) );
243 out_probs->Allocate( );
245 mIt = seeds.begin( );
248 for( unsigned long j = 0; j < N; ++j )
253 if( mIt->first != j )
255 idx = Self::_ND( j, region );
256 TScalar maxP = x.coeffRef( j - o, 0 );
257 unsigned long maxL = 0;
258 for( unsigned long l = 1; l < nLabels; ++l )
260 TScalar vp = x.coeffRef( j - o, l );
269 lbl = inv_labels[ maxL ];
274 idx = Self::_ND( mIt->first, region );
281 out_labels->SetPixel( idx, lbl );
282 out_probs->SetPixel( idx, p );
287 // -------------------------------------------------------------------------
288 template< class _TImage, class _TLabels, class _TScalar >
290 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
291 _1D( const TIndex& idx, const TRegion& region )
293 unsigned long i = idx[ 0 ];
294 unsigned long off = 1;
295 typename TRegion::SizeType size = region.GetSize( );
296 for( unsigned int d = 1; d < TIndex::Dimension; ++d )
298 off *= size[ d - 1 ];
305 // -------------------------------------------------------------------------
306 template< class _TImage, class _TLabels, class _TScalar >
307 typename fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
308 TIndex fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
309 _ND( const unsigned long& i, const TRegion& region )
311 typename TRegion::SizeType size = region.GetSize( );
315 if( TIndex::Dimension == 3 )
317 unsigned long z = size[ 0 ] * size[ 1 ];
322 idx[ 1 ] = j / size[ 0 ];
323 idx[ 0 ] = j % size[ 0 ];
327 #endif // __fpa__Common__OriginalRandomWalker__hxx__