]> Creatis software - CreaPhase.git/blob - octave_packages/statistics-1.1.3/kmeans.m
Add a useful package (from Source forge) for octave
[CreaPhase.git] / octave_packages / statistics-1.1.3 / kmeans.m
1 ## Copyright (C) 2011 Soren Hauberg <soren@hauberg.org>
2 ## Copyright (C) 2012 Daniel Ward <dwa012@gmail.com>
3 ##
4 ## This program is free software; you can redistribute it and/or modify it under
5 ## the terms of the GNU General Public License as published by the Free Software
6 ## Foundation; either version 3 of the License, or (at your option) any later
7 ## version.
8 ##
9 ## This program is distributed in the hope that it will be useful, but WITHOUT
10 ## ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 ## FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
12 ## details.
13 ##
14 ## You should have received a copy of the GNU General Public License along with
15 ## this program; if not, see <http://www.gnu.org/licenses/>.
16
17 ## -*- texinfo -*-
18 ## @deftypefn {Function File} {[@var{idx}, @var{centers}] =} kmeans (@var{data}, @var{k}, @var{param1}, @var{value1}, @dots{})
19 ## K-means clustering.
20 ##
21 ## @seealso{linkage}
22 ## @end deftypefn
23
24 function [classes, centers, sumd, D] = kmeans (data, k, varargin)
25   [reg, prop] = parseparams (varargin);
26
27   ## defaults for options
28   emptyaction = "error";
29   start       = "sample";
30
31   #used for getting the number of samples
32   nRows = rows (data);
33
34   ## used to hold the distances from each sample to each class
35   D = zeros (nRows, k);
36   
37   #used for convergence of the centroids
38   err = 1;
39   
40   #initial sum of distances
41   sumd = Inf;
42   
43   ## Input checking, validate the matrix and k
44   if (!isnumeric (data) || !ismatrix (data) || !isreal (data))
45     error ("kmeans: first input argument must be a DxN real data matrix");
46   elseif (!isscalar (k))
47     error ("kmeans: second input argument must be a scalar");
48   endif
49   
50   if (length (varargin) > 0)
51     ## check for the 'emptyaction' property
52     found = find (strcmpi (prop, "emptyaction") == 1);
53     switch (lower (prop{found+1}))
54       case "singleton"
55         emptyaction = "singleton";
56       otherwise
57         error ("kmeans: unsupported empty cluster action parameter");
58     endswitch
59   endif
60   
61   ## check for the 'start' property
62   switch (lower (start))
63     case "sample"
64       idx = randperm (nRows) (1:k);
65       centers = data (idx, :);
66     otherwise
67       error ("kmeans: unsupported initial clustering parameter");
68   endswitch
69   
70   ## Run the algorithm
71   while err > .001
72     ## Compute distances
73     for i = 1:k
74       D (:, i) = sumsq (data - repmat (centers(i, :), nRows, 1), 2);
75     endfor
76     
77     ## Classify
78     [tmp, classes] = min (D, [], 2);
79     
80     ## Calcualte new centroids
81     for i = 1:k
82       ## Check for empty clusters
83       if (sum (classes == i) ==0 || length (mean (data(classes == i, :))) == 0)
84       
85         switch emptyaction
86           ## if 'singleton', then find the point that is the
87           ## farthest and add it to the empty cluster
88           case 'singleton'
89             classes(maxCostSampleIndex (data, centers(i,:))) = i;
90          ## if 'error' then throw the error
91           otherwise
92             error ("kmeans: empty cluster created");
93         endswitch
94      endif ## end check for empty clusters
95
96       ## update the centroids
97       centers(i, :) = mean (data(classes == i, :));
98     endfor
99
100     ## calculate the differnece in the sum of distances
101     err  = sumd - objCost (data, classes, centers);
102     ## update the current sum of distances
103     sumd = objCost (data, classes, centers);
104   endwhile
105 endfunction
106
107 ## calculate the sum of distances
108 function obj = objCost (data, classes, centers)
109   obj = 0;
110     for i=1:rows (data)
111       obj = obj + sumsq (data(i,:) - centers(classes(i),:));
112     endfor
113 endfunction
114
115 function index = maxCostSampleIndex (data, centers)
116   cost = 0;
117   for index = 1:rows (data)
118     if cost < sumsq (data(index,:) - centers)
119       cost = sumsq (data(index,:) - centers);
120     endif
121   endfor
122 endfunction
123
124 %!demo
125 %! ## Generate a two-cluster problem
126 %! C1 = randn (100, 2) + 1;
127 %! C2 = randn (100, 2) - 1;
128 %! data = [C1; C2];
129 %!
130 %! ## Perform clustering
131 %! [idx, centers] = kmeans (data, 2);
132 %!
133 %! ## Plot the result
134 %! figure
135 %! plot (data (idx==1, 1), data (idx==1, 2), 'ro');
136 %! hold on
137 %! plot (data (idx==2, 1), data (idx==2, 2), 'bs');
138 %! plot (centers (:, 1), centers (:, 2), 'kv', 'markersize', 10);
139 %! hold off