]> Creatis software - FrontAlgorithms.git/blob - lib/fpa/Common/OriginalRandomWalker.hxx
...
[FrontAlgorithms.git] / lib / fpa / Common / OriginalRandomWalker.hxx
1 // =========================================================================
2 // @author Leonardo Florez Valencia
3 // @email florez-l@javeriana.edu.co
4 // =========================================================================
5 #ifndef __fpa__Common__OriginalRandomWalker__hxx__
6 #define __fpa__Common__OriginalRandomWalker__hxx__
7
8 #include <cmath>
9 #include <itkImageRegionConstIteratorWithIndex.h>
10 #include <itkImageRegionIteratorWithIndex.h>
11 #include <Eigen/Sparse>
12
13 // -------------------------------------------------------------------------
14 /* TODO
15    template< class _TImage, class _TLabels, class _TScalar >
16    void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
17    AddSeed( const TIndex& seed, const TLabel& label )
18    {
19    this->m_Seeds.push_back( seed );
20    this->m_Labels.push_back( label );
21    this->Modified( );
22    }
23 */
24
25 // -------------------------------------------------------------------------
26 template< class _TImage, class _TLabels, class _TScalar >
27 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
28 OriginalRandomWalker( )
29   : Superclass( )
30 {
31   fpaFilterInputConfigureMacro( InputLabels, TLabels );
32   fpaFilterOutputConfigureMacro( OutputProbabilities, TScalarImage );
33 }
34
35 // -------------------------------------------------------------------------
36 template< class _TImage, class _TLabels, class _TScalar >
37 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
38 ~OriginalRandomWalker( )
39 {
40 }
41
42 // -------------------------------------------------------------------------
43 template< class _TImage, class _TLabels, class _TScalar >
44 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
45 GenerateData( )
46 {
47   // Useful typedefs
48   typedef Eigen::Triplet< TScalar >         _TTriplet;
49   typedef std::vector< _TTriplet >          _TTriplets;
50   typedef Eigen::SparseMatrix< TScalar >    _TMatrix;
51   typedef Eigen::SimplicialLDLT< _TMatrix > _TSolver;
52
53   // Configure edge function
54   if( this->m_EdgeFunction.IsNull( ) )
55     itkExceptionMacro( << "Undefined edge function." );
56   const TImage* input = this->GetInput( );
57   this->m_EdgeFunction->SetDataObject( input );
58
59   // Allocate outputs
60   this->AllocateOutputs( );
61
62   // Build boundary triplets and count labels
63   _TTriplets St, Bt;
64   std::map< TLabel, unsigned long > labels;
65   this->_Boundary( St, labels );
66   struct _TTripletsOrd
67   {
68     bool operator()( const _TTriplet& a, const _TTriplet& b )
69       {
70         return( a.row( ) < b.row( ) );
71       }
72   };
73   std::sort( St.begin( ), St.end( ), _TTripletsOrd( ) );
74   for( unsigned long i = 0; i < St.size( ); ++i )
75     Bt.push_back( _TTriplet( i, labels[ St[ i ].col( ) ], St[ i ].value( ) ) );
76
77   // Laplacian triplets
78   _TTriplets At, Rt;
79   this->_Laplacian( At, Rt, St );
80
81   // Matrices
82   TRegion region = input->GetRequestedRegion( );
83   unsigned long nSeeds = St.size( );
84   unsigned long nLabels = labels.size( );
85   unsigned long N = region.GetNumberOfPixels( );
86
87   std::vector< TLabel > invLabels( nLabels );
88   for( typename std::map< TLabel, unsigned long >::value_type s: labels )
89     invLabels[ s.second ] = s.first;
90
91   _TMatrix B( nSeeds, nLabels );
92   B.setFromTriplets( Bt.begin( ), Bt.end( ) );
93   B.makeCompressed( );
94
95   _TMatrix R( N - nSeeds, nSeeds );
96   R.setFromTriplets( Rt.begin( ), Rt.end( ) );
97   R.makeCompressed( );
98
99   _TMatrix A( N - nSeeds, N - nSeeds );
100   A.setFromTriplets( At.begin( ), At.end( ) );
101   A.makeCompressed( );
102
103   // Solve dirichlet problem
104   _TSolver solver;
105   solver.compute( A );
106   if( solver.info( ) != Eigen::Success )
107     itkExceptionMacro( << "Error decomposing matrix." );
108   _TMatrix x = solver.solve( R * B );
109   if( solver.info( ) != Eigen::Success )
110     itkExceptionMacro( << "Error solving system." );
111
112   // Fill outputs
113   this->_Output( x, St, invLabels );
114 }
115
116 // -------------------------------------------------------------------------
117 template< class _TImage, class _TLabels, class _TScalar >
118 _TScalar fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
119 _L( const TIndex& i, const TIndex& j )
120 {
121   if( i == j )
122   {
123     TRegion r = this->GetInput( )->GetRequestedRegion( );
124     TScalar s = TScalar( 0 );
125     for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
126     {
127       for( int n = -1; n <= 1; n += 2 )
128       {
129         TIndex k = i;
130         k[ d ] += n;
131         if( r.IsInside( k ) )
132           s += this->m_EdgeFunction->Evaluate( i, k );
133
134       } // rof
135
136     } // rof
137     return( s );
138   }
139   else
140     return( -( this->m_EdgeFunction->Evaluate( i, j ) ) );
141 }
142
143 // -------------------------------------------------------------------------
144 template< class _TImage, class _TLabels, class _TScalar >
145 template< class _TTriplets >
146 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
147 _Boundary( _TTriplets& B, std::map< TLabel, unsigned long >& labels )
148 {
149   B.clear( );
150
151   // Set up the multithreaded processing
152   _TBoundaryThreadStruct thrStr;
153   thrStr.Filter = this;
154   thrStr.Triplets = reinterpret_cast< void* >( &B );
155   thrStr.Labels = &labels;
156
157   // Configure threader
158   const TLabels* in_labels = this->GetInputLabels( );
159   const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
160   const unsigned int nThreads =
161     split->GetNumberOfSplits(
162       in_labels->GetRequestedRegion( ), this->GetNumberOfThreads( )
163       );
164
165   itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
166   threads->SetNumberOfThreads( nThreads );
167   threads->SetSingleMethod( this->_BoundaryCbk< _TTriplets >, &thrStr );
168
169   // Execute threader
170   threads->SingleMethodExecute( );
171 }
172
173 // -------------------------------------------------------------------------
174 template< class _TImage, class _TLabels, class _TScalar >
175 template< class _TTriplets >
176 ITK_THREAD_RETURN_TYPE
177 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
178 _BoundaryCbk( void* arg )
179 {
180   _TBoundaryThreadStruct* thrStr;
181   itk::ThreadIdType total, thrId, thrCount;
182   itk::MultiThreader::ThreadInfoStruct* thrInfo =
183     reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
184   thrId = thrInfo->ThreadID;
185   thrCount = thrInfo->NumberOfThreads;
186   thrStr = reinterpret_cast< _TBoundaryThreadStruct* >( thrInfo->UserData );
187
188   TRegion region;
189   total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
190   if( thrId < total )
191     thrStr->Filter->_ThreadedBoundary(
192       region, thrId,
193       reinterpret_cast< _TTriplets* >( thrStr->Triplets ),
194       thrStr->Labels
195       );
196   return( ITK_THREAD_RETURN_VALUE );
197 }
198
199 // -------------------------------------------------------------------------
200 template< class _TImage, class _TLabels, class _TScalar >
201 template< class _TTriplets >
202 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
203 _ThreadedBoundary(
204   const TRegion& region, const itk::ThreadIdType& id,
205   _TTriplets* B,
206   std::map< TLabel, unsigned long >* labels
207   )
208 {
209   typedef itk::ImageRegionConstIteratorWithIndex< TLabels > _TIt;
210   typedef typename std::map< TLabel, unsigned long >::value_type _TMapValue;
211   typedef typename std::map< unsigned long, TLabel >::value_type _TInvValue;
212   typedef typename _TTriplets::value_type _TTriplet;
213
214   const TLabels* in_labels = this->GetInputLabels( );
215   TRegion reqRegion = in_labels->GetRequestedRegion( );
216   _TIt it( in_labels, region );
217   for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
218   {
219     if( it.Get( ) != 0 )
220     {
221       unsigned long i = Self::_1D( it.GetIndex( ), reqRegion );
222       this->m_Mutex.Lock( );
223       B->push_back( _TTriplet( i, it.Get( ), TScalar( 1 ) ) );
224       if( labels->find( it.Get( ) ) == labels->end( ) )
225         labels->insert( _TMapValue( it.Get( ), labels->size( ) ) );
226       this->m_Mutex.Unlock( );
227
228     } // fi
229
230   } // rof
231 }
232
233 // -------------------------------------------------------------------------
234 template< class _TImage, class _TLabels, class _TScalar >
235 template< class _TTriplets >
236 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
237 _Laplacian( _TTriplets& A, _TTriplets& R, const _TTriplets& B )
238 {
239   A.clear( );
240   R.clear( );
241
242   // Set up the multithreaded processing
243   _TLaplacianThreadStruct thrStr;
244   thrStr.Filter = this;
245   thrStr.A = reinterpret_cast< void* >( &A );
246   thrStr.R = reinterpret_cast< void* >( &R );
247   thrStr.B = reinterpret_cast< const void* >( &B );
248
249   // Configure threader
250   const TImage* in = this->GetInputLabels( );
251   const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
252   const unsigned int nThreads =
253     split->GetNumberOfSplits(
254       in->GetRequestedRegion( ), this->GetNumberOfThreads( )
255       );
256
257   itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
258   threads->SetNumberOfThreads( nThreads );
259   threads->SetSingleMethod( this->_LaplacianCbk< _TTriplets >, &thrStr );
260
261   // Execute threader
262   threads->SingleMethodExecute( );
263 }
264
265 // -------------------------------------------------------------------------
266 template< class _TImage, class _TLabels, class _TScalar >
267 template< class _TTriplets >
268 ITK_THREAD_RETURN_TYPE
269 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
270 _LaplacianCbk( void* arg )
271 {
272   _TLaplacianThreadStruct* thrStr;
273   itk::ThreadIdType total, thrId, thrCount;
274   itk::MultiThreader::ThreadInfoStruct* thrInfo =
275     reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
276   thrId = thrInfo->ThreadID;
277   thrCount = thrInfo->NumberOfThreads;
278   thrStr = reinterpret_cast< _TLaplacianThreadStruct* >( thrInfo->UserData );
279
280   TRegion region;
281   total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
282   if( thrId < total )
283     thrStr->Filter->_ThreadedLaplacian(
284       region, thrId,
285       reinterpret_cast< _TTriplets* >( thrStr->A ),
286       reinterpret_cast< _TTriplets* >( thrStr->R ),
287       reinterpret_cast< const _TTriplets* >( thrStr->B )
288       );
289   return( ITK_THREAD_RETURN_VALUE );
290 }
291
292 // -------------------------------------------------------------------------
293 template< class _TImage, class _TLabels, class _TScalar >
294 template< class _TTriplets >
295 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
296 _ThreadedLaplacian(
297   const TRegion& region, const itk::ThreadIdType& id,
298   _TTriplets* A, _TTriplets* R, const _TTriplets* B
299   )
300 {
301   typedef itk::ImageRegionConstIteratorWithIndex< TImage > _TIt;
302   typedef typename _TTriplets::value_type _TTriplet;
303
304   const TImage* in = this->GetInput( );
305   const TLabels* in_labels = this->GetInputLabels( );
306   TRegion reqRegion = in->GetRequestedRegion( );
307   _TIt it( in, region );
308   for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
309   {
310     TIndex idx = it.GetIndex( );
311     bool iSeed = ( in_labels->GetPixel( idx ) != 0 );
312     unsigned long i = Self::_1D( idx, reqRegion );
313     unsigned long si;
314
315     // A's diagonal values
316     if( !iSeed )
317     {
318       si = Self::_NearSeedIndex( i, *B );
319       this->m_Mutex.Lock( );
320       A->push_back( _TTriplet( si, si, this->_L( idx, idx ) ) );
321       this->m_Mutex.Unlock( );
322     }
323     else
324       si = Self::_SeedIndex( i, *B );
325
326     // Neighbors (final matrix is symmetric)
327     for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
328     {
329       for( int s = -1; s <= 1; s += 2 )
330       {
331         TIndex jdx = idx;
332         jdx[ d ] += s;
333         if( reqRegion.IsInside( jdx ) )
334         {
335           TScalar L = this->_L( idx, jdx );
336           unsigned long j = Self::_1D( jdx, reqRegion );
337           bool jSeed = ( in_labels->GetPixel( jdx ) != 0 );
338           if( !jSeed )
339           {
340             unsigned long sj = Self::_NearSeedIndex( j, *B );
341             if( !iSeed )
342             {
343               this->m_Mutex.Lock( );
344               A->push_back( _TTriplet( si, sj, L ) );
345               this->m_Mutex.Unlock( );
346             }
347             else
348             {
349               this->m_Mutex.Lock( );
350               R->push_back( _TTriplet( sj, si, -L ) );
351               this->m_Mutex.Unlock( );
352
353             } // fi
354
355           } // fi
356
357         } // fi
358
359       } // rof
360
361     } // rof
362
363   } // rof
364 }
365
366 // -------------------------------------------------------------------------
367 template< class _TImage, class _TLabels, class _TScalar >
368 template< class _TMatrix, class _TTriplets >
369 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
370 _Output(
371   const _TMatrix& X, const _TTriplets& S, const std::vector< TLabel >& invLabels
372   )
373 {
374   // Set up the multithreaded processing
375   _TOutputThreadStruct thrStr;
376   thrStr.Filter = this;
377   thrStr.X = reinterpret_cast< const void* >( &X );
378   thrStr.S = reinterpret_cast< const void* >( &S );
379   thrStr.InvLabels = &invLabels;
380
381   // Configure threader
382   const TLabels* out = this->GetOutput( );
383   const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
384   const unsigned int nThreads =
385     split->GetNumberOfSplits(
386       out->GetRequestedRegion( ), this->GetNumberOfThreads( )
387       );
388
389   itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
390   threads->SetNumberOfThreads( nThreads );
391   threads->SetSingleMethod(
392     this->_OutputCbk< _TMatrix, _TTriplets >, &thrStr
393     );
394
395   // Execute threader
396   threads->SingleMethodExecute( );
397 }
398
399 // -------------------------------------------------------------------------
400 template< class _TImage, class _TLabels, class _TScalar >
401 template< class _TMatrix, class _TTriplets >
402 ITK_THREAD_RETURN_TYPE
403 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
404 _OutputCbk( void* arg )
405 {
406   _TOutputThreadStruct* thrStr;
407   itk::ThreadIdType total, thrId, thrCount;
408   itk::MultiThreader::ThreadInfoStruct* thrInfo =
409     reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
410   thrId = thrInfo->ThreadID;
411   thrCount = thrInfo->NumberOfThreads;
412   thrStr = reinterpret_cast< _TOutputThreadStruct* >( thrInfo->UserData );
413
414   TRegion region;
415   total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
416   if( thrId < total )
417     thrStr->Filter->_ThreadedOutput(
418       region, thrId,
419       reinterpret_cast< const _TMatrix* >( thrStr->X ),
420       reinterpret_cast< const _TTriplets* >( thrStr->S ),
421       thrStr->InvLabels
422       );
423   return( ITK_THREAD_RETURN_VALUE );
424 }
425
426 // -------------------------------------------------------------------------
427 template< class _TImage, class _TLabels, class _TScalar >
428 template< class _TMatrix, class _TTriplets >
429 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
430 _ThreadedOutput(
431   const TRegion& region, const itk::ThreadIdType& id,
432   const _TMatrix* X, const _TTriplets* S,
433   const std::vector< TLabel >* invLabels
434   )
435 {
436   // Fill outputs
437   const TLabels* in_labels = this->GetInputLabels( );
438   TLabels* out_labels = this->GetOutput( );
439   TScalarImage* out_probs = this->GetOutputProbabilities( );
440   TRegion reqRegion = out_labels->GetRequestedRegion( );
441   itk::ImageRegionConstIteratorWithIndex< TLabels > iIt( in_labels, region );
442   itk::ImageRegionIteratorWithIndex< TLabels > oIt( out_labels, region );
443   itk::ImageRegionIteratorWithIndex< TScalarImage > pIt( out_probs, region );
444   iIt.GoToBegin( );
445   oIt.GoToBegin( );
446   pIt.GoToBegin( );
447   for( ; !iIt.IsAtEnd( ); ++iIt, ++oIt, ++pIt )
448   {
449     if( iIt.Get( ) == 0 )
450     {
451       unsigned long i = Self::_1D( iIt.GetIndex( ), reqRegion );
452       unsigned long j = Self::_NearSeedIndex( i, *S );
453       TScalar maxP = X->coeff( j, 0 );
454       unsigned long maxL = 0;
455       for( unsigned int s = 1; s < invLabels->size( ); ++s )
456       {
457         TScalar p = X->coeff( j, s );
458         if( maxP <= p )
459         {
460           maxP = p;
461           maxL = s;
462
463         } // fi
464
465       } // rof
466       oIt.Set( ( *invLabels )[ maxL ] );
467       pIt.Set( maxP );
468     }
469     else
470     {
471       oIt.Set( iIt.Get( ) );
472       pIt.Set( TScalar( 1 ) );
473
474     } // fi
475
476   } // rof
477 }
478
479 // -------------------------------------------------------------------------
480 template< class _TImage, class _TLabels, class _TScalar >
481 unsigned long
482 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
483 _1D( const TIndex& idx, const TRegion& region )
484 {
485   unsigned long i = idx[ 0 ];
486   unsigned long off = 1;
487   typename TRegion::SizeType size = region.GetSize( );
488   for( unsigned int d = 1; d < TIndex::Dimension; ++d )
489   {
490     off *= size[ d - 1 ];
491     i += idx[ d ] * off;
492
493   } // rof
494   return( i );
495 }
496
497 // -------------------------------------------------------------------------
498 template< class _TImage, class _TLabels, class _TScalar >
499 template< class _TTriplets >
500 unsigned long
501 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
502 _SeedIndex( const unsigned long& i, const _TTriplets& t )
503 {
504   unsigned long s = 0;
505   unsigned long f = t.size( );
506   unsigned long e = f - 1;
507   while( e > s && f == t.size( ) )
508   {
509     if( e > s + 1 )
510     {
511       unsigned long h = ( e + s ) >> 1;
512       if     ( i < t[ h ].row( ) ) e = h;
513       else if( t[ h ].row( ) < i ) s = h;
514       else                         f = h;
515     }
516     else
517       f = ( t[ s ].row( ) == i )? s: e;
518
519   } // elihw
520   return( f );
521 }
522
523 // -------------------------------------------------------------------------
524 template< class _TImage, class _TLabels, class _TScalar >
525 template< class _TTriplets >
526 unsigned long
527 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
528 _NearSeedIndex( const unsigned long& i, const _TTriplets& t )
529 {
530   long s = 0;
531   long e = t.size( ) - 1;
532   while( e > s + 1 )
533   {
534     long h = ( e + s ) >> 1;
535     if     ( i < t[ h ].row( ) ) e = h;
536     else if( t[ h ].row( ) < i ) s = h;
537
538   } // elihw
539   long d;
540   if( i < t[ s ].row( ) )
541     d = -1;
542   else if( t[ s ].row( ) < i && i < t[ e ].row( ) )
543     d = s + 1;
544   else
545     d = e + 1;
546   if( d < 0 ) d = 0;
547   return( i - d );
548 }
549
550 #endif // __fpa__Common__OriginalRandomWalker__hxx__
551 // eof - $RCSfile$