summaryrefslogtreecommitdiff
path: root/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansPlusPlusInitialMeans.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansPlusPlusInitialMeans.java')
-rw-r--r--src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansPlusPlusInitialMeans.java213
1 files changed, 213 insertions, 0 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansPlusPlusInitialMeans.java b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansPlusPlusInitialMeans.java
new file mode 100644
index 00000000..c7a2fa1d
--- /dev/null
+++ b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansPlusPlusInitialMeans.java
@@ -0,0 +1,213 @@
+package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans;
+
+/*
+ This file is part of ELKI:
+ Environment for Developing KDD-Applications Supported by Index-Structures
+
+ Copyright (C) 2012
+ Ludwig-Maximilians-Universität München
+ Lehr- und Forschungseinheit für Datenbanksysteme
+ ELKI Development Team
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
+import de.lmu.ifi.dbs.elki.data.NumberVector;
+import de.lmu.ifi.dbs.elki.database.ids.ArrayDBIDs;
+import de.lmu.ifi.dbs.elki.database.ids.DBID;
+import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
+import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
+import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs;
+import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery;
+import de.lmu.ifi.dbs.elki.database.relation.Relation;
+import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDistanceFunction;
+import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDoubleDistanceFunction;
+import de.lmu.ifi.dbs.elki.distance.distancevalue.NumberDistance;
+import de.lmu.ifi.dbs.elki.logging.LoggingUtil;
+import de.lmu.ifi.dbs.elki.math.linearalgebra.Vector;
+import de.lmu.ifi.dbs.elki.utilities.documentation.Reference;
+import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException;
+
+/**
+ * K-Means++ initialization for k-means.
+ *
+ * Reference:
+ * <p>
+ * D. Arthur, S. Vassilvitskii<br />
+ * k-means++: the advantages of careful seeding<br />
+ * In: Proc. of the Eighteenth Annual ACM-SIAM Symposium on Discrete Algorithms,
+ * SODA 2007
+ * </p>
+ *
+ * @author Erich Schubert
+ *
+ * @param <V> Vector type
+ * @param <D> Distance type
+ */
+@Reference(authors = "D. Arthur, S. Vassilvitskii", title = "k-means++: the advantages of careful seeding", booktitle = "Proc. of the Eighteenth Annual ACM-SIAM Symposium on Discrete Algorithms, SODA 2007", url = "http://dx.doi.org/10.1145/1283383.1283494")
+public class KMeansPlusPlusInitialMeans<V extends NumberVector<V, ?>, D extends NumberDistance<D, ?>> extends AbstractKMeansInitialization<V> {
+ /**
+ * Constructor.
+ *
+ * @param seed Random seed.
+ */
+ public KMeansPlusPlusInitialMeans(Long seed) {
+ super(seed);
+ }
+
+ @Override
+ public List<Vector> chooseInitialMeans(Relation<V> relation, int k, PrimitiveDistanceFunction<? super V, ?> distanceFunction) {
+ // Get a distance query
+ if(!(distanceFunction.getDistanceFactory() instanceof NumberDistance)) {
+ throw new AbortException("K-Means++ initialization can only be used with numerical distances.");
+ }
+ @SuppressWarnings("unchecked")
+ final PrimitiveDistanceFunction<? super V, D> distF = (PrimitiveDistanceFunction<? super V, D>) distanceFunction;
+ DistanceQuery<V, D> distQ = relation.getDatabase().getDistanceQuery(relation, distF);
+
+ // Chose first mean
+ List<Vector> means = new ArrayList<Vector>(k);
+
+ Random random = (seed != null) ? new Random(seed) : new Random();
+ DBID first = DBIDUtil.randomSample(relation.getDBIDs(), 1, random.nextLong()).iterator().next();
+ means.add(relation.get(first).getColumnVector());
+
+ ModifiableDBIDs chosen = DBIDUtil.newHashSet(k);
+ chosen.add(first);
+ ArrayDBIDs ids = DBIDUtil.ensureArray(relation.getDBIDs());
+ // Initialize weights
+ double[] weights = new double[ids.size()];
+ double weightsum = initialWeights(weights, ids, first, distQ);
+ while(means.size() < k) {
+ if(weightsum > Double.MAX_VALUE) {
+ LoggingUtil.warning("Could not choose a reasonable mean for k-means++ - too many data points, too large squared distances?");
+ }
+ if(weightsum < Double.MIN_NORMAL) {
+ LoggingUtil.warning("Could not choose a reasonable mean for k-means++ - to few data points?");
+ }
+ double r = random.nextDouble() * weightsum;
+ int pos = 0;
+ while(r > 0 && pos < weights.length) {
+ r -= weights[pos];
+ pos++;
+ }
+ // Add new mean:
+ DBID newmean = ids.get(pos);
+ means.add(relation.get(newmean).getColumnVector());
+ chosen.add(newmean);
+ // Update weights:
+ weights[pos] = 0.0;
+ // Choose optimized version for double distances, if applicable.
+ if (distF instanceof PrimitiveDoubleDistanceFunction) {
+ @SuppressWarnings("unchecked")
+ PrimitiveDoubleDistanceFunction<V> ddist = (PrimitiveDoubleDistanceFunction<V>) distF;
+ weightsum = updateWeights(weights, ids, newmean, ddist, relation);
+ } else {
+ weightsum = updateWeights(weights, ids, newmean, distQ);
+ }
+ }
+
+ return means;
+ }
+
+ /**
+ * Initialize the weight list.
+ *
+ * @param weights Weight list
+ * @param ids IDs
+ * @param latest Added ID
+ * @param distQ Distance query
+ * @return Weight sum
+ */
+ protected double initialWeights(double[] weights, ArrayDBIDs ids, DBID latest, DistanceQuery<V, D> distQ) {
+ double weightsum = 0.0;
+ DBIDIter it = ids.iter();
+ for(int i = 0; i < weights.length; i++, it.advance()) {
+ DBID id = it.getDBID();
+ if(latest.equals(id)) {
+ weights[i] = 0.0;
+ }
+ else {
+ double d = distQ.distance(latest, id).doubleValue();
+ weights[i] = d * d;
+ }
+ weightsum += weights[i];
+ }
+ return weightsum;
+ }
+
+ /**
+ * Update the weight list.
+ *
+ * @param weights Weight list
+ * @param ids IDs
+ * @param latest Added ID
+ * @param distQ Distance query
+ * @return Weight sum
+ */
+ protected double updateWeights(double[] weights, ArrayDBIDs ids, DBID latest, DistanceQuery<V, D> distQ) {
+ double weightsum = 0.0;
+ DBIDIter it = ids.iter();
+ for(int i = 0; i < weights.length; i++, it.advance()) {
+ DBID id = it.getDBID();
+ if(weights[i] > 0.0) {
+ double d = distQ.distance(latest, id).doubleValue();
+ weights[i] = Math.min(weights[i], d * d);
+ weightsum += weights[i];
+ }
+ }
+ return weightsum;
+ }
+
+ /**
+ * Update the weight list.
+ *
+ * @param weights Weight list
+ * @param ids IDs
+ * @param latest Added ID
+ * @param distF Distance function
+ * @return Weight sum
+ */
+ protected double updateWeights(double[] weights, ArrayDBIDs ids, DBID latest, PrimitiveDoubleDistanceFunction<V> distF, Relation<V> rel) {
+ final V lv = rel.get(latest);
+ double weightsum = 0.0;
+ DBIDIter it = ids.iter();
+ for(int i = 0; i < weights.length; i++, it.advance()) {
+ DBID id = it.getDBID();
+ if(weights[i] > 0.0) {
+ double d = distF.doubleDistance(lv, rel.get(id));
+ weights[i] = Math.min(weights[i], d * d);
+ weightsum += weights[i];
+ }
+ }
+ return weightsum;
+ }
+
+ /**
+ * Parameterization class.
+ *
+ * @author Erich Schubert
+ *
+ * @apiviz.exclude
+ */
+ public static class Parameterizer<V extends NumberVector<V, ?>, D extends NumberDistance<D, ?>> extends AbstractKMeansInitialization.Parameterizer<V> {
+ @Override
+ protected KMeansPlusPlusInitialMeans<V, D> makeInstance() {
+ return new KMeansPlusPlusInitialMeans<V, D>(seed);
+ }
+ }
+} \ No newline at end of file