diff options
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansMacQueen.java')
-rw-r--r-- | src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansMacQueen.java | 70 |
1 files changed, 37 insertions, 33 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansMacQueen.java b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansMacQueen.java index a0f4bb3f..7d2f805a 100644 --- a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansMacQueen.java +++ b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansMacQueen.java @@ -4,7 +4,7 @@ package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans; This file is part of ELKI: Environment for Developing KDD-Applications Supported by Index-Structures - Copyright (C) 2013 + Copyright (C) 2014 Ludwig-Maximilians-Universität München Lehr- und Forschungseinheit für Datenbanksysteme ELKI Development Team @@ -26,6 +26,7 @@ package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans; import java.util.ArrayList; import java.util.List; +import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.initialization.KMeansInitialization; import de.lmu.ifi.dbs.elki.data.Cluster; import de.lmu.ifi.dbs.elki.data.Clustering; import de.lmu.ifi.dbs.elki.data.NumberVector; @@ -38,18 +39,22 @@ import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil; import de.lmu.ifi.dbs.elki.database.ids.DBIDs; import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs; import de.lmu.ifi.dbs.elki.database.relation.Relation; -import de.lmu.ifi.dbs.elki.database.relation.RelationUtil; import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDistanceFunction; -import de.lmu.ifi.dbs.elki.distance.distancevalue.Distance; import de.lmu.ifi.dbs.elki.logging.Logging; import de.lmu.ifi.dbs.elki.logging.progress.IndefiniteProgress; +import de.lmu.ifi.dbs.elki.logging.statistics.DoubleStatistic; import de.lmu.ifi.dbs.elki.math.linearalgebra.Vector; import de.lmu.ifi.dbs.elki.utilities.documentation.Description; import de.lmu.ifi.dbs.elki.utilities.documentation.Reference; import de.lmu.ifi.dbs.elki.utilities.documentation.Title; /** - * Provides the k-means algorithm, using MacQueen style incremental updates. + * The original k-means algorithm, using MacQueen style incremental updates; + * making this effectively an "online" (streaming) algorithm. + * + * This implementation will by default iterate over the data set until + * convergence, although MacQueen likely only meant to do a single pass over the + * data. * * <p> * Reference:<br /> @@ -62,12 +67,14 @@ import de.lmu.ifi.dbs.elki.utilities.documentation.Title; * @apiviz.has KMeansModel * * @param <V> vector type to use - * @param <D> distance function value type */ @Title("K-Means") -@Description("Finds a partitioning into k clusters.") -@Reference(authors = "J. MacQueen", title = "Some Methods for Classification and Analysis of Multivariate Observations", booktitle = "5th Berkeley Symp. Math. Statist. Prob., Vol. 1, 1967, pp 281-297", url = "http://projecteuclid.org/euclid.bsmsp/1200512992") -public class KMeansMacQueen<V extends NumberVector<?>, D extends Distance<D>> extends AbstractKMeans<V, D, KMeansModel<V>> { +@Description("Finds a least-squares partitioning into k clusters.") +@Reference(authors = "J. MacQueen", // +title = "Some Methods for Classification and Analysis of Multivariate Observations", // +booktitle = "5th Berkeley Symp. Math. Statist. Prob., Vol. 1, 1967, pp 281-297", // +url = "http://projecteuclid.org/euclid.bsmsp/1200512992") +public class KMeansMacQueen<V extends NumberVector> extends AbstractKMeans<V, KMeansModel> { /** * The logger for this class. */ @@ -81,47 +88,44 @@ public class KMeansMacQueen<V extends NumberVector<?>, D extends Distance<D>> ex * @param maxiter Maxiter parameter * @param initializer Initialization method */ - public KMeansMacQueen(PrimitiveDistanceFunction<NumberVector<?>, D> distanceFunction, int k, int maxiter, KMeansInitialization<V> initializer) { + public KMeansMacQueen(PrimitiveDistanceFunction<NumberVector> distanceFunction, int k, int maxiter, KMeansInitialization<? super V> initializer) { super(distanceFunction, k, maxiter, initializer); } @Override - public Clustering<KMeansModel<V>> run(Database database, Relation<V> relation) { - if (relation.size() <= 0) { + public Clustering<KMeansModel> run(Database database, Relation<V> relation) { + if(relation.size() <= 0) { return new Clustering<>("k-Means Clustering", "kmeans-clustering"); } // Choose initial means - List<Vector> means = new ArrayList<>(k); - for (NumberVector<?> nv : initializer.chooseInitialMeans(database, relation, k, getDistanceFunction())) { - means.add(nv.getColumnVector()); - } - // Initialize cluster and assign objects + List<Vector> means = initializer.chooseInitialMeans(database, relation, k, getDistanceFunction(), Vector.FACTORY); List<ModifiableDBIDs> clusters = new ArrayList<>(); - for (int i = 0; i < k; i++) { + for(int i = 0; i < k; i++) { clusters.add(DBIDUtil.newHashSet((int) (relation.size() * 2. / k))); } WritableIntegerDataStore assignment = DataStoreUtil.makeIntegerStorage(relation.getDBIDs(), DataStoreFactory.HINT_TEMP | DataStoreFactory.HINT_HOT, -1); + double[] varsum = new double[k]; IndefiniteProgress prog = LOG.isVerbose() ? new IndefiniteProgress("K-Means iteration", LOG) : null; - // Refine result - for (int iteration = 0; maxiter <= 0 || iteration < maxiter; iteration++) { - if (prog != null) { - prog.incrementProcessed(LOG); - } - boolean changed = macQueenIterate(relation, means, clusters, assignment); - if (!changed) { + DoubleStatistic varstat = LOG.isStatistics() ? new DoubleStatistic(this.getClass().getName() + ".variance-sum") : null; + // Iterate MacQueen + for(int iteration = 0; maxiter <= 0 || iteration < maxiter; iteration++) { + LOG.incrementProcessed(prog); + boolean changed = macQueenIterate(relation, means, clusters, assignment, varsum); + logVarstat(varstat, varsum); + if(!changed) { break; } } - if (prog != null) { - prog.setCompleted(LOG); - } + LOG.setCompleted(prog); - final NumberVector.Factory<V, ?> factory = RelationUtil.getNumberVectorFactory(relation); - Clustering<KMeansModel<V>> result = new Clustering<>("k-Means Clustering", "kmeans-clustering"); - for (int i = 0; i < clusters.size(); i++) { + Clustering<KMeansModel> result = new Clustering<>("k-Means Clustering", "kmeans-clustering"); + for(int i = 0; i < clusters.size(); i++) { DBIDs ids = clusters.get(i); - KMeansModel<V> model = new KMeansModel<>(factory.newNumberVector(means.get(i).getArrayRef())); + if(ids.size() == 0) { + continue; + } + KMeansModel model = new KMeansModel(means.get(i), varsum[i]); result.addToplevelCluster(new Cluster<>(ids, model)); } return result; @@ -139,14 +143,14 @@ public class KMeansMacQueen<V extends NumberVector<?>, D extends Distance<D>> ex * * @apiviz.exclude */ - public static class Parameterizer<V extends NumberVector<?>, D extends Distance<D>> extends AbstractKMeans.Parameterizer<V, D> { + public static class Parameterizer<V extends NumberVector> extends AbstractKMeans.Parameterizer<V> { @Override protected Logging getLogger() { return LOG; } @Override - protected KMeansMacQueen<V, D> makeInstance() { + protected KMeansMacQueen<V> makeInstance() { return new KMeansMacQueen<>(distanceFunction, k, maxiter, initializer); } } |