diff options
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansPlusPlusInitialMeans.java')
-rw-r--r-- | src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansPlusPlusInitialMeans.java | 213 |
1 files changed, 213 insertions, 0 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansPlusPlusInitialMeans.java b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansPlusPlusInitialMeans.java new file mode 100644 index 00000000..c7a2fa1d --- /dev/null +++ b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/KMeansPlusPlusInitialMeans.java @@ -0,0 +1,213 @@ +package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans; + +/* + This file is part of ELKI: + Environment for Developing KDD-Applications Supported by Index-Structures + + Copyright (C) 2012 + Ludwig-Maximilians-Universität München + Lehr- und Forschungseinheit für Datenbanksysteme + ELKI Development Team + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. + */ +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import de.lmu.ifi.dbs.elki.data.NumberVector; +import de.lmu.ifi.dbs.elki.database.ids.ArrayDBIDs; +import de.lmu.ifi.dbs.elki.database.ids.DBID; +import de.lmu.ifi.dbs.elki.database.ids.DBIDIter; +import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil; +import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs; +import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery; +import de.lmu.ifi.dbs.elki.database.relation.Relation; +import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDistanceFunction; +import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDoubleDistanceFunction; +import de.lmu.ifi.dbs.elki.distance.distancevalue.NumberDistance; +import de.lmu.ifi.dbs.elki.logging.LoggingUtil; +import de.lmu.ifi.dbs.elki.math.linearalgebra.Vector; +import de.lmu.ifi.dbs.elki.utilities.documentation.Reference; +import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException; + +/** + * K-Means++ initialization for k-means. + * + * Reference: + * <p> + * D. Arthur, S. Vassilvitskii<br /> + * k-means++: the advantages of careful seeding<br /> + * In: Proc. of the Eighteenth Annual ACM-SIAM Symposium on Discrete Algorithms, + * SODA 2007 + * </p> + * + * @author Erich Schubert + * + * @param <V> Vector type + * @param <D> Distance type + */ +@Reference(authors = "D. Arthur, S. Vassilvitskii", title = "k-means++: the advantages of careful seeding", booktitle = "Proc. of the Eighteenth Annual ACM-SIAM Symposium on Discrete Algorithms, SODA 2007", url = "http://dx.doi.org/10.1145/1283383.1283494") +public class KMeansPlusPlusInitialMeans<V extends NumberVector<V, ?>, D extends NumberDistance<D, ?>> extends AbstractKMeansInitialization<V> { + /** + * Constructor. + * + * @param seed Random seed. + */ + public KMeansPlusPlusInitialMeans(Long seed) { + super(seed); + } + + @Override + public List<Vector> chooseInitialMeans(Relation<V> relation, int k, PrimitiveDistanceFunction<? super V, ?> distanceFunction) { + // Get a distance query + if(!(distanceFunction.getDistanceFactory() instanceof NumberDistance)) { + throw new AbortException("K-Means++ initialization can only be used with numerical distances."); + } + @SuppressWarnings("unchecked") + final PrimitiveDistanceFunction<? super V, D> distF = (PrimitiveDistanceFunction<? super V, D>) distanceFunction; + DistanceQuery<V, D> distQ = relation.getDatabase().getDistanceQuery(relation, distF); + + // Chose first mean + List<Vector> means = new ArrayList<Vector>(k); + + Random random = (seed != null) ? new Random(seed) : new Random(); + DBID first = DBIDUtil.randomSample(relation.getDBIDs(), 1, random.nextLong()).iterator().next(); + means.add(relation.get(first).getColumnVector()); + + ModifiableDBIDs chosen = DBIDUtil.newHashSet(k); + chosen.add(first); + ArrayDBIDs ids = DBIDUtil.ensureArray(relation.getDBIDs()); + // Initialize weights + double[] weights = new double[ids.size()]; + double weightsum = initialWeights(weights, ids, first, distQ); + while(means.size() < k) { + if(weightsum > Double.MAX_VALUE) { + LoggingUtil.warning("Could not choose a reasonable mean for k-means++ - too many data points, too large squared distances?"); + } + if(weightsum < Double.MIN_NORMAL) { + LoggingUtil.warning("Could not choose a reasonable mean for k-means++ - to few data points?"); + } + double r = random.nextDouble() * weightsum; + int pos = 0; + while(r > 0 && pos < weights.length) { + r -= weights[pos]; + pos++; + } + // Add new mean: + DBID newmean = ids.get(pos); + means.add(relation.get(newmean).getColumnVector()); + chosen.add(newmean); + // Update weights: + weights[pos] = 0.0; + // Choose optimized version for double distances, if applicable. + if (distF instanceof PrimitiveDoubleDistanceFunction) { + @SuppressWarnings("unchecked") + PrimitiveDoubleDistanceFunction<V> ddist = (PrimitiveDoubleDistanceFunction<V>) distF; + weightsum = updateWeights(weights, ids, newmean, ddist, relation); + } else { + weightsum = updateWeights(weights, ids, newmean, distQ); + } + } + + return means; + } + + /** + * Initialize the weight list. + * + * @param weights Weight list + * @param ids IDs + * @param latest Added ID + * @param distQ Distance query + * @return Weight sum + */ + protected double initialWeights(double[] weights, ArrayDBIDs ids, DBID latest, DistanceQuery<V, D> distQ) { + double weightsum = 0.0; + DBIDIter it = ids.iter(); + for(int i = 0; i < weights.length; i++, it.advance()) { + DBID id = it.getDBID(); + if(latest.equals(id)) { + weights[i] = 0.0; + } + else { + double d = distQ.distance(latest, id).doubleValue(); + weights[i] = d * d; + } + weightsum += weights[i]; + } + return weightsum; + } + + /** + * Update the weight list. + * + * @param weights Weight list + * @param ids IDs + * @param latest Added ID + * @param distQ Distance query + * @return Weight sum + */ + protected double updateWeights(double[] weights, ArrayDBIDs ids, DBID latest, DistanceQuery<V, D> distQ) { + double weightsum = 0.0; + DBIDIter it = ids.iter(); + for(int i = 0; i < weights.length; i++, it.advance()) { + DBID id = it.getDBID(); + if(weights[i] > 0.0) { + double d = distQ.distance(latest, id).doubleValue(); + weights[i] = Math.min(weights[i], d * d); + weightsum += weights[i]; + } + } + return weightsum; + } + + /** + * Update the weight list. + * + * @param weights Weight list + * @param ids IDs + * @param latest Added ID + * @param distF Distance function + * @return Weight sum + */ + protected double updateWeights(double[] weights, ArrayDBIDs ids, DBID latest, PrimitiveDoubleDistanceFunction<V> distF, Relation<V> rel) { + final V lv = rel.get(latest); + double weightsum = 0.0; + DBIDIter it = ids.iter(); + for(int i = 0; i < weights.length; i++, it.advance()) { + DBID id = it.getDBID(); + if(weights[i] > 0.0) { + double d = distF.doubleDistance(lv, rel.get(id)); + weights[i] = Math.min(weights[i], d * d); + weightsum += weights[i]; + } + } + return weightsum; + } + + /** + * Parameterization class. + * + * @author Erich Schubert + * + * @apiviz.exclude + */ + public static class Parameterizer<V extends NumberVector<V, ?>, D extends NumberDistance<D, ?>> extends AbstractKMeansInitialization.Parameterizer<V> { + @Override + protected KMeansPlusPlusInitialMeans<V, D> makeInstance() { + return new KMeansPlusPlusInitialMeans<V, D>(seed); + } + } +}
\ No newline at end of file |