]> Creatis software - FrontAlgorithms.git/blob - lib/fpa/Common/OriginalRandomWalker.hxx
...
[FrontAlgorithms.git] / lib / fpa / Common / OriginalRandomWalker.hxx
1 // =========================================================================
2 // @author Leonardo Florez Valencia
3 // @email florez-l@javeriana.edu.co
4 // =========================================================================
5 #ifndef __fpa__Common__OriginalRandomWalker__hxx__
6 #define __fpa__Common__OriginalRandomWalker__hxx__
7
8 #include <itkImageRegionConstIteratorWithIndex.h>
9 #include <itkImageRegionIteratorWithIndex.h>
10 #include <Eigen/Sparse>
11
12 // -------------------------------------------------------------------------
13 /* TODO
14    template< class _TImage, class _TLabels, class _TScalar >
15    void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
16    AddSeed( const TIndex& seed, const TLabel& label )
17    {
18    this->m_Seeds.push_back( seed );
19    this->m_Labels.push_back( label );
20    this->Modified( );
21    }
22 */
23
24 // -------------------------------------------------------------------------
25 template< class _TImage, class _TLabels, class _TScalar >
26 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
27 OriginalRandomWalker( )
28   : Superclass( )
29 {
30   fpaFilterInputConfigureMacro( InputLabels, TLabels );
31   fpaFilterOutputConfigureMacro( OutputProbabilities, TScalarImage );
32 }
33
34 // -------------------------------------------------------------------------
35 template< class _TImage, class _TLabels, class _TScalar >
36 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
37 ~OriginalRandomWalker( )
38 {
39 }
40
41 // -------------------------------------------------------------------------
42 template< class _TImage, class _TLabels, class _TScalar >
43 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
44 GenerateData( )
45 {
46   // Useful typedefs
47   typedef Eigen::Triplet< TScalar >         _TTriplet;
48   typedef std::vector< _TTriplet >          _TTriplets;
49   typedef Eigen::SparseMatrix< TScalar >    _TMatrix;
50   typedef Eigen::SimplicialLDLT< _TMatrix > _TSolver;
51
52   // Configure edge function
53   if( this->m_EdgeFunction.IsNull( ) )
54     itkExceptionMacro( << "Undefined edge function." );
55   const TImage* input = this->GetInput( );
56   this->m_EdgeFunction->SetDataObject( input );
57
58   // Allocate outputs
59   this->AllocateOutputs( );
60
61   // Build boundary triplets and count labels
62   _TTriplets St, Bt;
63   std::map< TLabel, unsigned long > labels;
64   this->_Boundary( St, labels );
65   struct _TTripletsOrd
66   {
67     bool operator()( const _TTriplet& a, const _TTriplet& b )
68       {
69         return( a.row( ) < b.row( ) );
70       }
71   };
72   std::sort( St.begin( ), St.end( ), _TTripletsOrd( ) );
73   for( unsigned long i = 0; i < St.size( ); ++i )
74     Bt.push_back( _TTriplet( i, labels[ St[ i ].col( ) ], St[ i ].value( ) ) );
75
76   // Laplacian triplets
77   _TTriplets At, Rt;
78   this->_Laplacian( At, Rt, St );
79
80   // Matrices
81   TRegion region = input->GetRequestedRegion( );
82   unsigned long nSeeds = St.size( );
83   unsigned long nLabels = labels.size( );
84   unsigned long N = region.GetNumberOfPixels( );
85
86   std::vector< TLabel > invLabels( nLabels );
87   for( typename std::map< TLabel, unsigned long >::value_type s: labels )
88     invLabels[ s.second ] = s.first;
89
90   _TMatrix B( nSeeds, nLabels );
91   B.setFromTriplets( Bt.begin( ), Bt.end( ) );
92   B.makeCompressed( );
93
94   _TMatrix R( N - nSeeds, nSeeds );
95   R.setFromTriplets( Rt.begin( ), Rt.end( ) );
96   R.makeCompressed( );
97
98   _TMatrix A( N - nSeeds, N - nSeeds );
99   A.setFromTriplets( At.begin( ), At.end( ) );
100   A.makeCompressed( );
101
102   // Solve dirichlet problem
103   _TSolver solver;
104   solver.compute( A );
105   if( solver.info( ) != Eigen::Success )
106     itkExceptionMacro( << "Error decomposing matrix." );
107   _TMatrix x = solver.solve( R * B );
108   if( solver.info( ) != Eigen::Success )
109     itkExceptionMacro( << "Error solving system." );
110
111   // Fill outputs
112   this->_Output( x, St, invLabels );
113 }
114
115 // -------------------------------------------------------------------------
116 template< class _TImage, class _TLabels, class _TScalar >
117 _TScalar fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
118 _L( const TIndex& i, const TIndex& j )
119 {
120   if( i == j )
121   {
122     TRegion r = this->GetInput( )->GetRequestedRegion( );
123     TScalar s = TScalar( 0 );
124     for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
125     {
126       for( int n = -1; n <= 1; n += 2 )
127       {
128         TIndex k = i;
129         k[ d ] += n;
130         if( r.IsInside( k ) )
131           s += this->m_EdgeFunction->Evaluate( i, k );
132
133       } // rof
134
135     } // rof
136     return( s );
137   }
138   else
139     return( -( this->m_EdgeFunction->Evaluate( i, j ) ) );
140 }
141
142 // -------------------------------------------------------------------------
143 template< class _TImage, class _TLabels, class _TScalar >
144 template< class _TTriplets >
145 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
146 _Boundary( _TTriplets& B, std::map< TLabel, unsigned long >& labels )
147 {
148   B.clear( );
149
150   // Set up the multithreaded processing
151   _TBoundaryThreadStruct thrStr;
152   thrStr.Filter = this;
153   thrStr.Triplets = reinterpret_cast< void* >( &B );
154   thrStr.Labels = &labels;
155
156   // Configure threader
157   const TLabels* in_labels = this->GetInputLabels( );
158   const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
159   const unsigned int nThreads =
160     split->GetNumberOfSplits(
161       in_labels->GetRequestedRegion( ), this->GetNumberOfThreads( )
162       );
163
164   itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
165   threads->SetNumberOfThreads( nThreads );
166   threads->SetSingleMethod( this->_BoundaryCbk< _TTriplets >, &thrStr );
167
168   // Execute threader
169   threads->SingleMethodExecute( );
170 }
171
172 // -------------------------------------------------------------------------
173 template< class _TImage, class _TLabels, class _TScalar >
174 template< class _TTriplets >
175 ITK_THREAD_RETURN_TYPE
176 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
177 _BoundaryCbk( void* arg )
178 {
179   _TBoundaryThreadStruct* thrStr;
180   itk::ThreadIdType total, thrId, thrCount;
181   itk::MultiThreader::ThreadInfoStruct* thrInfo =
182     reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
183   thrId = thrInfo->ThreadID;
184   thrCount = thrInfo->NumberOfThreads;
185   thrStr = reinterpret_cast< _TBoundaryThreadStruct* >( thrInfo->UserData );
186
187   TRegion region;
188   total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
189   if( thrId < total )
190     thrStr->Filter->_ThreadedBoundary(
191       region, thrId,
192       reinterpret_cast< _TTriplets* >( thrStr->Triplets ),
193       thrStr->Labels
194       );
195   return( ITK_THREAD_RETURN_VALUE );
196 }
197
198 // -------------------------------------------------------------------------
199 template< class _TImage, class _TLabels, class _TScalar >
200 template< class _TTriplets >
201 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
202 _ThreadedBoundary(
203   const TRegion& region, const itk::ThreadIdType& id,
204   _TTriplets* B,
205   std::map< TLabel, unsigned long >* labels
206   )
207 {
208   typedef itk::ImageRegionConstIteratorWithIndex< TLabels > _TIt;
209   typedef typename std::map< TLabel, unsigned long >::value_type _TMapValue;
210   typedef typename std::map< unsigned long, TLabel >::value_type _TInvValue;
211   typedef typename _TTriplets::value_type _TTriplet;
212
213   const TLabels* in_labels = this->GetInputLabels( );
214   TRegion reqRegion = in_labels->GetRequestedRegion( );
215   _TIt it( in_labels, region );
216   for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
217   {
218     if( it.Get( ) != 0 )
219     {
220       unsigned long i = Self::_1D( it.GetIndex( ), reqRegion );
221       this->m_Mutex.Lock( );
222       B->push_back( _TTriplet( i, it.Get( ), TScalar( 1 ) ) );
223       if( labels->find( it.Get( ) ) == labels->end( ) )
224         labels->insert( _TMapValue( it.Get( ), labels->size( ) ) );
225       this->m_Mutex.Unlock( );
226
227     } // fi
228
229   } // rof
230 }
231
232 // -------------------------------------------------------------------------
233 template< class _TImage, class _TLabels, class _TScalar >
234 template< class _TTriplets >
235 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
236 _Laplacian( _TTriplets& A, _TTriplets& R, const _TTriplets& B )
237 {
238   A.clear( );
239   R.clear( );
240
241   // Set up the multithreaded processing
242   _TLaplacianThreadStruct thrStr;
243   thrStr.Filter = this;
244   thrStr.A = reinterpret_cast< void* >( &A );
245   thrStr.R = reinterpret_cast< void* >( &R );
246   thrStr.B = reinterpret_cast< const void* >( &B );
247
248   // Configure threader
249   const TImage* in = this->GetInput( );
250   const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
251   const unsigned int nThreads =
252     split->GetNumberOfSplits(
253       in->GetRequestedRegion( ), this->GetNumberOfThreads( )
254       );
255
256   itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
257   threads->SetNumberOfThreads( nThreads );
258   threads->SetSingleMethod( this->_LaplacianCbk< _TTriplets >, &thrStr );
259
260   // Execute threader
261   threads->SingleMethodExecute( );
262 }
263
264 // -------------------------------------------------------------------------
265 template< class _TImage, class _TLabels, class _TScalar >
266 template< class _TTriplets >
267 ITK_THREAD_RETURN_TYPE
268 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
269 _LaplacianCbk( void* arg )
270 {
271   _TLaplacianThreadStruct* thrStr;
272   itk::ThreadIdType total, thrId, thrCount;
273   itk::MultiThreader::ThreadInfoStruct* thrInfo =
274     reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
275   thrId = thrInfo->ThreadID;
276   thrCount = thrInfo->NumberOfThreads;
277   thrStr = reinterpret_cast< _TLaplacianThreadStruct* >( thrInfo->UserData );
278
279   TRegion region;
280   total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
281   if( thrId < total )
282     thrStr->Filter->_ThreadedLaplacian(
283       region, thrId,
284       reinterpret_cast< _TTriplets* >( thrStr->A ),
285       reinterpret_cast< _TTriplets* >( thrStr->R ),
286       reinterpret_cast< const _TTriplets* >( thrStr->B )
287       );
288   return( ITK_THREAD_RETURN_VALUE );
289 }
290
291 // -------------------------------------------------------------------------
292 template< class _TImage, class _TLabels, class _TScalar >
293 template< class _TTriplets >
294 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
295 _ThreadedLaplacian(
296   const TRegion& region, const itk::ThreadIdType& id,
297   _TTriplets* A, _TTriplets* R, const _TTriplets* B
298   )
299 {
300   typedef itk::ImageRegionConstIteratorWithIndex< TImage > _TIt;
301   typedef typename _TTriplets::value_type _TTriplet;
302
303   const TImage* in = this->GetInput( );
304   const TLabels* in_labels = this->GetInputLabels( );
305   TRegion reqRegion = in->GetRequestedRegion( );
306   _TIt it( in, region );
307   for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
308   {
309     TIndex idx = it.GetIndex( );
310     bool iSeed = ( in_labels->GetPixel( idx ) != 0 );
311     unsigned long i = Self::_1D( idx, reqRegion );
312     unsigned long si;
313
314     // A's diagonal values
315     if( !iSeed )
316     {
317       si = Self::_NearSeedIndex( i, *B );
318       this->m_Mutex.Lock( );
319       A->push_back( _TTriplet( si, si, this->_L( idx, idx ) ) );
320       this->m_Mutex.Unlock( );
321     }
322     else
323       si = Self::_SeedIndex( i, *B );
324
325     // Neighbors (final matrix is symmetric)
326     for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
327     {
328       for( int s = -1; s <= 1; s += 2 )
329       {
330         TIndex jdx = idx;
331         jdx[ d ] += s;
332         if( reqRegion.IsInside( jdx ) )
333         {
334           TScalar L = this->_L( idx, jdx );
335           unsigned long j = Self::_1D( jdx, reqRegion );
336           bool jSeed = ( in_labels->GetPixel( jdx ) != 0 );
337           if( !jSeed )
338           {
339             unsigned long sj = Self::_NearSeedIndex( j, *B );
340             if( !iSeed )
341             {
342               this->m_Mutex.Lock( );
343               A->push_back( _TTriplet( si, sj, L ) );
344               this->m_Mutex.Unlock( );
345             }
346             else
347             {
348               this->m_Mutex.Lock( );
349               R->push_back( _TTriplet( sj, si, -L ) );
350               this->m_Mutex.Unlock( );
351
352             } // fi
353
354           } // fi
355
356         } // fi
357
358       } // rof
359
360     } // rof
361
362   } // rof
363 }
364
365 // -------------------------------------------------------------------------
366 template< class _TImage, class _TLabels, class _TScalar >
367 template< class _TMatrix, class _TTriplets >
368 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
369 _Output(
370   const _TMatrix& X, const _TTriplets& S, const std::vector< TLabel >& invLabels
371   )
372 {
373   // Set up the multithreaded processing
374   _TOutputThreadStruct thrStr;
375   thrStr.Filter = this;
376   thrStr.X = reinterpret_cast< const void* >( &X );
377   thrStr.S = reinterpret_cast< const void* >( &S );
378   thrStr.InvLabels = &invLabels;
379
380   // Configure threader
381   const TLabels* out = this->GetOutput( );
382   const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
383   const unsigned int nThreads =
384     split->GetNumberOfSplits(
385       out->GetRequestedRegion( ), this->GetNumberOfThreads( )
386       );
387
388   itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
389   threads->SetNumberOfThreads( nThreads );
390   threads->SetSingleMethod(
391     this->_OutputCbk< _TMatrix, _TTriplets >, &thrStr
392     );
393
394   // Execute threader
395   threads->SingleMethodExecute( );
396 }
397
398 // -------------------------------------------------------------------------
399 template< class _TImage, class _TLabels, class _TScalar >
400 template< class _TMatrix, class _TTriplets >
401 ITK_THREAD_RETURN_TYPE
402 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
403 _OutputCbk( void* arg )
404 {
405   _TOutputThreadStruct* thrStr;
406   itk::ThreadIdType total, thrId, thrCount;
407   itk::MultiThreader::ThreadInfoStruct* thrInfo =
408     reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
409   thrId = thrInfo->ThreadID;
410   thrCount = thrInfo->NumberOfThreads;
411   thrStr = reinterpret_cast< _TOutputThreadStruct* >( thrInfo->UserData );
412
413   TRegion region;
414   total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
415   if( thrId < total )
416     thrStr->Filter->_ThreadedOutput(
417       region, thrId,
418       reinterpret_cast< const _TMatrix* >( thrStr->X ),
419       reinterpret_cast< const _TTriplets* >( thrStr->S ),
420       thrStr->InvLabels
421       );
422   return( ITK_THREAD_RETURN_VALUE );
423 }
424
425 // -------------------------------------------------------------------------
426 template< class _TImage, class _TLabels, class _TScalar >
427 template< class _TMatrix, class _TTriplets >
428 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
429 _ThreadedOutput(
430   const TRegion& region, const itk::ThreadIdType& id,
431   const _TMatrix* X, const _TTriplets* S,
432   const std::vector< TLabel >* invLabels
433   )
434 {
435   // Fill outputs
436   const TLabels* in_labels = this->GetInputLabels( );
437   TLabels* out_labels = this->GetOutput( );
438   TScalarImage* out_probs = this->GetOutputProbabilities( );
439   TRegion reqRegion = out_labels->GetRequestedRegion( );
440   itk::ImageRegionConstIteratorWithIndex< TLabels > iIt( in_labels, region );
441   itk::ImageRegionIteratorWithIndex< TLabels > oIt( out_labels, region );
442   itk::ImageRegionIteratorWithIndex< TScalarImage > pIt( out_probs, region );
443   iIt.GoToBegin( );
444   oIt.GoToBegin( );
445   pIt.GoToBegin( );
446   for( ; !iIt.IsAtEnd( ); ++iIt, ++oIt, ++pIt )
447   {
448     if( iIt.Get( ) == 0 )
449     {
450       unsigned long i = Self::_1D( iIt.GetIndex( ), reqRegion );
451       unsigned long j = Self::_NearSeedIndex( i, *S );
452       TScalar maxP = X->coeff( j, 0 );
453       unsigned long maxL = 0;
454       for( unsigned int s = 1; s < invLabels->size( ); ++s )
455       {
456         TScalar p = X->coeff( j, s );
457         if( maxP <= p )
458         {
459           maxP = p;
460           maxL = s;
461
462         } // fi
463
464       } // rof
465       oIt.Set( ( *invLabels )[ maxL ] );
466       pIt.Set( maxP );
467     }
468     else
469     {
470       oIt.Set( iIt.Get( ) );
471       pIt.Set( TScalar( 1 ) );
472
473     } // fi
474
475   } // rof
476 }
477
478 // -------------------------------------------------------------------------
479 template< class _TImage, class _TLabels, class _TScalar >
480 unsigned long
481 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
482 _1D( const TIndex& idx, const TRegion& region )
483 {
484   unsigned long i = idx[ 0 ];
485   unsigned long off = 1;
486   typename TRegion::SizeType size = region.GetSize( );
487   for( unsigned int d = 1; d < TIndex::Dimension; ++d )
488   {
489     off *= size[ d - 1 ];
490     i += idx[ d ] * off;
491
492   } // rof
493   return( i );
494 }
495
496 // -------------------------------------------------------------------------
497 template< class _TImage, class _TLabels, class _TScalar >
498 template< class _TTriplets >
499 unsigned long
500 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
501 _SeedIndex( const unsigned long& i, const _TTriplets& t )
502 {
503   unsigned long s = 0;
504   unsigned long f = t.size( );
505   unsigned long e = f - 1;
506   while( e > s && f == t.size( ) )
507   {
508     if( e > s + 1 )
509     {
510       unsigned long h = ( e + s ) >> 1;
511       if     ( i < t[ h ].row( ) ) e = h;
512       else if( t[ h ].row( ) < i ) s = h;
513       else                         f = h;
514     }
515     else
516       f = ( t[ s ].row( ) == i )? s: e;
517
518   } // elihw
519   return( f );
520 }
521
522 // -------------------------------------------------------------------------
523 template< class _TImage, class _TLabels, class _TScalar >
524 template< class _TTriplets >
525 unsigned long
526 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
527 _NearSeedIndex( const unsigned long& i, const _TTriplets& t )
528 {
529   long s = 0;
530   long e = t.size( ) - 1;
531   while( e > s + 1 )
532   {
533     long h = ( e + s ) >> 1;
534     if     ( i < t[ h ].row( ) ) e = h;
535     else if( t[ h ].row( ) < i ) s = h;
536
537   } // elihw
538   long d;
539   if( i < t[ s ].row( ) )
540     d = -1;
541   else if( t[ s ].row( ) < i && i < t[ e ].row( ) )
542     d = s + 1;
543   else
544     d = e + 1;
545   if( d < 0 ) d = 0;
546   return( i - d );
547 }
548
549 #endif // __fpa__Common__OriginalRandomWalker__hxx__
550 // eof - $RCSfile$