1 // =========================================================================
2 // @author Leonardo Florez Valencia
3 // @email florez-l@javeriana.edu.co
4 // =========================================================================
10 #include <itkImageFileReader.h>
11 #include <itkImageFileWriter.h>
12 #include <itkImageRegionIteratorWithIndex.h>
13 #include <itkVariableSizeMatrix.h>
15 #include <Eigen/Sparse>
16 #include <Eigen/SparseQR>
17 #include <Eigen/OrderingMethods>
19 // -------------------------------------------------------------------------
20 const unsigned int Dim = 2;
21 typedef unsigned char TPixel;
22 typedef unsigned char TLabel;
23 typedef double TScalar;
24 typedef itk::Image< TPixel, Dim > TImage;
25 typedef itk::Image< TLabel, Dim > TLabelImage;
27 typedef Eigen::SparseMatrix< TScalar > TSparseMatrix;
30 class Eigen::AMDOrdering< StorageIndex >
31 class Eigen::COLAMDOrdering< StorageIndex >
32 class Eigen::NaturalOrdering< StorageIndex >
34 typedef Eigen::AMDOrdering< int > TSolverStorage;
36 typedef Eigen::SparseQR< TSparseMatrix, TSolverStorage > TSparseSolver;
37 typedef Eigen::Triplet< TScalar > TTriplet;
38 typedef std::vector< TTriplet > TTriplets;
40 // -------------------------------------------------------------------------
41 template< class _TIndex, class _TRegion >
42 unsigned long GetIndex( const _TIndex& idx, const _TRegion& region )
45 unsigned long off = 1;
46 for( unsigned int d = 0; d < _TIndex::Dimension; ++d )
49 off *= region.GetSize( )[ d ];
55 // -------------------------------------------------------------------------
56 int main( int argc, char* argv[] )
63 << "Usage: " << argv[ 0 ] << " input_image"
68 std::string input_image_filename = argv[ 1 ];
73 TPixel img[ 5 ][ 5 ] =
75 { 1, 2 , 100, 2 , 1 },
76 { 1, 100, 100, 100, 1 },
77 { 1, 100, 100, 100, 1 },
78 { 1, 100, 100, 100, 1 },
79 { 1, 2 , 100, 2 , 1 },
83 TImage::Pointer input;
85 TImage::SizeType size;
88 input = TImage::New( );
89 input->SetRegions( size );
91 itk::ImageRegionIteratorWithIndex< TImage > it( input, input->GetLargestPossibleRegion( ) );
92 for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
93 it.Set( img[ it.GetIndex( )[ 1 ] ][ it.GetIndex( )[ 0 ] ] );
96 typedef itk::ImageFileReader< TImage > TImageReader;
97 TImageReader::Pointer input_image_reader = TImageReader::New( );
98 input_image_reader->SetFileName( input_image_filename );
99 input_image_reader->Update( );
100 input = input_image_reader->GetOutput( );
101 input->DisconnectPipeline( );
104 TImage::RegionType region = input->GetLargestPossibleRegion( );
105 unsigned long N = region.GetNumberOfPixels( );
108 TImage::IndexType s1, s2, s3;
117 std::vector< unsigned long > seeds;
118 std::vector< unsigned long > labels;
120 std::map< unsigned long, unsigned long > seeds;
121 seeds[ GetIndex( s1, region ) ] = 1;
122 seeds[ GetIndex( s2, region ) ] = 2;
123 seeds[ GetIndex( s3, region ) ] = 2;
124 unsigned long nLabels = 2;
128 itk::ImageRegionIteratorWithIndex< TImage > it( input, region );
130 for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
132 TImage::IndexType idx = it.GetIndex( );
133 TScalar vidx = TScalar( it.Get( ) );
134 unsigned long iidx = GetIndex( idx, region );
137 TScalar s = TScalar( 0 );
138 for( unsigned int d = 0; d < Dim; ++d )
140 TImage::IndexType jdx;
141 for( int l = -1; l <= 1; l += 2 )
145 if( region.IsInside( jdx ) )
147 TScalar vjdx = TScalar( input->GetPixel( jdx ) );
148 unsigned long ijdx = GetIndex( jdx, region );
149 TScalar v = std::fabs( vidx - vjdx );
150 Lt.push_back( TTriplet( iidx, ijdx, v ) );
162 std::vector< TScalar > diag( N, TScalar( 0 ) );
163 for( TTriplets::iterator tIt = Lt.begin( ); tIt != Lt.end( ); ++tIt )
165 TScalar v = std::exp( -beta * tIt->value( ) / maxV );
168 *tIt = TTriplet( tIt->row( ), tIt->col( ), -v );
169 diag[ tIt->col( ) ] += v;
172 for( unsigned long i = 0; i < diag.size( ); ++i )
173 Lt.push_back( TTriplet( i, i, diag[ i ] ) );
174 TSparseMatrix L( N, N );
175 L.setFromTriplets( Lt.begin( ), Lt.end( ) );
180 // TODO: for( unsigned int i = 0; i < seeds.size( ); ++i )
181 std::map< unsigned long, unsigned long >::const_iterator mIt = seeds.begin( );
182 for( unsigned long i = 0; mIt != seeds.end( ); ++mIt, ++i )
183 boundaryt.push_back( TTriplet( i, mIt->second - 1, TScalar( 1 ) ) );
184 TSparseMatrix boundary( seeds.size( ), nLabels );
185 boundary.setFromTriplets( boundaryt.begin( ), boundaryt.end( ) );
186 boundary.makeCompressed( );
191 std::map< unsigned long, unsigned long >::const_iterator sIt = seeds.begin( );
192 for( ; sIt != seeds.end( ); ++sIt )
195 for( unsigned long n = 0; n < N; ++n )
197 std::map< unsigned long, unsigned long >::const_iterator nIt =
199 if( nIt == seeds.end( ) )
200 RHSt.push_back( TTriplet( y++, x, L.coeff( sIt->first, n ) ) );
206 TSparseMatrix RHS( N - seeds.size( ), seeds.size( ) );
207 RHS.setFromTriplets( RHSt.begin( ), RHSt.end( ) );
208 RHS.makeCompressed( );
213 for( unsigned long m = 0; m < N; ++m )
215 if( seeds.find( m ) == seeds.end( ) )
218 for( unsigned long n = 0; n < N; ++n )
219 if( seeds.find( n ) == seeds.end( ) )
220 At.push_back( TTriplet( y++, x, L.coeff( m, n ) ) );
226 TSparseMatrix A( N - seeds.size( ), N - seeds.size( ) );
227 A.setFromTriplets( At.begin( ), At.end( ) );
230 // Solve dirichlet problem
231 TSparseSolver solver;
233 if( solver.info( ) != Eigen::Success )
235 std::cerr << "Error computing." << std::endl;
237 TSparseMatrix sol = solver.solve( ( RHS * TScalar( -1 ) ) * boundary );
238 if( solver.info( ) != Eigen::Success )
240 std::cerr << "Error solving." << std::endl;
243 TLabelImage::Pointer output = TLabelImage::New( );
244 output->SetLargestPossibleRegion( input->GetLargestPossibleRegion( ) );
245 output->SetRequestedRegion( input->GetRequestedRegion( ) );
246 output->SetBufferedRegion( input->GetBufferedRegion( ) );
247 output->SetSpacing( input->GetSpacing( ) );
248 output->SetOrigin( input->GetOrigin( ) );
249 output->SetDirection( input->GetDirection( ) );
253 itk::ImageRegionIteratorWithIndex< TLabelImage > lIt( output, output->GetRequestedRegion( ) );
254 for( lIt.GoToBegin( ); !lIt.IsAtEnd( ); ++lIt )
256 unsigned long i = GetIndex( lIt.GetIndex( ), region );
257 std::map< unsigned long, unsigned long >::const_iterator sIt =
259 if( sIt == seeds.end( ) )
261 std::cout << i << std::endl;
262 TScalar maxProb = TScalar( -1 );
263 for( unsigned long k = 0; k < nLabels; ++k )
265 std::cout << "\t\t" << i << " " << k << std::endl;
266 TScalar p = sol.coeff( i, k );
271 std::cout << "\t" << maxProb << std::endl;
275 std::cout << "---> " << i << std::endl;
276 lIt.Set( sIt->second );