]> Creatis software - CreaPhase.git/blob - octave_packages/statistics-1.1.3/mnrnd.m
Add a useful package (from Source forge) for octave
[CreaPhase.git] / octave_packages / statistics-1.1.3 / mnrnd.m
1 ## Copyright (C) 2012  Arno Onken
2 ##
3 ## This program is free software: you can redistribute it and/or modify
4 ## it under the terms of the GNU General Public License as published by
5 ## the Free Software Foundation, either version 3 of the License, or
6 ## (at your option) any later version.
7 ##
8 ## This program is distributed in the hope that it will be useful,
9 ## but WITHOUT ANY WARRANTY; without even the implied warranty of
10 ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 ## GNU General Public License for more details.
12 ##
13 ## You should have received a copy of the GNU General Public License
14 ## along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16 ## -*- texinfo -*-
17 ## @deftypefn {Function File} {@var{x} =} mnrnd (@var{n}, @var{p})
18 ## @deftypefnx {Function File} {@var{x} =} mnrnd (@var{n}, @var{p}, @var{s})
19 ## Generate random samples from the multinomial distribution.
20 ##
21 ## @subheading Arguments
22 ##
23 ## @itemize @bullet
24 ## @item
25 ## @var{n} is the first parameter of the multinomial distribution. @var{n} can
26 ## be scalar or a vector containing the number of trials of each multinomial
27 ## sample. The elements of @var{n} must be non-negative integers.
28 ##
29 ## @item
30 ## @var{p} is the second parameter of the multinomial distribution. @var{p} can
31 ## be a vector with the probabilities of the categories or a matrix with each
32 ## row containing the probabilities of a multinomial sample. If @var{p} has
33 ## more than one row and @var{n} is non-scalar, then the number of rows of
34 ## @var{p} must match the number of elements of @var{n}.
35 ##
36 ## @item
37 ## @var{s} is the number of multinomial samples to be generated. @var{s} must
38 ## be a non-negative integer. If @var{s} is specified, then @var{n} must be
39 ## scalar and @var{p} must be a vector.
40 ## @end itemize
41 ##
42 ## @subheading Return values
43 ##
44 ## @itemize @bullet
45 ## @item
46 ## @var{x} is a matrix of random samples from the multinomial distribution with
47 ## corresponding parameters @var{n} and @var{p}. Each row corresponds to one
48 ## multinomial sample. The number of columns, therefore, corresponds to the
49 ## number of columns of @var{p}. If @var{s} is not specified, then the number
50 ## of rows of @var{x} is the maximum of the number of elements of @var{n} and
51 ## the number of rows of @var{p}. If a row of @var{p} does not sum to @code{1},
52 ## then the corresponding row of @var{x} will contain only @code{NaN} values.
53 ## @end itemize
54 ##
55 ## @subheading Examples
56 ##
57 ## @example
58 ## @group
59 ## n = 10;
60 ## p = [0.2, 0.5, 0.3];
61 ## x = mnrnd (n, p);
62 ## @end group
63 ##
64 ## @group
65 ## n = 10 * ones (3, 1);
66 ## p = [0.2, 0.5, 0.3];
67 ## x = mnrnd (n, p);
68 ## @end group
69 ##
70 ## @group
71 ## n = (1:2)';
72 ## p = [0.2, 0.5, 0.3; 0.1, 0.1, 0.8];
73 ## x = mnrnd (n, p);
74 ## @end group
75 ## @end example
76 ##
77 ## @subheading References
78 ##
79 ## @enumerate
80 ## @item
81 ## Wendy L. Martinez and Angel R. Martinez. @cite{Computational Statistics
82 ## Handbook with MATLAB}. Appendix E, pages 547-557, Chapman & Hall/CRC, 2001.
83 ##
84 ## @item
85 ## Merran Evans, Nicholas Hastings and Brian Peacock. @cite{Statistical
86 ## Distributions}. pages 134-136, Wiley, New York, third edition, 2000.
87 ## @end enumerate
88 ## @end deftypefn
89
90 ## Author: Arno Onken <asnelt@asnelt.org>
91 ## Description: Random samples from the multinomial distribution
92
93 function x = mnrnd (n, p, s)
94
95   # Check arguments
96   if (nargin == 3)
97     if (! isscalar (n) || n < 0 || round (n) != n)
98       error ("mnrnd: n must be a non-negative integer");
99     endif
100     if (! isvector (p) || any (p < 0 | p > 1))
101       error ("mnrnd: p must be a vector of probabilities");
102     endif
103     if (! isscalar (s) || s < 0 || round (s) != s)
104       error ("mnrnd: s must be a non-negative integer");
105     endif
106   elseif (nargin == 2)
107     if (isvector (p) && size (p, 1) > 1)
108       p = p';
109     endif
110     if (! isvector (n) || any (n < 0 | round (n) != n) || size (n, 2) > 1)
111       error ("mnrnd: n must be a non-negative integer column vector");
112     endif
113     if (! ismatrix (p) || isempty (p) || any (p < 0 | p > 1))
114       error ("mnrnd: p must be a non-empty matrix with rows of probabilities");
115     endif
116     if (! isscalar (n) && size (p, 1) > 1 && length (n) != size (p, 1))
117       error ("mnrnd: the length of n must match the number of rows of p");
118     endif
119   else
120     print_usage ();
121   endif
122
123   # Adjust input sizes
124   if (nargin == 3)
125     n = n * ones (s, 1);
126     p = repmat (p(:)', s, 1);
127   elseif (nargin == 2)
128     if (isscalar (n) && size (p, 1) > 1)
129       n = n * ones (size (p, 1), 1);
130     elseif (size (p, 1) == 1)
131       p = repmat (p, length (n), 1);
132     endif
133   endif
134   sz = size (p);
135
136   # Upper bounds of categories
137   ub = cumsum (p, 2);
138   # Make sure that the greatest upper bound is 1
139   gub = ub(:, end);
140   ub(:, end) = 1;
141   # Lower bounds of categories
142   lb = [zeros(sz(1), 1) ub(:, 1:(end-1))];
143
144   # Draw multinomial samples
145   x = zeros (sz);
146   for i = 1:sz(1)
147     # Draw uniform random numbers
148     r = repmat (rand (n(i), 1), 1, sz(2));
149     # Compare the random numbers of r to the cumulated probabilities of p and
150     # count the number of samples for each category
151     x(i, :) =  sum (r <= repmat (ub(i, :), n(i), 1) & r > repmat (lb(i, :), n(i), 1), 1);
152   endfor
153   # Set invalid rows to NaN
154   k = (abs (gub - 1) > 1e-6);
155   x(k, :) = NaN;
156
157 endfunction
158
159 %!test
160 %! n = 10;
161 %! p = [0.2, 0.5, 0.3];
162 %! x = mnrnd (n, p);
163 %! assert (size (x), size (p));
164 %! assert (all (x >= 0));
165 %! assert (all (round (x) == x));
166 %! assert (sum (x) == n);
167
168 %!test
169 %! n = 10 * ones (3, 1);
170 %! p = [0.2, 0.5, 0.3];
171 %! x = mnrnd (n, p);
172 %! assert (size (x), [length(n), length(p)]);
173 %! assert (all (x >= 0));
174 %! assert (all (round (x) == x));
175 %! assert (all (sum (x, 2) == n));
176
177 %!test
178 %! n = (1:2)';
179 %! p = [0.2, 0.5, 0.3; 0.1, 0.1, 0.8];
180 %! x = mnrnd (n, p);
181 %! assert (size (x), size (p));
182 %! assert (all (x >= 0));
183 %! assert (all (round (x) == x));
184 %! assert (all (sum (x, 2) == n));