]> Creatis software - FrontAlgorithms.git/blob - lib/fpa/Common/RandomWalker.hxx
...
[FrontAlgorithms.git] / lib / fpa / Common / RandomWalker.hxx
1 // =========================================================================
2 // @author Leonardo Florez Valencia
3 // @email florez-l@javeriana.edu.co
4 // =========================================================================
5 #ifndef __fpa__Common__RandomWalker__hxx__
6 #define __fpa__Common__RandomWalker__hxx__
7
8 #include <itkImageRegionConstIteratorWithIndex.h>
9 #include <itkImageRegionIteratorWithIndex.h>
10 #ifdef USE_Eigen3
11 #  include <Eigen/Sparse>
12 #endif // USE_Eigen3
13
14 // -------------------------------------------------------------------------
15 template< class _TImage, class _TLabels, class _TScalar >
16 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
17 RandomWalker( )
18   : Superclass( )
19 {
20   fpaFilterInputConfigureMacro( InputLabels, TLabels );
21   fpaFilterOutputConfigureMacro( OutputProbabilities, TScalarImage );
22 }
23
24 // -------------------------------------------------------------------------
25 template< class _TImage, class _TLabels, class _TScalar >
26 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
27 ~RandomWalker( )
28 {
29 }
30
31 // -------------------------------------------------------------------------
32 template< class _TImage, class _TLabels, class _TScalar >
33 void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
34 GenerateData( )
35 {
36 #ifdef USE_Eigen3
37   // Useful typedefs
38   typedef Eigen::Triplet< TScalar >         _TTriplet;
39   typedef std::vector< _TTriplet >          _TTriplets;
40   typedef Eigen::SparseMatrix< TScalar >    _TMatrix;
41   typedef Eigen::SimplicialLDLT< _TMatrix > _TSolver;
42
43   // Configure edge function
44   if( this->m_EdgeFunction.IsNull( ) )
45     itkExceptionMacro( << "Undefined edge function." );
46   const TImage* input = this->GetInput( );
47   this->m_EdgeFunction->SetDataObject( input );
48
49   // Allocate outputs
50   this->AllocateOutputs( );
51
52   // Persisting objects
53   _TMatrix A( 1, 1 ), C( 1, 1 );
54   _TTriplets St;
55   std::vector< TLabel > invLabels;
56
57   { // begin
58     // Build boundary triplets and count labels
59     _TTriplets Bt;
60     std::map< TLabel, unsigned long > labels;
61     itkDebugMacro( << "Building boundary matrix..." );
62     this->_Boundary( St, labels );
63     struct _TTripletsOrd
64     {
65       bool operator()( const _TTriplet& a, const _TTriplet& b )
66         {
67           return( a.row( ) < b.row( ) );
68         }
69     };
70     itkDebugMacro( << "Sorting boundary pixels..." );
71     std::sort( St.begin( ), St.end( ), _TTripletsOrd( ) );
72     itkDebugMacro( << "Assigning boundary pixels..." );
73     for( unsigned long i = 0; i < St.size( ); ++i )
74       Bt.push_back(
75         _TTriplet( i, labels[ St[ i ].col( ) ], St[ i ].value( ) )
76         );
77
78     // Laplacian triplets
79     itkDebugMacro( << "Building laplacian matrix..." );
80     _TTriplets At, Rt;
81     this->_Laplacian( At, Rt, St );
82
83     // Matrices
84     TRegion region = input->GetRequestedRegion( );
85     unsigned long nSeeds = St.size( );
86     unsigned long nLabels = labels.size( );
87     unsigned long N = region.GetNumberOfPixels( );
88
89     itkDebugMacro( << "Creating inverse labels..." );
90     invLabels.resize( nLabels );
91     for( typename std::map< TLabel, unsigned long >::value_type s: labels )
92       invLabels[ s.second ] = s.first;
93
94     itkDebugMacro( << "Creating B matrix..." );
95     _TMatrix B( nSeeds, nLabels );
96     B.setFromTriplets( Bt.begin( ), Bt.end( ) );
97     B.makeCompressed( );
98
99     itkDebugMacro( << "Creating R matrix..." );
100     _TMatrix R( N - nSeeds, nSeeds );
101     R.setFromTriplets( Rt.begin( ), Rt.end( ) );
102     R.makeCompressed( );
103
104     itkDebugMacro( << "Creating C matrix..." );
105     C = R * B;
106
107     itkDebugMacro( << "Creating A matrix..." );
108     A.resize( N - nSeeds, N - nSeeds );
109     A.setFromTriplets( At.begin( ), At.end( ) );
110     A.makeCompressed( );
111   } // end
112
113   // Solve dirichlet problem
114   _TSolver solver;
115   itkDebugMacro( << "Factorizing problem..." );
116   solver.compute( A );
117   if( solver.info( ) != Eigen::Success )
118     itkExceptionMacro( << "Error decomposing matrix." );
119   itkDebugMacro( << "Solving problem..." );
120   _TMatrix x = solver.solve( C );
121   if( solver.info( ) != Eigen::Success )
122     itkExceptionMacro( << "Error solving system." );
123
124   // Fill outputs
125   itkDebugMacro( << "Filling output..." );
126   this->_Output( x, St, invLabels );
127 #endif // USE_Eigen3
128 }
129
130 // -------------------------------------------------------------------------
131 template< class _TImage, class _TLabels, class _TScalar >
132 _TScalar fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
133 _L( const TIndex& i, const TIndex& j )
134 {
135 #ifdef USE_Eigen3
136   if( i == j )
137   {
138     TRegion r = this->GetInput( )->GetRequestedRegion( );
139     TScalar s = TScalar( 0 );
140     for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
141     {
142       for( int n = -1; n <= 1; n += 2 )
143       {
144         TIndex k = i;
145         k[ d ] += n;
146         if( r.IsInside( k ) )
147           s += this->m_EdgeFunction->Evaluate( i, k );
148
149       } // rof
150
151     } // rof
152     return( s );
153   }
154   else
155     return( -( this->m_EdgeFunction->Evaluate( i, j ) ) );
156 #else // USE_Eigen3
157   return( _TScalar( 0 ) );
158 #endif // USE_Eigen3
159 }
160
161 // -------------------------------------------------------------------------
162 template< class _TImage, class _TLabels, class _TScalar >
163 template< class _TTriplets >
164 void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
165 _Boundary( _TTriplets& B, std::map< TLabel, unsigned long >& labels )
166 {
167 #ifdef USE_Eigen3
168   B.clear( );
169
170   // Set up the multithreaded processing
171   _TBoundaryThreadStruct thrStr;
172   thrStr.Filter = this;
173   thrStr.Triplets = reinterpret_cast< void* >( &B );
174   thrStr.Labels = &labels;
175
176   // Configure threader
177   const TLabels* in_labels = this->GetInputLabels( );
178   const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
179   const unsigned int nThreads =
180     split->GetNumberOfSplits(
181       in_labels->GetRequestedRegion( ), this->GetNumberOfThreads( )
182       );
183
184   itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
185   threads->SetNumberOfThreads( nThreads );
186   threads->SetSingleMethod( this->_BoundaryCbk< _TTriplets >, &thrStr );
187
188   // Execute threader
189   threads->SingleMethodExecute( );
190 #endif // USE_Eigen3
191 }
192
193 // -------------------------------------------------------------------------
194 template< class _TImage, class _TLabels, class _TScalar >
195 template< class _TTriplets >
196 ITK_THREAD_RETURN_TYPE
197 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
198 _BoundaryCbk( void* arg )
199 {
200 #ifdef USE_Eigen3
201   _TBoundaryThreadStruct* thrStr;
202   itk::ThreadIdType total, thrId, thrCount;
203   itk::MultiThreader::ThreadInfoStruct* thrInfo =
204     reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
205   thrId = thrInfo->ThreadID;
206   thrCount = thrInfo->NumberOfThreads;
207   thrStr = reinterpret_cast< _TBoundaryThreadStruct* >( thrInfo->UserData );
208
209   TRegion region;
210   total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
211   if( thrId < total )
212     thrStr->Filter->_ThreadedBoundary(
213       region, thrId,
214       reinterpret_cast< _TTriplets* >( thrStr->Triplets ),
215       thrStr->Labels
216       );
217 #endif // USE_Eigen3
218   return( ITK_THREAD_RETURN_VALUE );
219 }
220
221 // -------------------------------------------------------------------------
222 template< class _TImage, class _TLabels, class _TScalar >
223 template< class _TTriplets >
224 void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
225 _ThreadedBoundary(
226   const TRegion& region, const itk::ThreadIdType& id,
227   _TTriplets* B,
228   std::map< TLabel, unsigned long >* labels
229   )
230 {
231 #ifdef USE_Eigen3
232   typedef itk::ImageRegionConstIteratorWithIndex< TLabels > _TIt;
233   typedef typename std::map< TLabel, unsigned long >::value_type _TMapValue;
234   typedef typename std::map< unsigned long, TLabel >::value_type _TInvValue;
235   typedef typename _TTriplets::value_type _TTriplet;
236
237   const TLabels* in_labels = this->GetInputLabels( );
238   TRegion reqRegion = in_labels->GetRequestedRegion( );
239   _TIt it( in_labels, region );
240   for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
241   {
242     if( it.Get( ) != 0 )
243     {
244       unsigned long i = Self::_1D( it.GetIndex( ), reqRegion );
245       this->m_Mutex.Lock( );
246       B->push_back( _TTriplet( i, it.Get( ), TScalar( 1 ) ) );
247       if( labels->find( it.Get( ) ) == labels->end( ) )
248         labels->insert( _TMapValue( it.Get( ), labels->size( ) ) );
249       this->m_Mutex.Unlock( );
250
251     } // fi
252
253   } // rof
254 #endif // USE_Eigen3
255 }
256
257 // -------------------------------------------------------------------------
258 template< class _TImage, class _TLabels, class _TScalar >
259 template< class _TTriplets >
260 void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
261 _Laplacian( _TTriplets& A, _TTriplets& R, const _TTriplets& B )
262 {
263 #ifdef USE_Eigen3
264   A.clear( );
265   R.clear( );
266
267   // Set up the multithreaded processing
268   _TLaplacianThreadStruct thrStr;
269   thrStr.Filter = this;
270   thrStr.A = reinterpret_cast< void* >( &A );
271   thrStr.R = reinterpret_cast< void* >( &R );
272   thrStr.B = reinterpret_cast< const void* >( &B );
273
274   // Configure threader
275   const TImage* in = this->GetInput( );
276   const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
277   const unsigned int nThreads =
278     split->GetNumberOfSplits(
279       in->GetRequestedRegion( ), this->GetNumberOfThreads( )
280       );
281
282   itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
283   threads->SetNumberOfThreads( nThreads );
284   threads->SetSingleMethod( this->_LaplacianCbk< _TTriplets >, &thrStr );
285
286   // Execute threader
287   threads->SingleMethodExecute( );
288 #endif // USE_Eigen3
289 }
290
291 // -------------------------------------------------------------------------
292 template< class _TImage, class _TLabels, class _TScalar >
293 template< class _TTriplets >
294 ITK_THREAD_RETURN_TYPE
295 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
296 _LaplacianCbk( void* arg )
297 {
298 #ifdef USE_Eigen3
299   _TLaplacianThreadStruct* thrStr;
300   itk::ThreadIdType total, thrId, thrCount;
301   itk::MultiThreader::ThreadInfoStruct* thrInfo =
302     reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
303   thrId = thrInfo->ThreadID;
304   thrCount = thrInfo->NumberOfThreads;
305   thrStr = reinterpret_cast< _TLaplacianThreadStruct* >( thrInfo->UserData );
306
307   TRegion region;
308   total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
309   if( thrId < total )
310     thrStr->Filter->_ThreadedLaplacian(
311       region, thrId,
312       reinterpret_cast< _TTriplets* >( thrStr->A ),
313       reinterpret_cast< _TTriplets* >( thrStr->R ),
314       reinterpret_cast< const _TTriplets* >( thrStr->B )
315       );
316 #endif // USE_Eigen3
317   return( ITK_THREAD_RETURN_VALUE );
318 }
319
320 // -------------------------------------------------------------------------
321 template< class _TImage, class _TLabels, class _TScalar >
322 template< class _TTriplets >
323 void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
324 _ThreadedLaplacian(
325   const TRegion& region, const itk::ThreadIdType& id,
326   _TTriplets* A, _TTriplets* R, const _TTriplets* B
327   )
328 {
329 #ifdef USE_Eigen3
330   typedef itk::ImageRegionConstIteratorWithIndex< TImage > _TIt;
331   typedef typename _TTriplets::value_type _TTriplet;
332
333   const TImage* in = this->GetInput( );
334   const TLabels* in_labels = this->GetInputLabels( );
335   TRegion reqRegion = in->GetRequestedRegion( );
336   _TIt it( in, region );
337   for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
338   {
339     TIndex idx = it.GetIndex( );
340     bool iSeed = ( in_labels->GetPixel( idx ) != 0 );
341     unsigned long i = Self::_1D( idx, reqRegion );
342     unsigned long si;
343
344     // A's diagonal values
345     if( !iSeed )
346     {
347       si = Self::_NearSeedIndex( i, *B );
348       this->m_Mutex.Lock( );
349       A->push_back( _TTriplet( si, si, this->_L( idx, idx ) ) );
350       this->m_Mutex.Unlock( );
351     }
352     else
353       si = Self::_SeedIndex( i, *B );
354
355     // Neighbors (final matrix is symmetric)
356     for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
357     {
358       for( int s = -1; s <= 1; s += 2 )
359       {
360         TIndex jdx = idx;
361         jdx[ d ] += s;
362         if( reqRegion.IsInside( jdx ) )
363         {
364           TScalar L = this->_L( idx, jdx );
365           unsigned long j = Self::_1D( jdx, reqRegion );
366           bool jSeed = ( in_labels->GetPixel( jdx ) != 0 );
367           if( !jSeed )
368           {
369             unsigned long sj = Self::_NearSeedIndex( j, *B );
370             if( !iSeed )
371             {
372               this->m_Mutex.Lock( );
373               A->push_back( _TTriplet( si, sj, L ) );
374               this->m_Mutex.Unlock( );
375             }
376             else
377             {
378               this->m_Mutex.Lock( );
379               R->push_back( _TTriplet( sj, si, -L ) );
380               this->m_Mutex.Unlock( );
381
382             } // fi
383
384           } // fi
385
386         } // fi
387
388       } // rof
389
390     } // rof
391
392   } // rof
393 #endif // USE_Eigen3
394 }
395
396 // -------------------------------------------------------------------------
397 template< class _TImage, class _TLabels, class _TScalar >
398 template< class _TMatrix, class _TTriplets >
399 void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
400 _Output(
401   const _TMatrix& X, const _TTriplets& S, const std::vector< TLabel >& invLabels
402   )
403 {
404 #ifdef USE_Eigen3
405   // Set up the multithreaded processing
406   _TOutputThreadStruct thrStr;
407   thrStr.Filter = this;
408   thrStr.X = reinterpret_cast< const void* >( &X );
409   thrStr.S = reinterpret_cast< const void* >( &S );
410   thrStr.InvLabels = &invLabels;
411
412   // Configure threader
413   const TLabels* out = this->GetOutput( );
414   const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
415   const unsigned int nThreads =
416     split->GetNumberOfSplits(
417       out->GetRequestedRegion( ), this->GetNumberOfThreads( )
418       );
419
420   itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
421   threads->SetNumberOfThreads( nThreads );
422   threads->SetSingleMethod(
423     this->_OutputCbk< _TMatrix, _TTriplets >, &thrStr
424     );
425
426   // Execute threader
427   threads->SingleMethodExecute( );
428 #endif // USE_Eigen3
429 }
430
431 // -------------------------------------------------------------------------
432 template< class _TImage, class _TLabels, class _TScalar >
433 template< class _TMatrix, class _TTriplets >
434 ITK_THREAD_RETURN_TYPE
435 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
436 _OutputCbk( void* arg )
437 {
438 #ifdef USE_Eigen3
439   _TOutputThreadStruct* thrStr;
440   itk::ThreadIdType total, thrId, thrCount;
441   itk::MultiThreader::ThreadInfoStruct* thrInfo =
442     reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
443   thrId = thrInfo->ThreadID;
444   thrCount = thrInfo->NumberOfThreads;
445   thrStr = reinterpret_cast< _TOutputThreadStruct* >( thrInfo->UserData );
446
447   TRegion region;
448   total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
449   if( thrId < total )
450     thrStr->Filter->_ThreadedOutput(
451       region, thrId,
452       reinterpret_cast< const _TMatrix* >( thrStr->X ),
453       reinterpret_cast< const _TTriplets* >( thrStr->S ),
454       thrStr->InvLabels
455       );
456 #endif // USE_Eigen3
457   return( ITK_THREAD_RETURN_VALUE );
458 }
459
460 // -------------------------------------------------------------------------
461 template< class _TImage, class _TLabels, class _TScalar >
462 template< class _TMatrix, class _TTriplets >
463 void fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
464 _ThreadedOutput(
465   const TRegion& region, const itk::ThreadIdType& id,
466   const _TMatrix* X, const _TTriplets* S,
467   const std::vector< TLabel >* invLabels
468   )
469 {
470 #ifdef USE_Eigen3
471   // Fill outputs
472   const TLabels* in_labels = this->GetInputLabels( );
473   TLabels* out_labels = this->GetOutput( );
474   TScalarImage* out_probs = this->GetOutputProbabilities( );
475   TRegion reqRegion = out_labels->GetRequestedRegion( );
476   itk::ImageRegionConstIteratorWithIndex< TLabels > iIt( in_labels, region );
477   itk::ImageRegionIteratorWithIndex< TLabels > oIt( out_labels, region );
478   itk::ImageRegionIteratorWithIndex< TScalarImage > pIt( out_probs, region );
479   iIt.GoToBegin( );
480   oIt.GoToBegin( );
481   pIt.GoToBegin( );
482   for( ; !iIt.IsAtEnd( ); ++iIt, ++oIt, ++pIt )
483   {
484     if( iIt.Get( ) == 0 )
485     {
486       unsigned long i = Self::_1D( iIt.GetIndex( ), reqRegion );
487       unsigned long j = Self::_NearSeedIndex( i, *S );
488       TScalar maxP = X->coeff( j, 0 );
489       unsigned long maxL = 0;
490       for( unsigned int s = 1; s < invLabels->size( ); ++s )
491       {
492         TScalar p = X->coeff( j, s );
493         if( maxP <= p )
494         {
495           maxP = p;
496           maxL = s;
497
498         } // fi
499
500       } // rof
501       oIt.Set( ( *invLabels )[ maxL ] );
502       pIt.Set( maxP );
503     }
504     else
505     {
506       oIt.Set( iIt.Get( ) );
507       pIt.Set( TScalar( 1 ) );
508
509     } // fi
510
511   } // rof
512 #endif // USE_Eigen3
513 }
514
515 // -------------------------------------------------------------------------
516 template< class _TImage, class _TLabels, class _TScalar >
517 unsigned long
518 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
519 _1D( const TIndex& idx, const TRegion& region )
520 {
521   unsigned long i = idx[ 0 ];
522   unsigned long off = 1;
523   typename TRegion::SizeType size = region.GetSize( );
524   for( unsigned int d = 1; d < TIndex::Dimension; ++d )
525   {
526     off *= size[ d - 1 ];
527     i += idx[ d ] * off;
528
529   } // rof
530   return( i );
531 }
532
533 // -------------------------------------------------------------------------
534 template< class _TImage, class _TLabels, class _TScalar >
535 template< class _TTriplets >
536 unsigned long
537 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
538 _SeedIndex( const unsigned long& i, const _TTriplets& t )
539 {
540   unsigned long s = 0;
541   unsigned long f = t.size( );
542   unsigned long e = f - 1;
543   while( e > s && f == t.size( ) )
544   {
545     if( e > s + 1 )
546     {
547       unsigned long h = ( e + s ) >> 1;
548       if     ( i < t[ h ].row( ) ) e = h;
549       else if( t[ h ].row( ) < i ) s = h;
550       else                         f = h;
551     }
552     else
553       f = ( t[ s ].row( ) == i )? s: e;
554
555   } // elihw
556   return( f );
557 }
558
559 // -------------------------------------------------------------------------
560 template< class _TImage, class _TLabels, class _TScalar >
561 template< class _TTriplets >
562 unsigned long
563 fpa::Common::RandomWalker< _TImage, _TLabels, _TScalar >::
564 _NearSeedIndex( const unsigned long& i, const _TTriplets& t )
565 {
566   long s = 0;
567   long e = t.size( ) - 1;
568   while( e > s + 1 )
569   {
570     long h = ( e + s ) >> 1;
571     if     ( i < t[ h ].row( ) ) e = h;
572     else if( t[ h ].row( ) < i ) s = h;
573
574   } // elihw
575   long d;
576   if( i < t[ s ].row( ) )
577     d = -1;
578   else if( t[ s ].row( ) < i && i < t[ e ].row( ) )
579     d = s + 1;
580   else
581     d = e + 1;
582   if( d < 0 ) d = 0;
583   return( i - d );
584 }
585
586 #endif // __fpa__Common__RandomWalker__hxx__
587
588 // eof - $RCSfile$