summaryrefslogtreecommitdiff
path: root/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java')
-rw-r--r--src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java87
1 files changed, 54 insertions, 33 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java
index b1b40632..f43c2277 100644
--- a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java
+++ b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansLloyd.java
@@ -27,19 +27,20 @@ import java.util.ArrayList;
import java.util.List;
import de.lmu.ifi.dbs.elki.algorithm.AbstractPrimitiveDistanceBasedAlgorithm;
-import de.lmu.ifi.dbs.elki.algorithm.clustering.ClusteringAlgorithm;
import de.lmu.ifi.dbs.elki.data.Cluster;
import de.lmu.ifi.dbs.elki.data.Clustering;
import de.lmu.ifi.dbs.elki.data.NumberVector;
-import de.lmu.ifi.dbs.elki.data.model.MeanModel;
+import de.lmu.ifi.dbs.elki.data.model.KMeansModel;
import de.lmu.ifi.dbs.elki.database.Database;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
+import de.lmu.ifi.dbs.elki.database.relation.RelationUtil;
+import de.lmu.ifi.dbs.elki.distance.distancefunction.EuclideanDistanceFunction;
import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDistanceFunction;
+import de.lmu.ifi.dbs.elki.distance.distancefunction.SquaredEuclideanDistanceFunction;
import de.lmu.ifi.dbs.elki.distance.distancevalue.Distance;
import de.lmu.ifi.dbs.elki.logging.Logging;
-import de.lmu.ifi.dbs.elki.utilities.DatabaseUtil;
import de.lmu.ifi.dbs.elki.utilities.documentation.Description;
import de.lmu.ifi.dbs.elki.utilities.documentation.Reference;
import de.lmu.ifi.dbs.elki.utilities.documentation.Title;
@@ -62,7 +63,8 @@ import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
*
* @author Arthur Zimek
*
- * @apiviz.has MeanModel
+ * @apiviz.landmark
+ * @apiviz.has KMeansModel
*
* @param <V> vector datatype
* @param <D> distance value type
@@ -70,11 +72,11 @@ import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
@Title("K-Means")
@Description("Finds a partitioning into k clusters.")
@Reference(authors = "S. Lloyd", title = "Least squares quantization in PCM", booktitle = "IEEE Transactions on Information Theory 28 (2): 129–137.", url = "http://dx.doi.org/10.1109/TIT.1982.1056489")
-public class KMeansLloyd<V extends NumberVector<V, ?>, D extends Distance<D>> extends AbstractKMeans<V, D> implements ClusteringAlgorithm<Clustering<MeanModel<V>>> {
+public class KMeansLloyd<V extends NumberVector<?>, D extends Distance<D>> extends AbstractKMeans<V, D, KMeansModel<V>> {
/**
* The logger for this class.
*/
- private static final Logging logger = Logging.getLogger(KMeansLloyd.class);
+ private static final Logging LOG = Logging.getLogger(KMeansLloyd.class);
/**
* Constructor.
@@ -82,55 +84,56 @@ public class KMeansLloyd<V extends NumberVector<V, ?>, D extends Distance<D>> ex
* @param distanceFunction distance function
* @param k k parameter
* @param maxiter Maxiter parameter
+ * @param initializer Initialization method
*/
- public KMeansLloyd(PrimitiveDistanceFunction<NumberVector<?, ?>, D> distanceFunction, int k, int maxiter, KMeansInitialization<V> initializer) {
+ public KMeansLloyd(PrimitiveDistanceFunction<NumberVector<?>, D> distanceFunction, int k, int maxiter, KMeansInitialization<V> initializer) {
super(distanceFunction, k, maxiter, initializer);
}
/**
- * Run k-means
+ * Run k-means.
*
* @param database Database
* @param relation relation to use
* @return result
*/
- public Clustering<MeanModel<V>> run(Database database, Relation<V> relation) {
- if(relation.size() <= 0) {
- return new Clustering<MeanModel<V>>("k-Means Clustering", "kmeans-clustering");
+ public Clustering<KMeansModel<V>> run(Database database, Relation<V> relation) {
+ if (relation.size() <= 0) {
+ return new Clustering<KMeansModel<V>>("k-Means Clustering", "kmeans-clustering");
}
// Choose initial means
- List<? extends NumberVector<?, ?>> means = initializer.chooseInitialMeans(relation, k, getDistanceFunction());
+ List<? extends NumberVector<?>> means = initializer.chooseInitialMeans(relation, k, getDistanceFunction());
// Setup cluster assignment store
List<ModifiableDBIDs> clusters = new ArrayList<ModifiableDBIDs>();
- for(int i = 0; i < k; i++) {
+ for (int i = 0; i < k; i++) {
clusters.add(DBIDUtil.newHashSet(relation.size() / k));
}
- for(int iteration = 0; maxiter <= 0 || iteration < maxiter; iteration++) {
- if(logger.isVerbose()) {
- logger.verbose("K-Means iteration " + (iteration + 1));
+ for (int iteration = 0; maxiter <= 0 || iteration < maxiter; iteration++) {
+ if (LOG.isVerbose()) {
+ LOG.verbose("K-Means iteration " + (iteration + 1));
}
boolean changed = assignToNearestCluster(relation, means, clusters);
// Stop if no cluster assignment changed.
- if(!changed) {
+ if (!changed) {
break;
}
// Recompute means.
means = means(clusters, means, relation);
}
// Wrap result
- final V factory = DatabaseUtil.assumeVectorField(relation).getFactory();
- Clustering<MeanModel<V>> result = new Clustering<MeanModel<V>>("k-Means Clustering", "kmeans-clustering");
- for(int i = 0; i < clusters.size(); i++) {
- MeanModel<V> model = new MeanModel<V>(factory.newNumberVector(means.get(i).getColumnVector().getArrayRef()));
- result.addCluster(new Cluster<MeanModel<V>>(clusters.get(i), model));
+ final NumberVector.Factory<V, ?> factory = RelationUtil.getNumberVectorFactory(relation);
+ Clustering<KMeansModel<V>> result = new Clustering<KMeansModel<V>>("k-Means Clustering", "kmeans-clustering");
+ for (int i = 0; i < clusters.size(); i++) {
+ KMeansModel<V> model = new KMeansModel<V>(factory.newNumberVector(means.get(i).getColumnVector().getArrayRef()));
+ result.addCluster(new Cluster<KMeansModel<V>>(clusters.get(i), model));
}
return result;
}
@Override
protected Logging getLogger() {
- return logger;
+ return LOG;
}
/**
@@ -140,35 +143,53 @@ public class KMeansLloyd<V extends NumberVector<V, ?>, D extends Distance<D>> ex
*
* @apiviz.exclude
*/
- public static class Parameterizer<V extends NumberVector<V, ?>, D extends Distance<D>> extends AbstractPrimitiveDistanceBasedAlgorithm.Parameterizer<NumberVector<?, ?>, D> {
+ public static class Parameterizer<V extends NumberVector<?>, D extends Distance<D>> extends AbstractPrimitiveDistanceBasedAlgorithm.Parameterizer<NumberVector<?>, D> {
+ /**
+ * k Parameter.
+ */
protected int k;
+ /**
+ * Number of iterations.
+ */
protected int maxiter;
+ /**
+ * Initialization method.
+ */
protected KMeansInitialization<V> initializer;
@Override
protected void makeOptions(Parameterization config) {
- super.makeOptions(config);
- IntParameter kP = new IntParameter(K_ID, new GreaterConstraint(0));
- if(config.grab(kP)) {
+ ObjectParameter<PrimitiveDistanceFunction<NumberVector<?>, D>> distanceFunctionP = makeParameterDistanceFunction(SquaredEuclideanDistanceFunction.class, PrimitiveDistanceFunction.class);
+ if(config.grab(distanceFunctionP)) {
+ distanceFunction = distanceFunctionP.instantiateClass(config);
+ if (!(distanceFunction instanceof EuclideanDistanceFunction) && !(distanceFunction instanceof SquaredEuclideanDistanceFunction)) {
+ LOG.warning("k-means optimizes the sum of squares - it should be used with squared euclidean distance and may stop converging otherwise!");
+ }
+ }
+
+ IntParameter kP = new IntParameter(K_ID);
+ kP.addConstraint(new GreaterConstraint(0));
+ if (config.grab(kP)) {
k = kP.getValue();
}
ObjectParameter<KMeansInitialization<V>> initialP = new ObjectParameter<KMeansInitialization<V>>(INIT_ID, KMeansInitialization.class, RandomlyGeneratedInitialMeans.class);
- if(config.grab(initialP)) {
+ if (config.grab(initialP)) {
initializer = initialP.instantiateClass(config);
}
- IntParameter maxiterP = new IntParameter(MAXITER_ID, new GreaterEqualConstraint(0), 0);
- if(config.grab(maxiterP)) {
- maxiter = maxiterP.getValue();
+ IntParameter maxiterP = new IntParameter(MAXITER_ID, 0);
+ maxiterP.addConstraint(new GreaterEqualConstraint(0));
+ if (config.grab(maxiterP)) {
+ maxiter = maxiterP.intValue();
}
}
@Override
- protected AbstractKMeans<V, D> makeInstance() {
+ protected KMeansLloyd<V, D> makeInstance() {
return new KMeansLloyd<V, D>(distanceFunction, k, maxiter, initializer);
}
}
-} \ No newline at end of file
+}