]> Creatis software - FrontAlgorithms.git/blob - lib/fpa/Common/OriginalRandomWalker.hxx
...
[FrontAlgorithms.git] / lib / fpa / Common / OriginalRandomWalker.hxx
1 // =========================================================================
2 // @author Leonardo Florez Valencia
3 // @email florez-l@javeriana.edu.co
4 // =========================================================================
5 #ifndef __fpa__Common__OriginalRandomWalker__hxx__
6 #define __fpa__Common__OriginalRandomWalker__hxx__
7
8 #include <cmath>
9 #include <map>
10 #include <itkImageRegionConstIteratorWithIndex.h>
11 #include <Eigen/Sparse>
12
13 // -------------------------------------------------------------------------
14 template< class _TImage, class _TLabels, class _TScalar >
15 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
16 AddSeed( const TIndex& seed, const TLabel& label )
17 {
18   this->m_Seeds.push_back( seed );
19   this->m_Labels.push_back( label );
20   this->Modified( );
21 }
22
23 // -------------------------------------------------------------------------
24 template< class _TImage, class _TLabels, class _TScalar >
25 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
26 OriginalRandomWalker( )
27   : Superclass( ),
28     m_Beta( TScalar( 90 ) ),
29     m_Epsilon( TScalar( 1e-5 ) ),
30     m_NormalizeWeights( true )
31 {
32   fpaFilterOptionalInputConfigureMacro( InputLabels, TLabels );
33   fpaFilterOutputConfigureMacro( OutputProbabilities, TScalarImage );
34 }
35
36 // -------------------------------------------------------------------------
37 template< class _TImage, class _TLabels, class _TScalar >
38 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
39 ~OriginalRandomWalker( )
40 {
41 }
42
43 // -------------------------------------------------------------------------
44 template< class _TImage, class _TLabels, class _TScalar >
45 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
46 GenerateData( )
47 {
48   typedef Eigen::SparseMatrix< TScalar >           _TSparseMatrix;
49   typedef Eigen::SimplicialLDLT< _TSparseMatrix >  _TSparseSolver;
50   typedef Eigen::Triplet< TScalar >                _TTriplet;
51   typedef std::vector< _TTriplet >                 _TTriplets;
52   typedef std::map< unsigned long, unsigned long > _TMap;
53   typedef _TMap::value_type                        _TMapValue;
54
55   // Some input values
56   const TImage* in = this->GetInput( );
57   TRegion region = in->GetRequestedRegion( );
58   unsigned long N = region.GetNumberOfPixels( );
59
60   // Prepare seeds for linear algebra
61   _TMap seeds;
62   for( unsigned int i = 0; i < this->m_Seeds.size( ); ++i )
63     seeds.insert(
64       _TMapValue(
65         Self::_1D( this->m_Seeds[ i ], region ),
66         this->m_Labels[ i ]
67         )
68       );
69
70   // Use input labels
71   /* TODO
72      const TLabels* in_labels = this->GetInputLabels( );
73      if( in_labels != NULL )
74      {
75      itk::ImageRegionConstIteratorWithIndex< TLabels > lIt( in_labels, region );
76      for( lIt.GoToBegin( ); !lIt.IsAtEnd( ); ++lIt )
77      if( lIt.Get( ) != 0 )
78      seeds.insert(
79      _TMapValue( Self::_1D( lIt.GetIndex( ), region ), lIt.Get( ) )
80      );
81
82      } // fi
83   */
84
85   // Prepare label tables
86   _TMap labels, inv_labels, seeds_indexes;
87   _TMap::const_iterator hIt = seeds.begin( );
88   for( ; hIt != seeds.end( ); ++hIt )
89   {
90     seeds_indexes[ hIt->first ] = seeds_indexes.size( ) - 1;
91     if( labels.find( hIt->second ) == labels.end( ) )
92     {
93       labels[ hIt->second ] = labels.size( ) - 1;
94       inv_labels[ labels[ hIt->second ] ] = hIt->second;
95
96     } // fi
97
98   } // rof
99
100   // Prepare matrix/image index conversion
101   _TMap indexes;
102   unsigned long o = 0;
103   for( unsigned long n = 0; n < N; ++n )
104   {
105     if( seeds.find( n ) == seeds.end( ) )
106       indexes.insert( _TMap::value_type( n, n - o ) );
107     else
108       o++;
109
110   } // rof
111   unsigned long nLabels = labels.size( );
112
113   // Boundary matrix
114   _TTriplets Bt;
115   _TMap::const_iterator mIt = seeds.begin( );
116   for( unsigned long i = 0; mIt != seeds.end( ); ++mIt, ++i )
117     Bt.push_back( _TTriplet( i, labels[ mIt->second ], TScalar( 1 ) ) );
118   _TSparseMatrix B( seeds.size( ), nLabels );
119   B.setFromTriplets( Bt.begin( ), Bt.end( ) );
120   B.makeCompressed( );
121
122   // Laplacian matrix
123   _TTriplets Lt;
124   itk::ImageRegionConstIteratorWithIndex< TImage > it( in, region );
125   TScalar maxV = TScalar( -1 );
126   for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
127   {
128     TIndex idx = it.GetIndex( );
129     TScalar vidx = TScalar( it.Get( ) );
130     unsigned long iidx = Self::_1D( idx, region );
131
132     // Neighbors
133     TScalar s = TScalar( 0 );
134     for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
135     {
136       TIndex jdx;
137       for( int l = -1; l <= 1; l += 2 )
138       {
139         jdx = idx;
140         jdx[ d ] += l;
141         if( region.IsInside( jdx ) )
142         {
143           TScalar vjdx = TScalar( in->GetPixel( jdx ) );
144           unsigned long ijdx = Self::_1D( jdx, region );
145           TScalar v = std::fabs( vidx - vjdx );
146           Lt.push_back( _TTriplet( iidx, ijdx, v ) );
147           if( maxV < v )
148             maxV = v;
149
150         } // fi
151
152       } // rof
153
154     } // rof
155
156   } // rof
157   std::vector< TScalar > diag( N, TScalar( 0 ) );
158   typename _TTriplets::iterator tIt;
159   TScalar betaV = -this->m_Beta;
160   if( this->m_NormalizeWeights )
161     betaV /= maxV;
162   for( tIt = Lt.begin( ); tIt != Lt.end( ); ++tIt )
163   {
164     TScalar v = std::exp( betaV * tIt->value( ) );
165     if( v < this->m_Epsilon )
166       v = this->m_Epsilon;
167     *tIt = _TTriplet( tIt->row( ), tIt->col( ), -v );
168     diag[ tIt->col( ) ] += v;
169
170   } // rof
171   for( unsigned long i = 0; i < diag.size( ); ++i )
172     Lt.push_back( _TTriplet( i, i, diag[ i ] ) );
173
174   // Compute R and A
175   _TTriplets Rt, At;
176   for( tIt = Lt.begin( ); tIt != Lt.end( ); ++tIt )
177   {
178     _TMap::const_iterator cIt, rIt;
179     cIt = seeds.find( tIt->col( ) );
180     rIt = seeds.find( tIt->row( ) );
181     if( cIt != seeds.end( ) )
182     {
183       if( rIt == seeds.end( ) )
184       {
185         _TMap::const_iterator iIt = indexes.find( tIt->row( ) );
186         _TMap::const_iterator jIt = seeds_indexes.find( cIt->first );
187         Rt.push_back( _TTriplet( iIt->second, jIt->second, -tIt->value( ) ) );
188
189       } // fi
190     }
191     else
192     {
193       if( rIt == seeds.end( ) )
194       {
195         _TMap::const_iterator iIt = indexes.find( tIt->row( ) );
196         _TMap::const_iterator jIt = indexes.find( tIt->col( ) );
197         At.push_back( _TTriplet( iIt->second, jIt->second, tIt->value( ) ) );
198
199       } // fi
200
201     } // fi
202
203   } // rof
204
205   _TSparseMatrix R( N - seeds.size( ), seeds.size( ) );
206   R.setFromTriplets( Rt.begin( ), Rt.end( ) );
207   R.makeCompressed( );
208
209   _TSparseMatrix A( N - seeds.size( ), N - seeds.size( ) );
210   A.setFromTriplets( At.begin( ), At.end( ) );
211   A.makeCompressed( );
212
213   // Solve dirichlet problem
214   _TSparseSolver solver;
215   solver.compute( A );
216   if( solver.info( ) != Eigen::Success )
217   {
218     std::cerr << "Error computing." << std::endl;
219   } // fi
220   _TSparseMatrix x = solver.solve( R * B );
221   if( solver.info( ) != Eigen::Success )
222   {
223     std::cerr << "Error solving." << std::endl;
224   } // fi
225
226   // Fill outputs
227   TLabels* out_labels = this->GetOutput( );
228   out_labels->SetLargestPossibleRegion( in->GetLargestPossibleRegion( ) );
229   out_labels->SetRequestedRegion( in->GetRequestedRegion( ) );
230   out_labels->SetBufferedRegion( in->GetBufferedRegion( ) );
231   out_labels->SetSpacing( in->GetSpacing( ) );
232   out_labels->SetOrigin( in->GetOrigin( ) );
233   out_labels->SetDirection( in->GetDirection( ) );
234   out_labels->Allocate( );
235
236   TScalarImage* out_probs = this->GetOutputProbabilities( );
237   out_probs->SetLargestPossibleRegion( in->GetLargestPossibleRegion( ) );
238   out_probs->SetRequestedRegion( in->GetRequestedRegion( ) );
239   out_probs->SetBufferedRegion( in->GetBufferedRegion( ) );
240   out_probs->SetSpacing( in->GetSpacing( ) );
241   out_probs->SetOrigin( in->GetOrigin( ) );
242   out_probs->SetDirection( in->GetDirection( ) );
243   out_probs->Allocate( );
244
245   mIt = seeds.begin( );
246   o = 0;
247   unsigned long j = 0;
248   for( unsigned long j = 0; j < N; ++j )
249   {
250     TIndex idx;
251     TLabel lbl;
252     TScalar p;
253     if( mIt->first != j )
254     {
255       idx = Self::_ND( j, region );
256       TScalar maxP = x.coeffRef( j - o, 0 );
257       unsigned long maxL = 0;
258       for( unsigned long l = 1; l < nLabels; ++l )
259       {
260         TScalar vp = x.coeffRef( j - o, l );
261         if( maxP < vp )
262         {
263           maxP = vp;
264           maxL = l;
265
266         } // fi
267
268       } // rof
269       lbl = inv_labels[ maxL ];
270       p = maxP;
271     }
272     else
273     {
274       idx = Self::_ND( mIt->first, region );
275       lbl = mIt->second;
276       p = TScalar( 1 );
277       mIt++;
278       o++;
279
280     } // fi
281     out_labels->SetPixel( idx, lbl );
282     out_probs->SetPixel( idx, p );
283
284   } // rof
285 }
286
287 // -------------------------------------------------------------------------
288 template< class _TImage, class _TLabels, class _TScalar >
289 unsigned long
290 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
291 _1D( const TIndex& idx, const TRegion& region )
292 {
293   unsigned long i = idx[ 0 ];
294   unsigned long off = 1;
295   typename TRegion::SizeType size = region.GetSize( );
296   for( unsigned int d = 1; d < TIndex::Dimension; ++d )
297   {
298     off *= size[ d - 1 ];
299     i += idx[ d ] * off;
300
301   } // rof
302   return( i );
303 }
304
305 // -------------------------------------------------------------------------
306 template< class _TImage, class _TLabels, class _TScalar >
307 typename fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
308 TIndex fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
309 _ND( const unsigned long& i, const TRegion& region )
310 {
311   typename TRegion::SizeType size = region.GetSize( );
312
313   unsigned long j = i;
314   TIndex idx;
315   if( TIndex::Dimension == 3 )
316   {
317     unsigned long z = size[ 0 ] * size[ 1 ];
318     idx[ 2 ] = j / z;
319     j -= idx[ 2 ] * z;
320
321   } // fi
322   idx[ 1 ] = j / size[ 0 ];
323   idx[ 0 ] = j % size[ 0 ];
324   return( idx );
325 }
326
327 #endif // __fpa__Common__OriginalRandomWalker__hxx__
328 // eof - $RCSfile$