diff options
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java')
-rw-r--r-- | src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java | 64 |
1 files changed, 35 insertions, 29 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java index 686e2076..ed92190d 100644 --- a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java +++ b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java @@ -4,7 +4,7 @@ package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans; This file is part of ELKI: Environment for Developing KDD-Applications Supported by Index-Structures - Copyright (C) 2013 + Copyright (C) 2014 Ludwig-Maximilians-Universität München Lehr- und Forschungseinheit für Datenbanksysteme ELKI Development Team @@ -26,6 +26,7 @@ package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans; import java.util.ArrayList; import java.util.List; +import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.initialization.KMeansInitialization; import de.lmu.ifi.dbs.elki.data.Cluster; import de.lmu.ifi.dbs.elki.data.Clustering; import de.lmu.ifi.dbs.elki.data.NumberVector; @@ -35,19 +36,20 @@ import de.lmu.ifi.dbs.elki.database.datastore.DataStoreFactory; import de.lmu.ifi.dbs.elki.database.datastore.DataStoreUtil; import de.lmu.ifi.dbs.elki.database.datastore.WritableIntegerDataStore; import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil; +import de.lmu.ifi.dbs.elki.database.ids.DBIDs; import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs; import de.lmu.ifi.dbs.elki.database.relation.Relation; -import de.lmu.ifi.dbs.elki.database.relation.RelationUtil; import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDistanceFunction; -import de.lmu.ifi.dbs.elki.distance.distancevalue.Distance; import de.lmu.ifi.dbs.elki.logging.Logging; import de.lmu.ifi.dbs.elki.logging.progress.IndefiniteProgress; +import de.lmu.ifi.dbs.elki.logging.statistics.DoubleStatistic; +import de.lmu.ifi.dbs.elki.math.linearalgebra.Vector; import de.lmu.ifi.dbs.elki.utilities.documentation.Description; import de.lmu.ifi.dbs.elki.utilities.documentation.Reference; import de.lmu.ifi.dbs.elki.utilities.documentation.Title; /** - * Provides the k-means algorithm, using Lloyd-style bulk iterations. + * The standard k-means algorithm, using Lloyd-style bulk iterations. * * <p> * Reference:<br /> @@ -63,12 +65,14 @@ import de.lmu.ifi.dbs.elki.utilities.documentation.Title; * @apiviz.has KMeansModel * * @param <V> vector datatype - * @param <D> distance value type */ @Title("K-Means") -@Description("Finds a partitioning into k clusters.") -@Reference(authors = "S. Lloyd", title = "Least squares quantization in PCM", booktitle = "IEEE Transactions on Information Theory 28 (2): 129–137.", url = "http://dx.doi.org/10.1109/TIT.1982.1056489") -public class KMeansLloyd<V extends NumberVector<?>, D extends Distance<D>> extends AbstractKMeans<V, D, KMeansModel<V>> { +@Description("Finds a least-squared partitioning into k clusters.") +@Reference(authors = "S. Lloyd", // +title = "Least squares quantization in PCM", // +booktitle = "IEEE Transactions on Information Theory 28 (2): 129–137.", // +url = "http://dx.doi.org/10.1109/TIT.1982.1056489") +public class KMeansLloyd<V extends NumberVector> extends AbstractKMeans<V, KMeansModel> { /** * The logger for this class. */ @@ -82,47 +86,49 @@ public class KMeansLloyd<V extends NumberVector<?>, D extends Distance<D>> exten * @param maxiter Maxiter parameter * @param initializer Initialization method */ - public KMeansLloyd(PrimitiveDistanceFunction<NumberVector<?>, D> distanceFunction, int k, int maxiter, KMeansInitialization<V> initializer) { + public KMeansLloyd(PrimitiveDistanceFunction<NumberVector> distanceFunction, int k, int maxiter, KMeansInitialization<? super V> initializer) { super(distanceFunction, k, maxiter, initializer); } @Override - public Clustering<KMeansModel<V>> run(Database database, Relation<V> relation) { - if (relation.size() <= 0) { + public Clustering<KMeansModel> run(Database database, Relation<V> relation) { + if(relation.size() <= 0) { return new Clustering<>("k-Means Clustering", "kmeans-clustering"); } // Choose initial means - List<? extends NumberVector<?>> means = initializer.chooseInitialMeans(database, relation, k, getDistanceFunction()); + List<Vector> means = initializer.chooseInitialMeans(database, relation, k, getDistanceFunction(), Vector.FACTORY); // Setup cluster assignment store List<ModifiableDBIDs> clusters = new ArrayList<>(); - for (int i = 0; i < k; i++) { + for(int i = 0; i < k; i++) { clusters.add(DBIDUtil.newHashSet((int) (relation.size() * 2. / k))); } WritableIntegerDataStore assignment = DataStoreUtil.makeIntegerStorage(relation.getDBIDs(), DataStoreFactory.HINT_TEMP | DataStoreFactory.HINT_HOT, -1); + double[] varsum = new double[k]; IndefiniteProgress prog = LOG.isVerbose() ? new IndefiniteProgress("K-Means iteration", LOG) : null; - for (int iteration = 0; maxiter <= 0 || iteration < maxiter; iteration++) { - if (prog != null) { - prog.incrementProcessed(LOG); - } - boolean changed = assignToNearestCluster(relation, means, clusters, assignment); + DoubleStatistic varstat = LOG.isStatistics() ? new DoubleStatistic(this.getClass().getName() + ".variance-sum") : null; + for(int iteration = 0; maxiter <= 0 || iteration < maxiter; iteration++) { + LOG.incrementProcessed(prog); + boolean changed = assignToNearestCluster(relation, means, clusters, assignment, varsum); + logVarstat(varstat, varsum); // Stop if no cluster assignment changed. - if (!changed) { + if(!changed) { break; } // Recompute means. means = means(clusters, means, relation); } - if (prog != null) { - prog.setCompleted(LOG); - } + LOG.setCompleted(prog); // Wrap result - final NumberVector.Factory<V, ?> factory = RelationUtil.getNumberVectorFactory(relation); - Clustering<KMeansModel<V>> result = new Clustering<>("k-Means Clustering", "kmeans-clustering"); - for (int i = 0; i < clusters.size(); i++) { - KMeansModel<V> model = new KMeansModel<>(factory.newNumberVector(means.get(i).getColumnVector().getArrayRef())); - result.addToplevelCluster(new Cluster<>(clusters.get(i), model)); + Clustering<KMeansModel> result = new Clustering<>("k-Means Clustering", "kmeans-clustering"); + for(int i = 0; i < clusters.size(); i++) { + DBIDs ids = clusters.get(i); + if(ids.size() == 0) { + continue; + } + KMeansModel model = new KMeansModel(means.get(i), varsum[i]); + result.addToplevelCluster(new Cluster<>(ids, model)); } return result; } @@ -139,14 +145,14 @@ public class KMeansLloyd<V extends NumberVector<?>, D extends Distance<D>> exten * * @apiviz.exclude */ - public static class Parameterizer<V extends NumberVector<?>, D extends Distance<D>> extends AbstractKMeans.Parameterizer<V, D> { + public static class Parameterizer<V extends NumberVector> extends AbstractKMeans.Parameterizer<V> { @Override protected Logging getLogger() { return LOG; } @Override - protected KMeansLloyd<V, D> makeInstance() { + protected KMeansLloyd<V> makeInstance() { return new KMeansLloyd<>(distanceFunction, k, maxiter, initializer); } } |