1 ## Copyright (C) 2009 VZLU Prague, a.s., Czech Republic
3 ## This program is free software; you can redistribute it and/or modify it under
4 ## the terms of the GNU General Public License as published by the Free Software
5 ## Foundation; either version 3 of the License, or (at your option) any later
8 ## This program is distributed in the hope that it will be useful, but WITHOUT
9 ## ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
10 ## FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
13 ## You should have received a copy of the GNU General Public License along with
14 ## this program; if not, see <http://www.gnu.org/licenses/>.
17 ## @deftypefn{Function File} {@var{x} =} smwsolve (@var{a}, @var{u}, @var{v}, @var{b})
18 ## @deftypefnx{Function File} {} smwsolve (@var{solver}, @var{u}, @var{v}, @var{b})
19 ## Solves the square system @code{(A + U*V')*X == B}, where @var{u} and @var{v} are
20 ## matrices with several columns, using the Sherman-Morrison-Woodbury formula,
21 ## so that a system with @var{a} as left-hand side is actually solved. This is
22 ## especially advantageous if @var{a} is diagonal, sparse, triangular or
24 ## @var{a} can be sparse or full, the other matrices are expected to be full.
25 ## Instead of a matrix @var{a}, a user may alternatively provide a function
26 ## @var{solver} that performs the left division operation.
29 ## Author: Jaroslav Hajek <highegg@gmail.com>
31 function x = smwsolve (a, u, v, b)
39 if (n != columns (v) || rows (a) != rows (u) || columns (a) != rows (v))
40 error ("smwsolve: dimension mismatch");
41 elseif (! issquare (a))
42 error ("smwsolve: need a square matrix");
51 elseif (isa (a, "function_handle"))
53 if (rows (xx) != rows (a) || columns (xx) != (nc + n))
54 error ("smwsolve: invalid result from a solver function");
57 error ("smwsolve: a must be a matrix or function handle");
65 vy = vxx(:,nc+1:nc+n);
67 x = x - y * ((eye (n) + vy) \ vx);
73 %! u = rand (10, 2); u /= diag (norm (u, "cols"));
74 %! v = rand (10, 2); v /= diag (norm (v, "cols"));
76 %! x1 = (A + u*v') \ b;
77 %! x2 = smwsolve (A, u, v, b);
78 %! assert (x1, x2, 1e-13);