]> Creatis software - CreaPhase.git/blob - mldivide.m
3733aab008aeab5715000aafbe552976f4e9b342
[CreaPhase.git] / mldivide.m
1 ## Copyright (C) 2010 VZLU Prague
2 ## 
3 ## This program is free software; you can redistribute it and/or modify
4 ## it under the terms of the GNU General Public License as published by
5 ## the Free Software Foundation; either version 3 of the License, or
6 ## (at your option) any later version.
7 ## 
8 ## This program is distributed in the hope that it will be useful,
9 ## but WITHOUT ANY WARRANTY; without even the implied warranty of
10 ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 ## GNU General Public License for more details.
12 ## 
13 ## You should have received a copy of the GNU General Public License
14 ## along with Octave; see the file COPYING.  If not, see
15 ## <http://www.gnu.org/licenses/>.
16
17 ## -*- texinfo -*-
18 ## @deftypefn {Function File} mldivide (@var{x}, @var{y})
19 ## Performs a left division with a block sparse matrix.
20 ## If @var{x} is a block sparse matrix, it must be either diagonal
21 ## or triangular, and @var{y} must be full.
22 ## If @var{x} is built-in sparse or full, @var{y} is converted
23 ## accordingly, then the built-in division is used.
24 ## @end deftypefn
25
26 function c = mldivide (a, b)
27   if (isa (a, "blksparse"))
28     if (issparse (b))
29       error ("blksparse: block sparse \\ sparse not implemented");
30     else
31       c = mldivide_sm (a, b);
32     endif
33   elseif (issparse (a))
34     c = a \ sparse (b);
35   else
36     c = a \ full (b);
37   endif
38 endfunction
39
40 function y = mldivide_sm (s, x)
41   siz = s.siz;
42   bsiz = s.bsiz;
43   if (bsiz(1) != bsiz(2) || siz(1) != siz(2))
44     error ("blksparse: can only divide by square matrices with square blocks");
45   endif
46
47   ## Check sizes.
48   [xr, xc] = size (x);
49   if (xr != siz(1)*bsiz(1))
50     gripe_nonconformant (siz.*bsiz, [xr, xc]);
51   endif
52
53   if (isempty (s) || isempty (x))
54     y = x;
55     return;
56   endif
57
58   ## Form blocks.
59   x = reshape (x, bsiz(1), siz(1), xc);
60   x = permute (x, [1, 3, 2]);
61
62   sv = s.sv;
63   si = s.i;
64   sj = s.j;
65   ns = size (sv, 3);
66
67   n = siz(1);
68   nb = bsiz(1);
69   d = find (si == sj);
70   full_diag = length (d) == n;
71
72   isdiag = full_diag && ns == n; # block diagonal
73   islower = full_diag && all (si >= sj); # block upper triangular
74   isupper = full_diag && all (si <= sj); # block lower triangular
75
76   if (isdiag)
77     xx = num2cell (x, [1, 2]);
78     ss = num2cell (sv, [1, 2]);
79     yy = cellfun (@mldivide, ss, xx, "uniformoutput", false);
80     y = cat (3, yy{:});
81     clear xx ss yy;
82   elseif (islower)
83     y = x;
84     ## this is the axpy version
85     for j = 1:n-1
86       y(:,:,j) = sv(:,:,d(j)) \ y(:,:,j);
87       k = d(j)+1:d(j+1)-1;
88       xy = y(:,:,j*ones (1, length (k)));
89       y(:,:,si(k)) -= blkmm (sv(:,:,k), xy);
90     endfor
91     y(:,:,n) = sv(:,:,ns) \ y(:,:,n);
92   elseif (isupper)
93     y = x;
94     ## this is the axpy version
95     for j = n:-1:2
96       y(:,:,j) = sv(:,:,d(j)) \ y(:,:,j);
97       k = d(j-1)+1:d(j)-1;
98       xy = y(:,:,j*ones (1, length (k)));
99       y(:,:,si(k)) -= blkmm (sv(:,:,k), xy);
100     endfor
101     y(:,:,1) = sv(:,:,1) \ y(:,:,1);
102   else
103     error ("blksparse: mldivide: matrix must be block triangular or diagonal");
104   endif
105
106   ## Narrow blocks.
107   y = permute (y, [1, 3, 2]);
108   y = reshape (y, bsiz(1)*siz(1), xc);
109 endfunction
110
111 function gripe_nonconformant (s1, s2, what = "arguments")
112   error ("Octave:nonconformant-args", ...
113   "nonconformant %s (op1 is %dx%d, op2 is %dx%d)", what, s1, s2);
114 endfunction
115