## Copyright (C) 2011 Soren Hauberg ## ## This program is free software; you can redistribute it and/or modify ## it under the terms of the GNU General Public License as published by ## the Free Software Foundation; either version 3 of the License, or ## (at your option) any later version. ## ## This program is distributed in the hope that it will be useful, ## but WITHOUT ANY WARRANTY; without even the implied warranty of ## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the ## GNU General Public License for more details. ## ## You should have received a copy of the GNU General Public License ## along with this program; if not, write to the Free Software ## Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA function [centers, classes, covariances] = kmeans(data, centers, maxiter = Inf) ## Input checking if (!ismatrix(data) || !isreal(data)) error("kmeans: first input argument must be a DxN real data matrix"); endif if (!ismatrix(centers) || !isreal(centers)) error("kmeans: second input argument must be a DxC real matrix of initial cluster centers"); endif [D, N] = size(data); C = size(centers,2); if (size(centers, 1) != D) error("kmeans: number of rows in first and second input argument must be equal"); endif if (!isscalar(maxiter)) error("kmeans: third input argument must be a scalar corresponding to the maximum number of iterations"); endif if (maxiter < 1) error("kmeans: maximum number of iterations must be more than 1"); endif ## Run the algorithm D = zeros(C, N); iterations = 0; prevcenters = centers; while (true) ## Compute distances for i = 1:C D(i, :) = sum( ( data - repmat(centers(:,i), 1, N) ).^2 ); endfor ## Classify [tmp, classes] = min(D); ## Recompute centers for i = 1:C centers(:,i) = mean( data(:, classes==i), 2); endfor ## Check for convergence iterations ++; if (all(centers(:) == prevcenters(:)) || iterations >= maxiter) break; endif prevcenters = centers; endwhile ## Compute extra output arguments if requested if (nargout >= 3) covariances = cell(1, C); for i = 1:C covariances{i} = cov( data(:, classes==i)' ); endfor endif endfunction