]> Creatis software - FrontAlgorithms.git/blob - tests/image/RandomWalker/Original.cxx
...
[FrontAlgorithms.git] / tests / image / RandomWalker / Original.cxx
1 // =========================================================================
2 // @author Leonardo Florez Valencia
3 // @email florez-l@javeriana.edu.co
4 // =========================================================================
5
6 #include <algorithm>
7 #include <map>
8 #include <vector>
9 #include <itkImage.h>
10 #include <itkImageFileReader.h>
11 #include <itkImageFileWriter.h>
12 #include <itkImageRegionIteratorWithIndex.h>
13 #include <itkVariableSizeMatrix.h>
14
15 #include <Eigen/Sparse>
16 #include <Eigen/SparseQR>
17 #include <Eigen/OrderingMethods>
18
19 // -------------------------------------------------------------------------
20 const unsigned int Dim = 2;
21 typedef unsigned char TPixel;
22 typedef unsigned char TLabel;
23 typedef double TScalar;
24 typedef itk::Image< TPixel, Dim > TImage;
25 typedef itk::Image< TLabel, Dim > TLabelImage;
26
27 typedef Eigen::SparseMatrix< TScalar > TSparseMatrix;
28
29 /* TODO
30    class Eigen::AMDOrdering< StorageIndex >
31    class Eigen::COLAMDOrdering< StorageIndex >
32    class Eigen::NaturalOrdering< StorageIndex >
33 */
34 typedef Eigen::AMDOrdering< int > TSolverStorage;
35
36 typedef Eigen::SparseQR< TSparseMatrix, TSolverStorage > TSparseSolver;
37 typedef Eigen::Triplet< TScalar > TTriplet;
38 typedef std::vector< TTriplet > TTriplets;
39
40 // -------------------------------------------------------------------------
41 template< class _TIndex, class _TRegion >
42 unsigned long GetIndex( const _TIndex& idx, const _TRegion& region )
43 {
44   unsigned long i = 0;
45   unsigned long off = 1;
46   for( unsigned int d = 0; d < _TIndex::Dimension; ++d )
47   {
48     i += idx[ d ] * off;
49     off *= region.GetSize( )[ d ];
50
51   } // rof
52   return( i );
53 }
54
55 // -------------------------------------------------------------------------
56 int main( int argc, char* argv[] )
57 {
58   // Get arguments
59   /* TODO
60      if( argc < 2 )
61      {
62      std::cerr
63      << "Usage: " << argv[ 0 ] << " input_image"
64      << std::endl;
65      return( 1 );
66
67      } // fi
68      std::string input_image_filename = argv[ 1 ];
69   */
70   TScalar beta = 90;
71   TScalar eps = 1e-5;
72
73   TPixel img[ 5 ][ 5 ] =
74     {
75       { 1,  2 , 100,  2 , 1 },
76       { 1, 100, 100, 100, 1 },
77       { 1, 100, 100, 100, 1 },
78       { 1, 100, 100, 100, 1 },
79       { 1,  2 , 100,  2 , 1 },
80     };
81
82   // Read image
83   TImage::Pointer input;
84   { // begin
85     TImage::SizeType size;
86     size.Fill( 5 );
87
88     input = TImage::New( );
89     input->SetRegions( size );
90     input->Allocate( );
91     itk::ImageRegionIteratorWithIndex< TImage > it( input, input->GetLargestPossibleRegion( ) );
92     for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
93       it.Set( img[ it.GetIndex( )[ 1 ] ][ it.GetIndex( )[ 0 ] ] );
94
95     /* TODO
96        typedef itk::ImageFileReader< TImage > TImageReader;
97        TImageReader::Pointer input_image_reader = TImageReader::New( );
98        input_image_reader->SetFileName( input_image_filename );
99        input_image_reader->Update( );
100        input = input_image_reader->GetOutput( );
101        input->DisconnectPipeline( );
102     */
103   } // end
104   TImage::RegionType region = input->GetLargestPossibleRegion( );
105   unsigned long N = region.GetNumberOfPixels( );
106
107   // Seeds
108   TImage::IndexType s1, s2, s3;
109   s1[ 0 ] = 0;
110   s1[ 1 ] = 1;
111   s2[ 0 ] = 2;
112   s2[ 1 ] = 3;
113   s3[ 0 ] = 4;
114   s3[ 1 ] = 4;
115
116   /*
117     std::vector< unsigned long > seeds;
118     std::vector< unsigned long > labels;
119   */
120   std::map< unsigned long, unsigned long > seeds;
121   seeds[ GetIndex( s1, region ) ] = 1;
122   seeds[ GetIndex( s2, region ) ] = 2;
123   seeds[ GetIndex( s3, region ) ] = 2;
124   unsigned long nLabels = 2;
125
126   // Construct L
127   TTriplets Lt;
128   itk::ImageRegionIteratorWithIndex< TImage > it( input, region );
129   TScalar maxV = -1;
130   for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
131   {
132     TImage::IndexType idx = it.GetIndex( );
133     TScalar vidx = TScalar( it.Get( ) );
134     unsigned long iidx = GetIndex( idx, region );
135
136     // Neighbors
137     TScalar s = TScalar( 0 );
138     for( unsigned int d = 0; d < Dim; ++d )
139     {
140       TImage::IndexType jdx;
141       for( int l = -1; l <= 1; l += 2 )
142       {
143         jdx = idx;
144         jdx[ d ] += l;
145         if( region.IsInside( jdx ) )
146         {
147           TScalar vjdx = TScalar( input->GetPixel( jdx ) );
148           unsigned long ijdx = GetIndex( jdx, region );
149           TScalar v = std::fabs( vidx - vjdx );
150           Lt.push_back( TTriplet( iidx, ijdx, v ) );
151           if( maxV < v )
152             maxV = v;
153
154         } // fi
155
156       } // rof
157
158     } // rof
159
160   } // rof
161
162   std::vector< TScalar > diag( N, TScalar( 0 ) );
163   for( TTriplets::iterator tIt = Lt.begin( ); tIt != Lt.end( ); ++tIt )
164   {
165     TScalar v = std::exp( -beta * tIt->value( ) / maxV );
166     if( v < eps )
167       v = eps;
168     *tIt = TTriplet( tIt->row( ), tIt->col( ), -v );
169     diag[ tIt->col( ) ] += v;
170
171   } // rof
172   for( unsigned long i = 0; i < diag.size( ); ++i )
173     Lt.push_back( TTriplet( i, i, diag[ i ] ) );
174   TSparseMatrix L( N, N );
175   L.setFromTriplets( Lt.begin( ), Lt.end( ) );
176   L.makeCompressed( );
177
178   // Boundary
179   TTriplets boundaryt;
180   // TODO: for( unsigned int i = 0; i < seeds.size( ); ++i )
181   std::map< unsigned long, unsigned long >::const_iterator mIt = seeds.begin( );
182   for( unsigned long i = 0; mIt != seeds.end( ); ++mIt, ++i )
183     boundaryt.push_back( TTriplet( i, mIt->second - 1, TScalar( 1 ) ) );
184   TSparseMatrix boundary( seeds.size( ), nLabels );
185   boundary.setFromTriplets( boundaryt.begin( ), boundaryt.end( ) );
186   boundary.makeCompressed( );
187
188   // Compute RHS
189   TTriplets RHSt;
190   unsigned long x = 0;
191   std::map< unsigned long, unsigned long >::const_iterator sIt = seeds.begin( );
192   for( ; sIt != seeds.end( ); ++sIt )
193   {
194     unsigned long y = 0;
195     for( unsigned long n = 0; n < N; ++n )
196     {
197       std::map< unsigned long, unsigned long >::const_iterator nIt =
198         seeds.find( n );
199       if( nIt == seeds.end( ) )
200         RHSt.push_back( TTriplet( y++, x, L.coeff( sIt->first, n ) ) );
201
202     } // rof
203     x++;
204
205   } // rof
206   TSparseMatrix RHS( N - seeds.size( ), seeds.size( ) );
207   RHS.setFromTriplets( RHSt.begin( ), RHSt.end( ) );
208   RHS.makeCompressed( );
209
210   // Compute A
211   TTriplets At;
212   x = 0;
213   for( unsigned long m = 0; m < N; ++m )
214   {
215     if( seeds.find( m ) == seeds.end( ) )
216     {
217       unsigned long y = 0;
218       for( unsigned long n = 0; n < N; ++n )
219         if( seeds.find( n ) == seeds.end( ) )
220           At.push_back( TTriplet( y++, x, L.coeff( m, n ) ) );
221       x++;
222
223     } // fi
224
225   } // rof
226   TSparseMatrix A( N - seeds.size( ), N - seeds.size( ) );
227   A.setFromTriplets( At.begin( ), At.end( ) );
228   A.makeCompressed( );
229
230   // Solve dirichlet problem
231   TSparseSolver solver;
232   solver.compute( A );
233   if( solver.info( ) != Eigen::Success )
234   {
235     std::cerr << "Error computing." << std::endl;
236   } // fi
237   TSparseMatrix sol = solver.solve( ( RHS * TScalar( -1 ) ) * boundary );
238   if( solver.info( ) != Eigen::Success )
239   {
240     std::cerr << "Error solving." << std::endl;
241   } // fi
242
243   TLabelImage::Pointer output = TLabelImage::New( );
244   output->SetLargestPossibleRegion( input->GetLargestPossibleRegion( ) );
245   output->SetRequestedRegion( input->GetRequestedRegion( ) );
246   output->SetBufferedRegion( input->GetBufferedRegion( ) );
247   output->SetSpacing( input->GetSpacing( ) );
248   output->SetOrigin( input->GetOrigin( ) );
249   output->SetDirection( input->GetDirection( ) );
250   output->Allocate( );
251
252   /* TODO
253      itk::ImageRegionIteratorWithIndex< TLabelImage > lIt( output, output->GetRequestedRegion( ) );
254      for( lIt.GoToBegin( ); !lIt.IsAtEnd( ); ++lIt )
255      {
256      unsigned long i = GetIndex( lIt.GetIndex( ), region );
257      std::map< unsigned long, unsigned long >::const_iterator sIt =
258      seeds.find( i );
259      if( sIt == seeds.end( ) )
260      {
261      std::cout << i << std::endl;
262      TScalar maxProb = TScalar( -1 );
263      for( unsigned long k = 0; k < nLabels; ++k )
264      {
265      std::cout << "\t\t" << i << " " << k << std::endl;
266      TScalar p = sol.coeff( i, k );
267      if( maxProb < p )
268      maxProb = p;
269
270      } // rof
271      std::cout << "\t" << maxProb << std::endl;
272      }
273      else
274      {
275      std::cout << "---> " << i << std::endl;
276      lIt.Set( sIt->second );
277
278      } // fi
279
280      } // rof
281   */
282
283   return( 0 );
284 }
285
286 // eof - $RCSfile$