diff options
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansHybridLloydMacQueen.java')
-rw-r--r-- | src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansHybridLloydMacQueen.java | 71 |
1 files changed, 34 insertions, 37 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansHybridLloydMacQueen.java b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansHybridLloydMacQueen.java index 2a60ef27..a978745a 100644 --- a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansHybridLloydMacQueen.java +++ b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansHybridLloydMacQueen.java @@ -4,7 +4,7 @@ package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans; This file is part of ELKI: Environment for Developing KDD-Applications Supported by Index-Structures - Copyright (C) 2013 + Copyright (C) 2014 Ludwig-Maximilians-Universität München Lehr- und Forschungseinheit für Datenbanksysteme ELKI Development Team @@ -26,6 +26,7 @@ package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans; import java.util.ArrayList; import java.util.List; +import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.initialization.KMeansInitialization; import de.lmu.ifi.dbs.elki.data.Cluster; import de.lmu.ifi.dbs.elki.data.Clustering; import de.lmu.ifi.dbs.elki.data.NumberVector; @@ -35,28 +36,26 @@ import de.lmu.ifi.dbs.elki.database.datastore.DataStoreFactory; import de.lmu.ifi.dbs.elki.database.datastore.DataStoreUtil; import de.lmu.ifi.dbs.elki.database.datastore.WritableIntegerDataStore; import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil; +import de.lmu.ifi.dbs.elki.database.ids.DBIDs; import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs; import de.lmu.ifi.dbs.elki.database.relation.Relation; -import de.lmu.ifi.dbs.elki.database.relation.RelationUtil; import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDistanceFunction; -import de.lmu.ifi.dbs.elki.distance.distancevalue.Distance; import de.lmu.ifi.dbs.elki.logging.Logging; import de.lmu.ifi.dbs.elki.logging.progress.IndefiniteProgress; +import de.lmu.ifi.dbs.elki.logging.statistics.DoubleStatistic; import de.lmu.ifi.dbs.elki.math.linearalgebra.Vector; /** - * Provides the k-means algorithm, alternating between MacQueen-style - * incremental processing and Lloyd-Style batch steps. + * A hybrid k-means algorithm, alternating between MacQueen-style incremental + * processing and Lloyd-Style batch steps. * * @author Erich Schubert * - * @apiviz.landmark * @apiviz.has KMeansModel * * @param <V> vector datatype - * @param <D> distance value type */ -public class KMeansHybridLloydMacQueen<V extends NumberVector<?>, D extends Distance<D>> extends AbstractKMeans<V, D, KMeansModel<V>> { +public class KMeansHybridLloydMacQueen<V extends NumberVector> extends AbstractKMeans<V, KMeansModel> { /** * The logger for this class. */ @@ -70,61 +69,59 @@ public class KMeansHybridLloydMacQueen<V extends NumberVector<?>, D extends Dist * @param maxiter Maxiter parameter * @param initializer Initialization method */ - public KMeansHybridLloydMacQueen(PrimitiveDistanceFunction<NumberVector<?>, D> distanceFunction, int k, int maxiter, KMeansInitialization<V> initializer) { + public KMeansHybridLloydMacQueen(PrimitiveDistanceFunction<NumberVector> distanceFunction, int k, int maxiter, KMeansInitialization<? super V> initializer) { super(distanceFunction, k, maxiter, initializer); } @Override - public Clustering<KMeansModel<V>> run(Database database, Relation<V> relation) { - if (relation.size() <= 0) { + public Clustering<KMeansModel> run(Database database, Relation<V> relation) { + if(relation.size() <= 0) { return new Clustering<>("k-Means Clustering", "kmeans-clustering"); } // Choose initial means - List<Vector> means = new ArrayList<>(k); - for (NumberVector<?> nv : initializer.chooseInitialMeans(database, relation, k, getDistanceFunction())) { - means.add(nv.getColumnVector()); - } + List<Vector> means = initializer.chooseInitialMeans(database, relation, k, getDistanceFunction(), Vector.FACTORY); // Setup cluster assignment store List<ModifiableDBIDs> clusters = new ArrayList<>(); - for (int i = 0; i < k; i++) { + for(int i = 0; i < k; i++) { clusters.add(DBIDUtil.newHashSet((int) (relation.size() * 2. / k))); } WritableIntegerDataStore assignment = DataStoreUtil.makeIntegerStorage(relation.getDBIDs(), DataStoreFactory.HINT_TEMP | DataStoreFactory.HINT_HOT, -1); + double[] varsum = new double[k]; IndefiniteProgress prog = LOG.isVerbose() ? new IndefiniteProgress("K-Means iteration", LOG) : null; - for (int iteration = 0; maxiter <= 0 || iteration < maxiter; iteration += 2) { + DoubleStatistic varstat = LOG.isStatistics() ? new DoubleStatistic(this.getClass().getName() + ".variance-sum") : null; + for(int iteration = 0; maxiter <= 0 || iteration < maxiter; iteration += 2) { { // MacQueen - if (prog != null) { - prog.incrementProcessed(LOG); - } - boolean changed = macQueenIterate(relation, means, clusters, assignment); - if (!changed) { + LOG.incrementProcessed(prog); + boolean changed = macQueenIterate(relation, means, clusters, assignment, varsum); + logVarstat(varstat, varsum); + if(!changed) { break; } } { // Lloyd - if (prog != null) { - prog.incrementProcessed(LOG); - } - boolean changed = assignToNearestCluster(relation, means, clusters, assignment); + LOG.incrementProcessed(prog); + boolean changed = assignToNearestCluster(relation, means, clusters, assignment, varsum); + logVarstat(varstat, varsum); // Stop if no cluster assignment changed. - if (!changed) { + if(!changed) { break; } // Recompute means. means = means(clusters, means, relation); } } - if (prog != null) { - prog.setCompleted(LOG); - } + LOG.setCompleted(prog); // Wrap result - final NumberVector.Factory<V, ?> factory = RelationUtil.getNumberVectorFactory(relation); - Clustering<KMeansModel<V>> result = new Clustering<>("k-Means Clustering", "kmeans-clustering"); - for (int i = 0; i < clusters.size(); i++) { - KMeansModel<V> model = new KMeansModel<>(factory.newNumberVector(means.get(i).getColumnVector().getArrayRef())); - result.addToplevelCluster(new Cluster<>(clusters.get(i), model)); + Clustering<KMeansModel> result = new Clustering<>("k-Means Clustering", "kmeans-clustering"); + for(int i = 0; i < clusters.size(); i++) { + DBIDs ids = clusters.get(i); + if(ids.size() == 0) { + continue; + } + KMeansModel model = new KMeansModel(means.get(i), varsum[i]); + result.addToplevelCluster(new Cluster<>(ids, model)); } return result; } @@ -141,14 +138,14 @@ public class KMeansHybridLloydMacQueen<V extends NumberVector<?>, D extends Dist * * @apiviz.exclude */ - public static class Parameterizer<V extends NumberVector<?>, D extends Distance<D>> extends AbstractKMeans.Parameterizer<V, D> { + public static class Parameterizer<V extends NumberVector> extends AbstractKMeans.Parameterizer<V> { @Override protected Logging getLogger() { return LOG; } @Override - protected KMeansHybridLloydMacQueen<V, D> makeInstance() { + protected KMeansHybridLloydMacQueen<V> makeInstance() { return new KMeansHybridLloydMacQueen<>(distanceFunction, k, maxiter, initializer); } } |