]> Creatis software - FrontAlgorithms.git/blob - lib/fpa/Common/OriginalRandomWalker.hxx
...
[FrontAlgorithms.git] / lib / fpa / Common / OriginalRandomWalker.hxx
1 // =========================================================================
2 // @author Leonardo Florez Valencia
3 // @email florez-l@javeriana.edu.co
4 // =========================================================================
5 #ifndef __fpa__Common__OriginalRandomWalker__hxx__
6 #define __fpa__Common__OriginalRandomWalker__hxx__
7
8 #include <itkImageRegionConstIteratorWithIndex.h>
9 #include <itkImageRegionIteratorWithIndex.h>
10 #include <Eigen/Sparse>
11
12 // -------------------------------------------------------------------------
13 /* TODO
14    template< class _TImage, class _TLabels, class _TScalar >
15    void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
16    AddSeed( const TIndex& seed, const TLabel& label )
17    {
18    this->m_Seeds.push_back( seed );
19    this->m_Labels.push_back( label );
20    this->Modified( );
21    }
22 */
23
24 // -------------------------------------------------------------------------
25 template< class _TImage, class _TLabels, class _TScalar >
26 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
27 OriginalRandomWalker( )
28   : Superclass( )
29 {
30   fpaFilterInputConfigureMacro( InputLabels, TLabels );
31   fpaFilterOutputConfigureMacro( OutputProbabilities, TScalarImage );
32 }
33
34 // -------------------------------------------------------------------------
35 template< class _TImage, class _TLabels, class _TScalar >
36 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
37 ~OriginalRandomWalker( )
38 {
39 }
40
41 // -------------------------------------------------------------------------
42 template< class _TImage, class _TLabels, class _TScalar >
43 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
44 GenerateData( )
45 {
46   // Useful typedefs
47   typedef Eigen::Triplet< TScalar >         _TTriplet;
48   typedef std::vector< _TTriplet >          _TTriplets;
49   typedef Eigen::SparseMatrix< TScalar >    _TMatrix;
50   typedef Eigen::SimplicialLDLT< _TMatrix > _TSolver;
51
52   // Configure edge function
53   if( this->m_EdgeFunction.IsNull( ) )
54     itkExceptionMacro( << "Undefined edge function." );
55   const TImage* input = this->GetInput( );
56   this->m_EdgeFunction->SetDataObject( input );
57
58   // Allocate outputs
59   this->AllocateOutputs( );
60
61   // Persisting objects
62   _TMatrix A( 1, 1 ), C( 1, 1 );
63   _TTriplets St;
64   std::vector< TLabel > invLabels;
65
66   { // begin
67     // Build boundary triplets and count labels
68     _TTriplets Bt;
69     std::map< TLabel, unsigned long > labels;
70     itkDebugMacro( << "Building boundary matrix..." );
71     this->_Boundary( St, labels );
72     struct _TTripletsOrd
73     {
74       bool operator()( const _TTriplet& a, const _TTriplet& b )
75         {
76           return( a.row( ) < b.row( ) );
77         }
78     };
79     itkDebugMacro( << "Sorting boundary pixels..." );
80     std::sort( St.begin( ), St.end( ), _TTripletsOrd( ) );
81     itkDebugMacro( << "Assigning boundary pixels..." );
82     for( unsigned long i = 0; i < St.size( ); ++i )
83       Bt.push_back(
84         _TTriplet( i, labels[ St[ i ].col( ) ], St[ i ].value( ) )
85         );
86
87     // Laplacian triplets
88     itkDebugMacro( << "Building laplacian matrix..." );
89     _TTriplets At, Rt;
90     this->_Laplacian( At, Rt, St );
91
92     // Matrices
93     TRegion region = input->GetRequestedRegion( );
94     unsigned long nSeeds = St.size( );
95     unsigned long nLabels = labels.size( );
96     unsigned long N = region.GetNumberOfPixels( );
97
98     itkDebugMacro( << "Creating inverse labels..." );
99     invLabels.resize( nLabels );
100     for( typename std::map< TLabel, unsigned long >::value_type s: labels )
101       invLabels[ s.second ] = s.first;
102
103     itkDebugMacro( << "Creating B matrix..." );
104     _TMatrix B( nSeeds, nLabels );
105     B.setFromTriplets( Bt.begin( ), Bt.end( ) );
106     B.makeCompressed( );
107
108     itkDebugMacro( << "Creating R matrix..." );
109     _TMatrix R( N - nSeeds, nSeeds );
110     R.setFromTriplets( Rt.begin( ), Rt.end( ) );
111     R.makeCompressed( );
112
113     itkDebugMacro( << "Creating C matrix..." );
114     C = R * B;
115
116     itkDebugMacro( << "Creating A matrix..." );
117     A.resize( N - nSeeds, N - nSeeds );
118     A.setFromTriplets( At.begin( ), At.end( ) );
119     A.makeCompressed( );
120   } // end
121
122   // Solve dirichlet problem
123   _TSolver solver;
124   itkDebugMacro( << "Factorizing problem..." );
125   solver.compute( A );
126   if( solver.info( ) != Eigen::Success )
127     itkExceptionMacro( << "Error decomposing matrix." );
128   itkDebugMacro( << "Solving problem..." );
129   _TMatrix x = solver.solve( C );
130   if( solver.info( ) != Eigen::Success )
131     itkExceptionMacro( << "Error solving system." );
132
133   // Fill outputs
134   itkDebugMacro( << "Filling output..." );
135   this->_Output( x, St, invLabels );
136 }
137
138 // -------------------------------------------------------------------------
139 template< class _TImage, class _TLabels, class _TScalar >
140 _TScalar fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
141 _L( const TIndex& i, const TIndex& j )
142 {
143   if( i == j )
144   {
145     TRegion r = this->GetInput( )->GetRequestedRegion( );
146     TScalar s = TScalar( 0 );
147     for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
148     {
149       for( int n = -1; n <= 1; n += 2 )
150       {
151         TIndex k = i;
152         k[ d ] += n;
153         if( r.IsInside( k ) )
154           s += this->m_EdgeFunction->Evaluate( i, k );
155
156       } // rof
157
158     } // rof
159     return( s );
160   }
161   else
162     return( -( this->m_EdgeFunction->Evaluate( i, j ) ) );
163 }
164
165 // -------------------------------------------------------------------------
166 template< class _TImage, class _TLabels, class _TScalar >
167 template< class _TTriplets >
168 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
169 _Boundary( _TTriplets& B, std::map< TLabel, unsigned long >& labels )
170 {
171   B.clear( );
172
173   // Set up the multithreaded processing
174   _TBoundaryThreadStruct thrStr;
175   thrStr.Filter = this;
176   thrStr.Triplets = reinterpret_cast< void* >( &B );
177   thrStr.Labels = &labels;
178
179   // Configure threader
180   const TLabels* in_labels = this->GetInputLabels( );
181   const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
182   const unsigned int nThreads =
183     split->GetNumberOfSplits(
184       in_labels->GetRequestedRegion( ), this->GetNumberOfThreads( )
185       );
186
187   itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
188   threads->SetNumberOfThreads( nThreads );
189   threads->SetSingleMethod( this->_BoundaryCbk< _TTriplets >, &thrStr );
190
191   // Execute threader
192   threads->SingleMethodExecute( );
193 }
194
195 // -------------------------------------------------------------------------
196 template< class _TImage, class _TLabels, class _TScalar >
197 template< class _TTriplets >
198 ITK_THREAD_RETURN_TYPE
199 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
200 _BoundaryCbk( void* arg )
201 {
202   _TBoundaryThreadStruct* thrStr;
203   itk::ThreadIdType total, thrId, thrCount;
204   itk::MultiThreader::ThreadInfoStruct* thrInfo =
205     reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
206   thrId = thrInfo->ThreadID;
207   thrCount = thrInfo->NumberOfThreads;
208   thrStr = reinterpret_cast< _TBoundaryThreadStruct* >( thrInfo->UserData );
209
210   TRegion region;
211   total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
212   if( thrId < total )
213     thrStr->Filter->_ThreadedBoundary(
214       region, thrId,
215       reinterpret_cast< _TTriplets* >( thrStr->Triplets ),
216       thrStr->Labels
217       );
218   return( ITK_THREAD_RETURN_VALUE );
219 }
220
221 // -------------------------------------------------------------------------
222 template< class _TImage, class _TLabels, class _TScalar >
223 template< class _TTriplets >
224 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
225 _ThreadedBoundary(
226   const TRegion& region, const itk::ThreadIdType& id,
227   _TTriplets* B,
228   std::map< TLabel, unsigned long >* labels
229   )
230 {
231   typedef itk::ImageRegionConstIteratorWithIndex< TLabels > _TIt;
232   typedef typename std::map< TLabel, unsigned long >::value_type _TMapValue;
233   typedef typename std::map< unsigned long, TLabel >::value_type _TInvValue;
234   typedef typename _TTriplets::value_type _TTriplet;
235
236   const TLabels* in_labels = this->GetInputLabels( );
237   TRegion reqRegion = in_labels->GetRequestedRegion( );
238   _TIt it( in_labels, region );
239   for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
240   {
241     if( it.Get( ) != 0 )
242     {
243       unsigned long i = Self::_1D( it.GetIndex( ), reqRegion );
244       this->m_Mutex.Lock( );
245       B->push_back( _TTriplet( i, it.Get( ), TScalar( 1 ) ) );
246       if( labels->find( it.Get( ) ) == labels->end( ) )
247         labels->insert( _TMapValue( it.Get( ), labels->size( ) ) );
248       this->m_Mutex.Unlock( );
249
250     } // fi
251
252   } // rof
253 }
254
255 // -------------------------------------------------------------------------
256 template< class _TImage, class _TLabels, class _TScalar >
257 template< class _TTriplets >
258 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
259 _Laplacian( _TTriplets& A, _TTriplets& R, const _TTriplets& B )
260 {
261   A.clear( );
262   R.clear( );
263
264   // Set up the multithreaded processing
265   _TLaplacianThreadStruct thrStr;
266   thrStr.Filter = this;
267   thrStr.A = reinterpret_cast< void* >( &A );
268   thrStr.R = reinterpret_cast< void* >( &R );
269   thrStr.B = reinterpret_cast< const void* >( &B );
270
271   // Configure threader
272   const TImage* in = this->GetInput( );
273   const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
274   const unsigned int nThreads =
275     split->GetNumberOfSplits(
276       in->GetRequestedRegion( ), this->GetNumberOfThreads( )
277       );
278
279   itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
280   threads->SetNumberOfThreads( nThreads );
281   threads->SetSingleMethod( this->_LaplacianCbk< _TTriplets >, &thrStr );
282
283   // Execute threader
284   threads->SingleMethodExecute( );
285 }
286
287 // -------------------------------------------------------------------------
288 template< class _TImage, class _TLabels, class _TScalar >
289 template< class _TTriplets >
290 ITK_THREAD_RETURN_TYPE
291 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
292 _LaplacianCbk( void* arg )
293 {
294   _TLaplacianThreadStruct* thrStr;
295   itk::ThreadIdType total, thrId, thrCount;
296   itk::MultiThreader::ThreadInfoStruct* thrInfo =
297     reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
298   thrId = thrInfo->ThreadID;
299   thrCount = thrInfo->NumberOfThreads;
300   thrStr = reinterpret_cast< _TLaplacianThreadStruct* >( thrInfo->UserData );
301
302   TRegion region;
303   total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
304   if( thrId < total )
305     thrStr->Filter->_ThreadedLaplacian(
306       region, thrId,
307       reinterpret_cast< _TTriplets* >( thrStr->A ),
308       reinterpret_cast< _TTriplets* >( thrStr->R ),
309       reinterpret_cast< const _TTriplets* >( thrStr->B )
310       );
311   return( ITK_THREAD_RETURN_VALUE );
312 }
313
314 // -------------------------------------------------------------------------
315 template< class _TImage, class _TLabels, class _TScalar >
316 template< class _TTriplets >
317 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
318 _ThreadedLaplacian(
319   const TRegion& region, const itk::ThreadIdType& id,
320   _TTriplets* A, _TTriplets* R, const _TTriplets* B
321   )
322 {
323   typedef itk::ImageRegionConstIteratorWithIndex< TImage > _TIt;
324   typedef typename _TTriplets::value_type _TTriplet;
325
326   const TImage* in = this->GetInput( );
327   const TLabels* in_labels = this->GetInputLabels( );
328   TRegion reqRegion = in->GetRequestedRegion( );
329   _TIt it( in, region );
330   for( it.GoToBegin( ); !it.IsAtEnd( ); ++it )
331   {
332     TIndex idx = it.GetIndex( );
333     bool iSeed = ( in_labels->GetPixel( idx ) != 0 );
334     unsigned long i = Self::_1D( idx, reqRegion );
335     unsigned long si;
336
337     // A's diagonal values
338     if( !iSeed )
339     {
340       si = Self::_NearSeedIndex( i, *B );
341       this->m_Mutex.Lock( );
342       A->push_back( _TTriplet( si, si, this->_L( idx, idx ) ) );
343       this->m_Mutex.Unlock( );
344     }
345     else
346       si = Self::_SeedIndex( i, *B );
347
348     // Neighbors (final matrix is symmetric)
349     for( unsigned int d = 0; d < TImage::ImageDimension; ++d )
350     {
351       for( int s = -1; s <= 1; s += 2 )
352       {
353         TIndex jdx = idx;
354         jdx[ d ] += s;
355         if( reqRegion.IsInside( jdx ) )
356         {
357           TScalar L = this->_L( idx, jdx );
358           unsigned long j = Self::_1D( jdx, reqRegion );
359           bool jSeed = ( in_labels->GetPixel( jdx ) != 0 );
360           if( !jSeed )
361           {
362             unsigned long sj = Self::_NearSeedIndex( j, *B );
363             if( !iSeed )
364             {
365               this->m_Mutex.Lock( );
366               A->push_back( _TTriplet( si, sj, L ) );
367               this->m_Mutex.Unlock( );
368             }
369             else
370             {
371               this->m_Mutex.Lock( );
372               R->push_back( _TTriplet( sj, si, -L ) );
373               this->m_Mutex.Unlock( );
374
375             } // fi
376
377           } // fi
378
379         } // fi
380
381       } // rof
382
383     } // rof
384
385   } // rof
386 }
387
388 // -------------------------------------------------------------------------
389 template< class _TImage, class _TLabels, class _TScalar >
390 template< class _TMatrix, class _TTriplets >
391 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
392 _Output(
393   const _TMatrix& X, const _TTriplets& S, const std::vector< TLabel >& invLabels
394   )
395 {
396   // Set up the multithreaded processing
397   _TOutputThreadStruct thrStr;
398   thrStr.Filter = this;
399   thrStr.X = reinterpret_cast< const void* >( &X );
400   thrStr.S = reinterpret_cast< const void* >( &S );
401   thrStr.InvLabels = &invLabels;
402
403   // Configure threader
404   const TLabels* out = this->GetOutput( );
405   const itk::ImageRegionSplitterBase* split = this->GetImageRegionSplitter( );
406   const unsigned int nThreads =
407     split->GetNumberOfSplits(
408       out->GetRequestedRegion( ), this->GetNumberOfThreads( )
409       );
410
411   itk::MultiThreader::Pointer threads = itk::MultiThreader::New( );
412   threads->SetNumberOfThreads( nThreads );
413   threads->SetSingleMethod(
414     this->_OutputCbk< _TMatrix, _TTriplets >, &thrStr
415     );
416
417   // Execute threader
418   threads->SingleMethodExecute( );
419 }
420
421 // -------------------------------------------------------------------------
422 template< class _TImage, class _TLabels, class _TScalar >
423 template< class _TMatrix, class _TTriplets >
424 ITK_THREAD_RETURN_TYPE
425 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
426 _OutputCbk( void* arg )
427 {
428   _TOutputThreadStruct* thrStr;
429   itk::ThreadIdType total, thrId, thrCount;
430   itk::MultiThreader::ThreadInfoStruct* thrInfo =
431     reinterpret_cast< itk::MultiThreader::ThreadInfoStruct* >( arg );
432   thrId = thrInfo->ThreadID;
433   thrCount = thrInfo->NumberOfThreads;
434   thrStr = reinterpret_cast< _TOutputThreadStruct* >( thrInfo->UserData );
435
436   TRegion region;
437   total = thrStr->Filter->SplitRequestedRegion( thrId, thrCount, region );
438   if( thrId < total )
439     thrStr->Filter->_ThreadedOutput(
440       region, thrId,
441       reinterpret_cast< const _TMatrix* >( thrStr->X ),
442       reinterpret_cast< const _TTriplets* >( thrStr->S ),
443       thrStr->InvLabels
444       );
445   return( ITK_THREAD_RETURN_VALUE );
446 }
447
448 // -------------------------------------------------------------------------
449 template< class _TImage, class _TLabels, class _TScalar >
450 template< class _TMatrix, class _TTriplets >
451 void fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
452 _ThreadedOutput(
453   const TRegion& region, const itk::ThreadIdType& id,
454   const _TMatrix* X, const _TTriplets* S,
455   const std::vector< TLabel >* invLabels
456   )
457 {
458   // Fill outputs
459   const TLabels* in_labels = this->GetInputLabels( );
460   TLabels* out_labels = this->GetOutput( );
461   TScalarImage* out_probs = this->GetOutputProbabilities( );
462   TRegion reqRegion = out_labels->GetRequestedRegion( );
463   itk::ImageRegionConstIteratorWithIndex< TLabels > iIt( in_labels, region );
464   itk::ImageRegionIteratorWithIndex< TLabels > oIt( out_labels, region );
465   itk::ImageRegionIteratorWithIndex< TScalarImage > pIt( out_probs, region );
466   iIt.GoToBegin( );
467   oIt.GoToBegin( );
468   pIt.GoToBegin( );
469   for( ; !iIt.IsAtEnd( ); ++iIt, ++oIt, ++pIt )
470   {
471     if( iIt.Get( ) == 0 )
472     {
473       unsigned long i = Self::_1D( iIt.GetIndex( ), reqRegion );
474       unsigned long j = Self::_NearSeedIndex( i, *S );
475       TScalar maxP = X->coeff( j, 0 );
476       unsigned long maxL = 0;
477       for( unsigned int s = 1; s < invLabels->size( ); ++s )
478       {
479         TScalar p = X->coeff( j, s );
480         if( maxP <= p )
481         {
482           maxP = p;
483           maxL = s;
484
485         } // fi
486
487       } // rof
488       oIt.Set( ( *invLabels )[ maxL ] );
489       pIt.Set( maxP );
490     }
491     else
492     {
493       oIt.Set( iIt.Get( ) );
494       pIt.Set( TScalar( 1 ) );
495
496     } // fi
497
498   } // rof
499 }
500
501 // -------------------------------------------------------------------------
502 template< class _TImage, class _TLabels, class _TScalar >
503 unsigned long
504 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
505 _1D( const TIndex& idx, const TRegion& region )
506 {
507   unsigned long i = idx[ 0 ];
508   unsigned long off = 1;
509   typename TRegion::SizeType size = region.GetSize( );
510   for( unsigned int d = 1; d < TIndex::Dimension; ++d )
511   {
512     off *= size[ d - 1 ];
513     i += idx[ d ] * off;
514
515   } // rof
516   return( i );
517 }
518
519 // -------------------------------------------------------------------------
520 template< class _TImage, class _TLabels, class _TScalar >
521 template< class _TTriplets >
522 unsigned long
523 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
524 _SeedIndex( const unsigned long& i, const _TTriplets& t )
525 {
526   unsigned long s = 0;
527   unsigned long f = t.size( );
528   unsigned long e = f - 1;
529   while( e > s && f == t.size( ) )
530   {
531     if( e > s + 1 )
532     {
533       unsigned long h = ( e + s ) >> 1;
534       if     ( i < t[ h ].row( ) ) e = h;
535       else if( t[ h ].row( ) < i ) s = h;
536       else                         f = h;
537     }
538     else
539       f = ( t[ s ].row( ) == i )? s: e;
540
541   } // elihw
542   return( f );
543 }
544
545 // -------------------------------------------------------------------------
546 template< class _TImage, class _TLabels, class _TScalar >
547 template< class _TTriplets >
548 unsigned long
549 fpa::Common::OriginalRandomWalker< _TImage, _TLabels, _TScalar >::
550 _NearSeedIndex( const unsigned long& i, const _TTriplets& t )
551 {
552   long s = 0;
553   long e = t.size( ) - 1;
554   while( e > s + 1 )
555   {
556     long h = ( e + s ) >> 1;
557     if     ( i < t[ h ].row( ) ) e = h;
558     else if( t[ h ].row( ) < i ) s = h;
559
560   } // elihw
561   long d;
562   if( i < t[ s ].row( ) )
563     d = -1;
564   else if( t[ s ].row( ) < i && i < t[ e ].row( ) )
565     d = s + 1;
566   else
567     d = e + 1;
568   if( d < 0 ) d = 0;
569   return( i - d );
570 }
571
572 #endif // __fpa__Common__OriginalRandomWalker__hxx__
573 // eof - $RCSfile$