diff options
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java')
-rw-r--r-- | src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java | 87 |
1 files changed, 54 insertions, 33 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java index b1b40632..f43c2277 100644 --- a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java +++ b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java @@ -27,19 +27,20 @@ import java.util.ArrayList; import java.util.List; import de.lmu.ifi.dbs.elki.algorithm.AbstractPrimitiveDistanceBasedAlgorithm; -import de.lmu.ifi.dbs.elki.algorithm.clustering.ClusteringAlgorithm; import de.lmu.ifi.dbs.elki.data.Cluster; import de.lmu.ifi.dbs.elki.data.Clustering; import de.lmu.ifi.dbs.elki.data.NumberVector; -import de.lmu.ifi.dbs.elki.data.model.MeanModel; +import de.lmu.ifi.dbs.elki.data.model.KMeansModel; import de.lmu.ifi.dbs.elki.database.Database; import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil; import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs; import de.lmu.ifi.dbs.elki.database.relation.Relation; +import de.lmu.ifi.dbs.elki.database.relation.RelationUtil; +import de.lmu.ifi.dbs.elki.distance.distancefunction.EuclideanDistanceFunction; import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDistanceFunction; +import de.lmu.ifi.dbs.elki.distance.distancefunction.SquaredEuclideanDistanceFunction; import de.lmu.ifi.dbs.elki.distance.distancevalue.Distance; import de.lmu.ifi.dbs.elki.logging.Logging; -import de.lmu.ifi.dbs.elki.utilities.DatabaseUtil; import de.lmu.ifi.dbs.elki.utilities.documentation.Description; import de.lmu.ifi.dbs.elki.utilities.documentation.Reference; import de.lmu.ifi.dbs.elki.utilities.documentation.Title; @@ -62,7 +63,8 @@ import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter; * * @author Arthur Zimek * - * @apiviz.has MeanModel + * @apiviz.landmark + * @apiviz.has KMeansModel * * @param <V> vector datatype * @param <D> distance value type @@ -70,11 +72,11 @@ import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter; @Title("K-Means") @Description("Finds a partitioning into k clusters.") @Reference(authors = "S. Lloyd", title = "Least squares quantization in PCM", booktitle = "IEEE Transactions on Information Theory 28 (2): 129–137.", url = "http://dx.doi.org/10.1109/TIT.1982.1056489") -public class KMeansLloyd<V extends NumberVector<V, ?>, D extends Distance<D>> extends AbstractKMeans<V, D> implements ClusteringAlgorithm<Clustering<MeanModel<V>>> { +public class KMeansLloyd<V extends NumberVector<?>, D extends Distance<D>> extends AbstractKMeans<V, D, KMeansModel<V>> { /** * The logger for this class. */ - private static final Logging logger = Logging.getLogger(KMeansLloyd.class); + private static final Logging LOG = Logging.getLogger(KMeansLloyd.class); /** * Constructor. @@ -82,55 +84,56 @@ public class KMeansLloyd<V extends NumberVector<V, ?>, D extends Distance<D>> ex * @param distanceFunction distance function * @param k k parameter * @param maxiter Maxiter parameter + * @param initializer Initialization method */ - public KMeansLloyd(PrimitiveDistanceFunction<NumberVector<?, ?>, D> distanceFunction, int k, int maxiter, KMeansInitialization<V> initializer) { + public KMeansLloyd(PrimitiveDistanceFunction<NumberVector<?>, D> distanceFunction, int k, int maxiter, KMeansInitialization<V> initializer) { super(distanceFunction, k, maxiter, initializer); } /** - * Run k-means + * Run k-means. * * @param database Database * @param relation relation to use * @return result */ - public Clustering<MeanModel<V>> run(Database database, Relation<V> relation) { - if(relation.size() <= 0) { - return new Clustering<MeanModel<V>>("k-Means Clustering", "kmeans-clustering"); + public Clustering<KMeansModel<V>> run(Database database, Relation<V> relation) { + if (relation.size() <= 0) { + return new Clustering<KMeansModel<V>>("k-Means Clustering", "kmeans-clustering"); } // Choose initial means - List<? extends NumberVector<?, ?>> means = initializer.chooseInitialMeans(relation, k, getDistanceFunction()); + List<? extends NumberVector<?>> means = initializer.chooseInitialMeans(relation, k, getDistanceFunction()); // Setup cluster assignment store List<ModifiableDBIDs> clusters = new ArrayList<ModifiableDBIDs>(); - for(int i = 0; i < k; i++) { + for (int i = 0; i < k; i++) { clusters.add(DBIDUtil.newHashSet(relation.size() / k)); } - for(int iteration = 0; maxiter <= 0 || iteration < maxiter; iteration++) { - if(logger.isVerbose()) { - logger.verbose("K-Means iteration " + (iteration + 1)); + for (int iteration = 0; maxiter <= 0 || iteration < maxiter; iteration++) { + if (LOG.isVerbose()) { + LOG.verbose("K-Means iteration " + (iteration + 1)); } boolean changed = assignToNearestCluster(relation, means, clusters); // Stop if no cluster assignment changed. - if(!changed) { + if (!changed) { break; } // Recompute means. means = means(clusters, means, relation); } // Wrap result - final V factory = DatabaseUtil.assumeVectorField(relation).getFactory(); - Clustering<MeanModel<V>> result = new Clustering<MeanModel<V>>("k-Means Clustering", "kmeans-clustering"); - for(int i = 0; i < clusters.size(); i++) { - MeanModel<V> model = new MeanModel<V>(factory.newNumberVector(means.get(i).getColumnVector().getArrayRef())); - result.addCluster(new Cluster<MeanModel<V>>(clusters.get(i), model)); + final NumberVector.Factory<V, ?> factory = RelationUtil.getNumberVectorFactory(relation); + Clustering<KMeansModel<V>> result = new Clustering<KMeansModel<V>>("k-Means Clustering", "kmeans-clustering"); + for (int i = 0; i < clusters.size(); i++) { + KMeansModel<V> model = new KMeansModel<V>(factory.newNumberVector(means.get(i).getColumnVector().getArrayRef())); + result.addCluster(new Cluster<KMeansModel<V>>(clusters.get(i), model)); } return result; } @Override protected Logging getLogger() { - return logger; + return LOG; } /** @@ -140,35 +143,53 @@ public class KMeansLloyd<V extends NumberVector<V, ?>, D extends Distance<D>> ex * * @apiviz.exclude */ - public static class Parameterizer<V extends NumberVector<V, ?>, D extends Distance<D>> extends AbstractPrimitiveDistanceBasedAlgorithm.Parameterizer<NumberVector<?, ?>, D> { + public static class Parameterizer<V extends NumberVector<?>, D extends Distance<D>> extends AbstractPrimitiveDistanceBasedAlgorithm.Parameterizer<NumberVector<?>, D> { + /** + * k Parameter. + */ protected int k; + /** + * Number of iterations. + */ protected int maxiter; + /** + * Initialization method. + */ protected KMeansInitialization<V> initializer; @Override protected void makeOptions(Parameterization config) { - super.makeOptions(config); - IntParameter kP = new IntParameter(K_ID, new GreaterConstraint(0)); - if(config.grab(kP)) { + ObjectParameter<PrimitiveDistanceFunction<NumberVector<?>, D>> distanceFunctionP = makeParameterDistanceFunction(SquaredEuclideanDistanceFunction.class, PrimitiveDistanceFunction.class); + if(config.grab(distanceFunctionP)) { + distanceFunction = distanceFunctionP.instantiateClass(config); + if (!(distanceFunction instanceof EuclideanDistanceFunction) && !(distanceFunction instanceof SquaredEuclideanDistanceFunction)) { + LOG.warning("k-means optimizes the sum of squares - it should be used with squared euclidean distance and may stop converging otherwise!"); + } + } + + IntParameter kP = new IntParameter(K_ID); + kP.addConstraint(new GreaterConstraint(0)); + if (config.grab(kP)) { k = kP.getValue(); } ObjectParameter<KMeansInitialization<V>> initialP = new ObjectParameter<KMeansInitialization<V>>(INIT_ID, KMeansInitialization.class, RandomlyGeneratedInitialMeans.class); - if(config.grab(initialP)) { + if (config.grab(initialP)) { initializer = initialP.instantiateClass(config); } - IntParameter maxiterP = new IntParameter(MAXITER_ID, new GreaterEqualConstraint(0), 0); - if(config.grab(maxiterP)) { - maxiter = maxiterP.getValue(); + IntParameter maxiterP = new IntParameter(MAXITER_ID, 0); + maxiterP.addConstraint(new GreaterEqualConstraint(0)); + if (config.grab(maxiterP)) { + maxiter = maxiterP.intValue(); } } @Override - protected AbstractKMeans<V, D> makeInstance() { + protected KMeansLloyd<V, D> makeInstance() { return new KMeansLloyd<V, D>(distanceFunction, k, maxiter, initializer); } } -}
\ No newline at end of file +} |