diff options
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/initialization/KMeansPlusPlusInitialMeans.java')
-rw-r--r-- | src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/initialization/KMeansPlusPlusInitialMeans.java | 220 |
1 files changed, 220 insertions, 0 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/initialization/KMeansPlusPlusInitialMeans.java b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/initialization/KMeansPlusPlusInitialMeans.java new file mode 100644 index 00000000..977a4182 --- /dev/null +++ b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/initialization/KMeansPlusPlusInitialMeans.java @@ -0,0 +1,220 @@ +package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.initialization; + +/* + This file is part of ELKI: + Environment for Developing KDD-Applications Supported by Index-Structures + + Copyright (C) 2014 + Ludwig-Maximilians-Universität München + Lehr- und Forschungseinheit für Datenbanksysteme + ELKI Development Team + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. + */ +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import de.lmu.ifi.dbs.elki.data.NumberVector; +import de.lmu.ifi.dbs.elki.database.Database; +import de.lmu.ifi.dbs.elki.database.datastore.DataStoreFactory; +import de.lmu.ifi.dbs.elki.database.datastore.DataStoreUtil; +import de.lmu.ifi.dbs.elki.database.datastore.WritableDoubleDataStore; +import de.lmu.ifi.dbs.elki.database.ids.ArrayModifiableDBIDs; +import de.lmu.ifi.dbs.elki.database.ids.DBID; +import de.lmu.ifi.dbs.elki.database.ids.DBIDIter; +import de.lmu.ifi.dbs.elki.database.ids.DBIDRef; +import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil; +import de.lmu.ifi.dbs.elki.database.ids.DBIDs; +import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery; +import de.lmu.ifi.dbs.elki.database.relation.Relation; +import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDistanceFunction; +import de.lmu.ifi.dbs.elki.logging.LoggingUtil; +import de.lmu.ifi.dbs.elki.math.random.RandomFactory; +import de.lmu.ifi.dbs.elki.utilities.documentation.Reference; + +/** + * K-Means++ initialization for k-means. + * + * Reference: + * <p> + * D. Arthur, S. Vassilvitskii<br /> + * k-means++: the advantages of careful seeding<br /> + * In: Proc. of the Eighteenth Annual ACM-SIAM Symposium on Discrete Algorithms, + * SODA 2007 + * </p> + * + * @author Erich Schubert + * + * @param <O> Vector type + */ +@Reference(authors = "D. Arthur, S. Vassilvitskii", title = "k-means++: the advantages of careful seeding", booktitle = "Proc. of the Eighteenth Annual ACM-SIAM Symposium on Discrete Algorithms, SODA 2007", url = "http://dx.doi.org/10.1145/1283383.1283494") +public class KMeansPlusPlusInitialMeans<O> extends AbstractKMeansInitialization<NumberVector> implements KMedoidsInitialization<O> { + /** + * Constructor. + * + * @param rnd Random generator. + */ + public KMeansPlusPlusInitialMeans(RandomFactory rnd) { + super(rnd); + } + + @Override + public <T extends NumberVector, V extends NumberVector> List<V> chooseInitialMeans(Database database, Relation<T> relation, int k, PrimitiveDistanceFunction<? super T> distanceFunction, NumberVector.Factory<V> factory) { + DistanceQuery<T> distQ = database.getDistanceQuery(relation, distanceFunction); + + DBIDs ids = relation.getDBIDs(); + WritableDoubleDataStore weights = DataStoreUtil.makeDoubleStorage(ids, DataStoreFactory.HINT_HOT | DataStoreFactory.HINT_TEMP, 0.); + + // Chose first mean + List<V> means = new ArrayList<>(k); + + Random random = rnd.getSingleThreadedRandom(); + DBID first = DBIDUtil.deref(DBIDUtil.randomSample(ids, 1, random).iter()); + means.add(factory.newNumberVector(relation.get(first))); + + // Initialize weights + double weightsum = initialWeights(weights, ids, first, distQ); + while(means.size() < k) { + if(weightsum > Double.MAX_VALUE) { + LoggingUtil.warning("Could not choose a reasonable mean for k-means++ - too many data points, too large squared distances?"); + } + if(weightsum < Double.MIN_NORMAL) { + LoggingUtil.warning("Could not choose a reasonable mean for k-means++ - to few data points?"); + } + double r = random.nextDouble() * weightsum; + DBIDIter it = ids.iter(); + for(; r > 0. && it.valid(); it.advance()) { + double w = weights.doubleValue(it); + if(w != w) { + continue; // NaN: alrady chosen. + } + r -= w; + } + // Add new mean: + final T newmean = relation.get(it); + means.add(factory.newNumberVector(newmean)); + // Update weights: + weights.putDouble(it, Double.NaN); + // Choose optimized version for double distances, if applicable. + weightsum = updateWeights(weights, ids, newmean, distQ); + } + + // Explicitly destroy temporary data. + weights.destroy(); + + return means; + } + + @Override + public DBIDs chooseInitialMedoids(int k, DBIDs ids, DistanceQuery<? super O> distQ) { + @SuppressWarnings("unchecked") + final Relation<O> rel = (Relation<O>) distQ.getRelation(); + + ArrayModifiableDBIDs means = DBIDUtil.newArray(k); + + WritableDoubleDataStore weights = DataStoreUtil.makeDoubleStorage(ids, DataStoreFactory.HINT_HOT | DataStoreFactory.HINT_TEMP, 0.); + + Random random = rnd.getSingleThreadedRandom(); + DBIDRef first = DBIDUtil.randomSample(ids, 1, random).iter(); + means.add(first); + + // Initialize weights + double weightsum = initialWeights(weights, ids, first, distQ); + while(means.size() < k) { + if(weightsum > Double.MAX_VALUE) { + LoggingUtil.warning("Could not choose a reasonable mean for k-means++ - too many data points, too large squared distances?"); + } + if(weightsum < Double.MIN_NORMAL) { + LoggingUtil.warning("Could not choose a reasonable mean for k-means++ - to few data points?"); + } + double r = random.nextDouble() * weightsum; + DBIDIter it = ids.iter(); + for(; r > 0. && it.valid(); it.advance()) { + double w = weights.doubleValue(it); + if(w != w) { + continue; // NaN: alrady chosen. + } + r -= w; + } + // Add new mean: + means.add(it); + // Update weights: + weights.putDouble(it, Double.NaN); + weightsum = updateWeights(weights, ids, rel.get(it), distQ); + } + + return means; + } + + /** + * Initialize the weight list. + * + * @param weights Weight list + * @param ids IDs + * @param latest Added ID + * @param distQ Distance query + * @return Weight sum + */ + protected double initialWeights(WritableDoubleDataStore weights, DBIDs ids, DBIDRef latest, DistanceQuery<?> distQ) { + double weightsum = 0.; + for(DBIDIter it = ids.iter(); it.valid(); it.advance()) { + // Distance will usually already be squared + double weight = distQ.distance(latest, it); + weights.putDouble(it, weight); + weightsum += weight; + } + return weightsum; + } + + /** + * Update the weight list. + * + * @param weights Weight list + * @param ids IDs + * @param latest Added ID + * @param distQ Distance query + * @return Weight sum + */ + protected <T> double updateWeights(WritableDoubleDataStore weights, DBIDs ids, T latest, DistanceQuery<? super T> distQ) { + double weightsum = 0.; + for(DBIDIter it = ids.iter(); it.valid(); it.advance()) { + double weight = weights.doubleValue(it); + if(weight != weight) { + continue; // NaN: already chosen! + } + double newweight = distQ.distance(latest, it); + if(newweight < weight) { + weights.putDouble(it, newweight); + weight = newweight; + } + weightsum += weight; + } + return weightsum; + } + + /** + * Parameterization class. + * + * @author Erich Schubert + * + * @apiviz.exclude + */ + public static class Parameterizer<V> extends AbstractKMeansInitialization.Parameterizer { + @Override + protected KMeansPlusPlusInitialMeans<V> makeInstance() { + return new KMeansPlusPlusInitialMeans<>(rnd); + } + } +}
\ No newline at end of file |