]> Creatis software - FrontAlgorithms.git/blob - lib/fpa/Common/OriginalRandomWalker.hxx
...
[FrontAlgorithms.git] / lib / fpa / Common / OriginalRandomWalker.hxx
1 // =========================================================================
2 // @author Leonardo Florez Valencia
3 // @email florez-l@javeriana.edu.co
4 // =========================================================================
5 #ifndef __fpa__Common__OriginalRandomWalker__hxx__
6 #define __fpa__Common__OriginalRandomWalker__hxx__
7
8 #include <cmath>
9 #include <itkImageRegionConstIteratorWithIndex.h>
10 #include <itkImageRegionIteratorWithIndex.h>
11 #include <Eigen/Sparse>
12
13 // -------------------------------------------------------------------------
14 template< class _TImage, class _TLabels, class _TScalar >
15 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
16 AddSeed( const TIndex& seed, const TLabel& label )
17 {
18   this->m_Seeds.push_back( seed );
19   this->m_Labels.push_back( label );
20   this->Modified( );
21 }
22
23 // -------------------------------------------------------------------------
24 template< class _TImage, class _TLabels, class _TScalar >
25 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
26 OriginalRandomWalker( )
27   : Superclass( ),
28     m_Beta( TScalar( 90 ) ),
29     m_Epsilon( TScalar( 1e-5 ) ),
30     m_NormalizeWeights( true )
31 {
32   fpaFilterInputConfigureMacro( InputLabels, TLabels );
33   fpaFilterOutputConfigureMacro( OutputProbabilities, TScalarImage );
34 }
35
36 // -------------------------------------------------------------------------
37 template< class _TImage, class _TLabels, class _TScalar >
38 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
39 ~OriginalRandomWalker( )
40 {
41 }
42
43 // -------------------------------------------------------------------------
44 template< class _TImage, class _TLabels, class _TScalar >
45 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
46 GenerateData( )
47 {
48   // Useful typedefs
49   typedef Eigen::Triplet< TScalar >         _TTriplet;
50   typedef std::vector< _TTriplet >          _TTriplets;
51   typedef Eigen::SparseMatrix< TScalar >    _TMatrix;
52   typedef Eigen::SimplicialLDLT< _TMatrix > _TSolver;
53
54   // Allocate outputs
55   this->AllocateOutputs( );
56
57   // Build boundary triplets and count labels
58   _TTriplets St, Bt;
59   std::map< TLabel, unsigned long > labels;
60   this->_Boundary( St, labels );
61   struct _TTripletsOrd
62   {
63     bool operator()( const _TTriplet& a, const _TTriplet& b )
64       {
65         return( a.row( ) < b.row( ) );
66       }
67   };
68   std::sort( St.begin( ), St.end( ), _TTripletsOrd( ) );
69   for( unsigned long i = 0; i < St.size( ); ++i )
70     Bt.push_back( _TTriplet( i, labels[ St[ i ].col( ) ], St[ i ].value( ) ) );
71
72   // Laplacian triplets
73   _TTriplets At, Rt;
74   this->_Laplacian( At, Rt, St );
75
76   // Matrices
77   TRegion region = this->GetInput( )->GetRequestedRegion( );
78   unsigned long nSeeds = St.size( );
79   unsigned long nLabels = labels.size( );
80   unsigned long N = region.GetNumberOfPixels( );
81
82   std::vector< TLabel > invLabels( nLabels );
83   for( typename std::map< TLabel, unsigned long >::value_type s: labels )
84     invLabels[ s.second ] = s.first;
85
86   _TMatrix B( nSeeds, nLabels );
87   B.setFromTriplets( Bt.begin( ), Bt.end( ) );
88   B.makeCompressed( );
89
90   _TMatrix R( N - nSeeds, nSeeds );
91   R.setFromTriplets( Rt.begin( ), Rt.end( ) );
92   R.makeCompressed( );
93
94   _TMatrix A( N - nSeeds, N - nSeeds );
95   A.setFromTriplets( At.begin( ), At.end( ) );
96   A.makeCompressed( );
97
98   // Solve dirichlet problem
99   _TSolver solver;
100   solver.compute( A );
101   if( solver.info( ) != Eigen::Success )
102   {
103     std::cerr << "Error computing." << std::endl;
104   } // fi
105   _TMatrix x = solver.solve( R * B );
106   if( solver.info( ) != Eigen::Success )
107   {
108     std::cerr << "Error solving." << std::endl;
109   } // fi
110
111   // Fill outputs
112   this->_Output( x, St, invLabels );
113 }
114
115 // -------------------------------------------------------------------------
116 template< class _TImage, class _TLabels, class _TScalar >
117 _TScalar fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
118 _W( const TIndex& i, const TIndex& j )
119 {
120   const TImage* in = this->GetInput( );
121   TScalar a = TScalar( in->GetPixel( i ) );
122   TScalar b = TScalar( in->GetPixel( j ) );
123   TScalar v = std::exp( -this->m_Beta * std::fabs( a - b ) );
124   if( v < this->m_Epsilon )
125     return( this->m_Epsilon );
126   else
127     return( v );
128 }
129
130 // -------------------------------------------------------------------------
131 template< class _TImage, class _TLabels, class _TScalar >
132 _TScalar fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
133 _L( const TIndex& i, const TIndex& j )
134 {
135   if( i == j )
136   {
137     TRegion r = this->GetInput( )->GetRequestedRegion( );
138     TScalar s = TScalar( 0 );
139     for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
140     {
141       for( int n = -1; n <= 1; n += 2 )
142       {
143         TIndex k = i;
144         k[ d ] += n;
145         if( r.IsInside( k ) )
146           s += this->_W( i, k );
147
148       } // rof
149
150     } // rof
151     return( s );
152   }
153   else
154     return( this->_W( i, j ) * TScalar( -1 ) );
155 }
156
157 // -------------------------------------------------------------------------
158 template< class _TImage, class _TLabels, class _TScalar >
159 template< class _TTriplets >
160 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
161 _Boundary( _TTriplets& B, std::map< TLabel, unsigned long >& labels )
162 {
163   B.clear( );
164
165   // Set up the multithreaded processing
166   _TBoundaryThreadStruct thrStr;
167   thrStr.Filter = this;
168   thrStr.Triplets = reinterpret_cast< void* >( &B );
169   thrStr.Labels = &labels;
170
171   // Configure threader
172   const TLabels* in_labels = this->GetInputLabels( );
173   const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
174   const unsigned int nThreads =
175     split->GetNumberOfSplits(
176       in_labels->GetRequestedRegion( ), this->GetNumberOfThreads( )
177       );
178
179   itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
180   threads->SetNumberOfThreads( nThreads );
181   threads->SetSingleMethod( this->_BoundaryCbk< _TTriplets >, &thrStr );
182
183   // Execute threader
184   threads->SingleMethodExecute( );
185 }
186
187 // -------------------------------------------------------------------------
188 template< class _TImage, class _TLabels, class _TScalar >
189 template< class _TTriplets >
190 ITK_THREAD_RETURN_TYPE
191 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
192 _BoundaryCbk( void* arg )
193 {
194   _TBoundaryThreadStruct* thrStr;
195   itk::ThreadIdType total, thrId, thrCount;
196   itk::MultiThreader::ThreadInfoStruct* thrInfo =
197     reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
198   thrId = thrInfo->ThreadID;
199   thrCount = thrInfo->NumberOfThreads;
200   thrStr = reinterpret_cast< _TBoundaryThreadStruct* >( thrInfo->UserData );
201
202   TRegion region;
203   total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
204   if( thrId < total )
205     thrStr->Filter->_ThreadedBoundary(
206       region, thrId,
207       reinterpret_cast< _TTriplets* >( thrStr->Triplets ),
208       thrStr->Labels
209       );
210   return( ITK_THREAD_RETURN_VALUE );
211 }
212
213 // -------------------------------------------------------------------------
214 template< class _TImage, class _TLabels, class _TScalar >
215 template< class _TTriplets >
216 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
217 _ThreadedBoundary(
218   const TRegion& region, const itk::ThreadIdType& id,
219   _TTriplets* B,
220   std::map< TLabel, unsigned long >* labels
221   )
222 {
223   typedef itk::ImageRegionConstIteratorWithIndex< TLabels > _TIt;
224   typedef typename std::map< TLabel, unsigned long >::value_type _TMapValue;
225   typedef typename std::map< unsigned long, TLabel >::value_type _TInvValue;
226   typedef typename _TTriplets::value_type _TTriplet;
227
228   const TLabels* in_labels = this->GetInputLabels( );
229   TRegion reqRegion = in_labels->GetRequestedRegion( );
230   _TIt it( in_labels, region );
231   for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
232   {
233     if( it.Get( ) != 0 )
234     {
235       unsigned long i = Self::_1D( it.GetIndex( ), reqRegion );
236       this->m_Mutex.Lock( );
237       B->push_back( _TTriplet( i, it.Get( ), TScalar( 1 ) ) );
238       if( labels->find( it.Get( ) ) == labels->end( ) )
239         labels->insert( _TMapValue( it.Get( ), labels->size( ) ) );
240       this->m_Mutex.Unlock( );
241
242     } // fi
243
244   } // rof
245 }
246
247 // -------------------------------------------------------------------------
248 template< class _TImage, class _TLabels, class _TScalar >
249 template< class _TTriplets >
250 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
251 _Laplacian( _TTriplets& A, _TTriplets& R, const _TTriplets& B )
252 {
253   A.clear( );
254   R.clear( );
255
256   // Set up the multithreaded processing
257   _TLaplacianThreadStruct thrStr;
258   thrStr.Filter = this;
259   thrStr.A = reinterpret_cast< void* >( &A );
260   thrStr.R = reinterpret_cast< void* >( &R );
261   thrStr.B = reinterpret_cast< const void* >( &B );
262
263   // Configure threader
264   const TImage* in = this->GetInputLabels( );
265   const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
266   const unsigned int nThreads =
267     split->GetNumberOfSplits(
268       in->GetRequestedRegion( ), this->GetNumberOfThreads( )
269       );
270
271   itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
272   threads->SetNumberOfThreads( nThreads );
273   threads->SetSingleMethod( this->_LaplacianCbk< _TTriplets >, &thrStr );
274
275   // Execute threader
276   threads->SingleMethodExecute( );
277 }
278
279 // -------------------------------------------------------------------------
280 template< class _TImage, class _TLabels, class _TScalar >
281 template< class _TTriplets >
282 ITK_THREAD_RETURN_TYPE
283 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
284 _LaplacianCbk( void* arg )
285 {
286   _TLaplacianThreadStruct* thrStr;
287   itk::ThreadIdType total, thrId, thrCount;
288   itk::MultiThreader::ThreadInfoStruct* thrInfo =
289     reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
290   thrId = thrInfo->ThreadID;
291   thrCount = thrInfo->NumberOfThreads;
292   thrStr = reinterpret_cast< _TLaplacianThreadStruct* >( thrInfo->UserData );
293
294   TRegion region;
295   total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
296   if( thrId < total )
297     thrStr->Filter->_ThreadedLaplacian(
298       region, thrId,
299       reinterpret_cast< _TTriplets* >( thrStr->A ),
300       reinterpret_cast< _TTriplets* >( thrStr->R ),
301       reinterpret_cast< const _TTriplets* >( thrStr->B )
302       );
303   return( ITK_THREAD_RETURN_VALUE );
304 }
305
306 // -------------------------------------------------------------------------
307 template< class _TImage, class _TLabels, class _TScalar >
308 template< class _TTriplets >
309 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
310 _ThreadedLaplacian(
311   const TRegion& region, const itk::ThreadIdType& id,
312   _TTriplets* A, _TTriplets* R, const _TTriplets* B
313   )
314 {
315   typedef itk::ImageRegionConstIteratorWithIndex< TImage > _TIt;
316   typedef typename _TTriplets::value_type _TTriplet;
317
318   const TImage* in = this->GetInput( );
319   const TLabels* in_labels = this->GetInputLabels( );
320   TRegion rqRegion = in->GetRequestedRegion( );
321   _TIt it( in, region );
322   for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
323   {
324     TIndex idx = it.GetIndex( );
325     bool iSeed = ( in_labels->GetPixel( idx ) != 0 );
326     unsigned long i = Self::_1D( idx, rqRegion );
327     unsigned long si;
328
329     // A's diagonal values
330     if( !iSeed )
331     {
332       si = Self::_NearSeedIndex( i, *B );
333       this->m_Mutex.Lock( );
334       A->push_back( _TTriplet( si, si, this->_L( idx, idx ) ) );
335       this->m_Mutex.Unlock( );
336     }
337     else
338       si = Self::_SeedIndex( i, *B );
339
340     // Neighbors (final matrix is symmetric)
341     for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
342     {
343       for( int s = -1; s <= 1; s += 2 )
344       {
345         TIndex jdx = idx;
346         jdx[ d ] += s;
347         if( rqRegion.IsInside( jdx ) )
348         {
349           TScalar L = this->_L( idx, jdx );
350           unsigned long j = Self::_1D( jdx, rqRegion );
351           bool jSeed = ( in_labels->GetPixel( jdx ) != 0 );
352           if( !jSeed )
353           {
354             unsigned long sj = Self::_NearSeedIndex( j, *B );
355             if( !iSeed )
356             {
357               this->m_Mutex.Lock( );
358               A->push_back( _TTriplet( si, sj, L ) );
359               this->m_Mutex.Unlock( );
360             }
361             else
362             {
363               this->m_Mutex.Lock( );
364               R->push_back( _TTriplet( sj, si, -L ) );
365               this->m_Mutex.Unlock( );
366
367             } // fi
368             
369           } // fi
370
371         } // fi
372
373       } // rof
374
375     } // rof
376
377   } // rof
378 }
379
380 // -------------------------------------------------------------------------
381 template< class _TImage, class _TLabels, class _TScalar >
382 template< class _TMatrix, class _TTriplets >
383 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
384 _Output(
385   const _TMatrix& X, const _TTriplets& S, const std::vector< TLabel >& invLabels
386   )
387 {
388   // Set up the multithreaded processing
389   _TOutputThreadStruct thrStr;
390   thrStr.Filter = this;
391   thrStr.X = reinterpret_cast< const void* >( &X );
392   thrStr.S = reinterpret_cast< const void* >( &S );
393   thrStr.InvLabels = &invLabels;
394
395   // Configure threader
396   const TLabels* out = this->GetOutput( );
397   const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
398   const unsigned int nThreads =
399     split->GetNumberOfSplits(
400       out->GetRequestedRegion( ), this->GetNumberOfThreads( )
401       );
402
403   itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
404   threads->SetNumberOfThreads( nThreads );
405   threads->SetSingleMethod(
406     this->_OutputCbk< _TMatrix, _TTriplets >, &thrStr
407     );
408
409   // Execute threader
410   threads->SingleMethodExecute( );
411 }
412
413 // -------------------------------------------------------------------------
414 template< class _TImage, class _TLabels, class _TScalar >
415 template< class _TMatrix, class _TTriplets >
416 ITK_THREAD_RETURN_TYPE
417 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
418 _OutputCbk( void* arg )
419 {
420   _TOutputThreadStruct* thrStr;
421   itk::ThreadIdType total, thrId, thrCount;
422   itk::MultiThreader::ThreadInfoStruct* thrInfo =
423     reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
424   thrId = thrInfo->ThreadID;
425   thrCount = thrInfo->NumberOfThreads;
426   thrStr = reinterpret_cast< _TOutputThreadStruct* >( thrInfo->UserData );
427
428   TRegion region;
429   total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
430   if( thrId < total )
431     thrStr->Filter->_ThreadedOutput(
432       region, thrId,
433       reinterpret_cast< const _TMatrix* >( thrStr->X ),
434       reinterpret_cast< const _TTriplets* >( thrStr->S ),
435       thrStr->InvLabels
436       );
437   return( ITK_THREAD_RETURN_VALUE );
438 }
439
440 // -------------------------------------------------------------------------
441 template< class _TImage, class _TLabels, class _TScalar >
442 template< class _TMatrix, class _TTriplets >
443 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
444 _ThreadedOutput(
445   const TRegion& region, const itk::ThreadIdType& id,
446   const _TMatrix* X, const _TTriplets* S,
447   const std::vector< TLabel >* invLabels
448   )
449 {
450   // Fill outputs
451   const TLabels* in_labels = this->GetInputLabels( );
452   TLabels* out_labels = this->GetOutput( );
453   TScalarImage* out_probs = this->GetOutputProbabilities( );
454   itk::ImageRegionConstIteratorWithIndex< TLabels > iIt( in_labels, region );
455   itk::ImageRegionIteratorWithIndex< TLabels > oIt( out_labels, region );
456   itk::ImageRegionIteratorWithIndex< TScalarImage > pIt( out_probs, region );
457   iIt.GoToBegin( );
458   oIt.GoToBegin( );
459   pIt.GoToBegin( );
460   for( ; !iIt.IsAtEnd( ); ++iIt, ++oIt, ++pIt )
461   {
462     if( iIt.Get( ) == 0 )
463     {
464       unsigned long i = Self::_1D( iIt.GetIndex( ), region );
465       unsigned long j = Self::_NearSeedIndex( i, *S );
466       TScalar maxP = X->coeff( j, 0 );
467       unsigned long maxL = 0;
468       for( unsigned int s = 1; s < invLabels->size( ); ++s )
469       {
470         TScalar p = X->coeff( j, s );
471         if( maxP <= p )
472         {
473           maxP = p;
474           maxL = s;
475
476         } // fi
477
478       } // rof
479       oIt.Set( ( *invLabels )[ maxL ] );
480       pIt.Set( maxP );
481     }
482     else
483     {
484       oIt.Set( iIt.Get( ) );
485       pIt.Set( TScalar( 1 ) );
486
487     } // fi
488
489   } // rof
490 }
491
492 // -------------------------------------------------------------------------
493 template< class _TImage, class _TLabels, class _TScalar >
494 unsigned long
495 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
496 _1D( const TIndex& idx, const TRegion& region )
497 {
498   unsigned long i = idx[ 0 ];
499   unsigned long off = 1;
500   typename TRegion::SizeType size = region.GetSize( );
501   for( unsigned int d = 1; d < TIndex::Dimension; ++d )
502   {
503     off *= size[ d - 1 ];
504     i += idx[ d ] * off;
505
506   } // rof
507   return( i );
508 }
509
510 // -------------------------------------------------------------------------
511 template< class _TImage, class _TLabels, class _TScalar >
512 template< class _TTriplets >
513 unsigned long
514 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
515 _SeedIndex( const unsigned long& i, const _TTriplets& t )
516 {
517   unsigned long s = 0;
518   unsigned long f = t.size( );
519   unsigned long e = f - 1;
520   while( e > s && f == t.size( ) )
521   {
522     if( e > s + 1 )
523     {
524       unsigned long h = ( e + s ) >> 1;
525       if     ( i < t[ h ].row( ) ) e = h;
526       else if( t[ h ].row( ) < i ) s = h;
527       else                         f = h;
528     }
529     else
530       f = ( t[ s ].row( ) == i )? s: e;
531
532   } // elihw
533   return( f );
534 }
535
536 // -------------------------------------------------------------------------
537 template< class _TImage, class _TLabels, class _TScalar >
538 template< class _TTriplets >
539 unsigned long
540 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
541 _NearSeedIndex( const unsigned long& i, const _TTriplets& t )
542 {
543   long s = 0;
544   long e = t.size( ) - 1;
545   while( e > s + 1 )
546   {
547     long h = ( e + s ) >> 1;
548     if     ( i < t[ h ].row( ) ) e = h;
549     else if( t[ h ].row( ) < i ) s = h;
550
551   } // elihw
552   long d;
553   if( i < t[ s ].row( ) )
554     d = -1;
555   else if( t[ s ].row( ) < i && i < t[ e ].row( ) )
556     d = s + 1;
557   else
558     d = e + 1;
559   if( d < 0 ) d = 0;
560   return( i - d );
561 }
562
563 #endif // __fpa__Common__OriginalRandomWalker__hxx__
564 // eof - $RCSfile$