diff options
author | Andrej Shadura <andrewsh@debian.org> | 2019-03-09 22:30:40 +0000 |
---|---|---|
committer | Andrej Shadura <andrewsh@debian.org> | 2019-03-09 22:30:40 +0000 |
commit | 337087b668d3a54f3afee3a9adb597a32e9f7e94 (patch) | |
tree | d860094269622472f8079d497ac7af02dbb4e038 /src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java | |
parent | 14a486343aef55f97f54082d6b542dedebf6f3ba (diff) |
Import Upstream version 0.6.5~20141030
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java')
-rw-r--r-- | src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java | 224 |
1 files changed, 119 insertions, 105 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java index 5754e961..bdfc2f04 100644 --- a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java +++ b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java @@ -4,7 +4,7 @@ package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans; This file is part of ELKI: Environment for Developing KDD-Applications Supported by Index-Structures - Copyright (C) 2013 + Copyright (C) 2014 Ludwig-Maximilians-Universität München Lehr- und Forschungseinheit für Datenbanksysteme ELKI Development Team @@ -24,14 +24,17 @@ package de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans; */ import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import de.lmu.ifi.dbs.elki.algorithm.AbstractPrimitiveDistanceBasedAlgorithm; import de.lmu.ifi.dbs.elki.algorithm.clustering.ClusteringAlgorithm; +import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.initialization.KMeansInitialization; +import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.initialization.RandomlyChosenInitialMeans; import de.lmu.ifi.dbs.elki.data.Clustering; import de.lmu.ifi.dbs.elki.data.NumberVector; import de.lmu.ifi.dbs.elki.data.VectorUtil.SortDBIDsBySingleDimension; -import de.lmu.ifi.dbs.elki.data.model.MeanModel; +import de.lmu.ifi.dbs.elki.data.model.Model; import de.lmu.ifi.dbs.elki.data.type.CombinedTypeInformation; import de.lmu.ifi.dbs.elki.data.type.TypeInformation; import de.lmu.ifi.dbs.elki.data.type.TypeUtil; @@ -43,11 +46,10 @@ import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil; import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs; import de.lmu.ifi.dbs.elki.database.relation.Relation; import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDistanceFunction; -import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDoubleDistanceFunction; import de.lmu.ifi.dbs.elki.distance.distancefunction.minkowski.EuclideanDistanceFunction; import de.lmu.ifi.dbs.elki.distance.distancefunction.minkowski.SquaredEuclideanDistanceFunction; -import de.lmu.ifi.dbs.elki.distance.distancevalue.Distance; import de.lmu.ifi.dbs.elki.logging.Logging; +import de.lmu.ifi.dbs.elki.logging.statistics.DoubleStatistic; import de.lmu.ifi.dbs.elki.math.linearalgebra.Vector; import de.lmu.ifi.dbs.elki.utilities.datastructures.QuickSelect; import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints; @@ -60,28 +62,26 @@ import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter; * * @author Erich Schubert * - * @apiviz.has MeanModel * @apiviz.composedOf KMeansInitialization * * @param <V> Vector type - * @param <D> Distance type * @param <M> Cluster model type */ -public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distance<D>, M extends MeanModel<V>> extends AbstractPrimitiveDistanceBasedAlgorithm<NumberVector<?>, D, Clustering<M>> implements KMeans<V, D, M>, ClusteringAlgorithm<Clustering<M>> { +public abstract class AbstractKMeans<V extends NumberVector, M extends Model> extends AbstractPrimitiveDistanceBasedAlgorithm<NumberVector, Clustering<M>> implements KMeans<V, M>, ClusteringAlgorithm<Clustering<M>> { /** - * Holds the value of {@link #K_ID}. + * Number of cluster centers to initialize. */ protected int k; /** - * Holds the value of {@link #MAXITER_ID}. + * Maximum number of iterations */ protected int maxiter; /** * Method to choose initial means. */ - protected KMeansInitialization<V> initializer; + protected KMeansInitialization<? super V> initializer; /** * Constructor. @@ -91,7 +91,7 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan * @param maxiter Maxiter parameter * @param initializer Function to generate the initial means */ - public AbstractKMeans(PrimitiveDistanceFunction<? super NumberVector<?>, D> distanceFunction, int k, int maxiter, KMeansInitialization<V> initializer) { + public AbstractKMeans(PrimitiveDistanceFunction<? super NumberVector> distanceFunction, int k, int maxiter, KMeansInitialization<? super V> initializer) { super(distanceFunction); this.k = k; this.maxiter = maxiter; @@ -106,43 +106,26 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan * @param means a list of k means * @param clusters cluster assignment * @param assignment Current cluster assignment + * @param varsum Variance sum output * @return true when the object was reassigned */ - protected boolean assignToNearestCluster(Relation<V> relation, List<? extends NumberVector<?>> means, List<? extends ModifiableDBIDs> clusters, WritableIntegerDataStore assignment) { + protected boolean assignToNearestCluster(Relation<? extends V> relation, List<? extends NumberVector> means, List<? extends ModifiableDBIDs> clusters, WritableIntegerDataStore assignment, double[] varsum) { boolean changed = false; - - if(getDistanceFunction() instanceof PrimitiveDoubleDistanceFunction) { - @SuppressWarnings("unchecked") - final PrimitiveDoubleDistanceFunction<? super NumberVector<?>> df = (PrimitiveDoubleDistanceFunction<? super NumberVector<?>>) getDistanceFunction(); - for(DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) { - double mindist = Double.POSITIVE_INFINITY; - V fv = relation.get(iditer); - int minIndex = 0; - for(int i = 0; i < k; i++) { - double dist = df.doubleDistance(fv, means.get(i)); - if(dist < mindist) { - minIndex = i; - mindist = dist; - } + Arrays.fill(varsum, 0.); + final PrimitiveDistanceFunction<? super NumberVector> df = getDistanceFunction(); + for(DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) { + double mindist = Double.POSITIVE_INFINITY; + V fv = relation.get(iditer); + int minIndex = 0; + for(int i = 0; i < k; i++) { + double dist = df.distance(fv, means.get(i)); + if(dist < mindist) { + minIndex = i; + mindist = dist; } - changed |= updateAssignment(iditer, clusters, assignment, minIndex); - } - } - else { - final PrimitiveDistanceFunction<? super NumberVector<?>, D> df = getDistanceFunction(); - for(DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) { - D mindist = df.getDistanceFactory().infiniteDistance(); - V fv = relation.get(iditer); - int minIndex = 0; - for(int i = 0; i < k; i++) { - D dist = df.distance(fv, means.get(i)); - if(dist.compareTo(mindist) < 0) { - minIndex = i; - mindist = dist; - } - } - changed |= updateAssignment(iditer, clusters, assignment, minIndex); } + varsum[minIndex] += mindist; + changed |= updateAssignment(iditer, clusters, assignment, minIndex); } return changed; } @@ -173,7 +156,7 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan * @param database the database containing the vectors * @return the mean vectors of the given clusters in the given database */ - protected List<Vector> means(List<? extends ModifiableDBIDs> clusters, List<? extends NumberVector<?>> means, Relation<V> database) { + protected List<Vector> means(List<? extends ModifiableDBIDs> clusters, List<? extends NumberVector> means, Relation<V> database) { // TODO: use Kahan summation for better numerical precision? List<Vector> newMeans = new ArrayList<>(k); for(int i = 0; i < k; i++) { @@ -187,7 +170,7 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan iter.advance(); // Update with remaining instances for(; iter.valid(); iter.advance()) { - NumberVector<?> vec = database.get(iter); + NumberVector vec = database.get(iter); for(int j = 0; j < mean.getDimensionality(); j++) { raw[j] += vec.doubleValue(j); } @@ -211,24 +194,23 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan * @param database the database containing the vectors * @return the mean vectors of the given clusters in the given database */ - protected List<NumberVector<?>> medians(List<? extends ModifiableDBIDs> clusters, List<? extends NumberVector<?>> medians, Relation<V> database) { + protected List<Vector> medians(List<? extends ModifiableDBIDs> clusters, List<Vector> medians, Relation<V> database) { final int dim = medians.get(0).getDimensionality(); final SortDBIDsBySingleDimension sorter = new SortDBIDsBySingleDimension(database); - List<NumberVector<?>> newMedians = new ArrayList<>(k); + List<Vector> newMedians = new ArrayList<>(k); for(int i = 0; i < k; i++) { ArrayModifiableDBIDs list = DBIDUtil.newArray(clusters.get(i)); - if(list.size() > 0) { - Vector mean = new Vector(dim); - for(int d = 0; d < dim; d++) { - sorter.setDimension(d); - DBID id = QuickSelect.median(list, sorter); - mean.set(d, database.get(id).doubleValue(d)); - } - newMedians.add(mean); + if(list.size() <= 0) { + newMedians.add(medians.get(i)); + continue; } - else { - newMedians.add((NumberVector<?>) medians.get(i)); + Vector mean = new Vector(dim); + for(int d = 0; d < dim; d++) { + sorter.setDimension(d); + DBID id = QuickSelect.median(list, sorter); + mean.set(d, database.get(id).doubleValue(d)); } + newMedians.add(mean); } return newMedians; } @@ -256,49 +238,30 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan * @param means Means * @param clusters Clusters * @param assignment Current cluster assignment + * @param varsum Variance sum output * @return true when the means have changed */ - protected boolean macQueenIterate(Relation<V> relation, List<Vector> means, List<ModifiableDBIDs> clusters, WritableIntegerDataStore assignment) { + protected boolean macQueenIterate(Relation<V> relation, List<Vector> means, List<ModifiableDBIDs> clusters, WritableIntegerDataStore assignment, double[] varsum) { boolean changed = false; - - if(getDistanceFunction() instanceof PrimitiveDoubleDistanceFunction) { - // Raw distance function - @SuppressWarnings("unchecked") - final PrimitiveDoubleDistanceFunction<? super NumberVector<?>> df = (PrimitiveDoubleDistanceFunction<? super NumberVector<?>>) getDistanceFunction(); - - // Incremental update - for(DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) { - double mindist = Double.POSITIVE_INFINITY; - V fv = relation.get(iditer); - int minIndex = 0; - for(int i = 0; i < k; i++) { - double dist = df.doubleDistance(fv, means.get(i)); - if(dist < mindist) { - minIndex = i; - mindist = dist; - } - } - changed |= updateMeanAndAssignment(clusters, means, minIndex, fv, iditer, assignment); - } - } - else { - // Raw distance function - final PrimitiveDistanceFunction<? super NumberVector<?>, D> df = getDistanceFunction(); - - // Incremental update - for(DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) { - D mindist = df.getDistanceFactory().infiniteDistance(); - V fv = relation.get(iditer); - int minIndex = 0; - for(int i = 0; i < k; i++) { - D dist = df.distance(fv, means.get(i)); - if(dist.compareTo(mindist) < 0) { - minIndex = i; - mindist = dist; - } + Arrays.fill(varsum, 0.); + + // Raw distance function + final PrimitiveDistanceFunction<? super NumberVector> df = getDistanceFunction(); + + // Incremental update + for(DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) { + double mindist = Double.POSITIVE_INFINITY; + V fv = relation.get(iditer); + int minIndex = 0; + for(int i = 0; i < k; i++) { + double dist = df.distance(fv, means.get(i)); + if(dist < mindist) { + minIndex = i; + mindist = dist; } - changed |= updateMeanAndAssignment(clusters, means, minIndex, fv, iditer, assignment); } + varsum[minIndex] += mindist; + changed |= updateMeanAndAssignment(clusters, means, minIndex, fv, iditer, assignment); } return changed; } @@ -339,18 +302,36 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan } @Override - public void setDistanceFunction(PrimitiveDistanceFunction<? super NumberVector<?>, D> distanceFunction) { + public void setDistanceFunction(PrimitiveDistanceFunction<? super NumberVector> distanceFunction) { this.distanceFunction = distanceFunction; } /** + * Log statistics on the variance sum. + * + * @param varstat Statistics log instance + * @param varsum Variance sum per cluster + */ + protected void logVarstat(DoubleStatistic varstat, double[] varsum) { + if(varstat == null) { + return; + } + double s = 0.; + for(double v : varsum) { + s += v; + } + varstat.setDouble(s); + getLogger().statistics(varstat); + } + + /** * Parameterization class. * * @author Erich Schubert * * @apiviz.exclude */ - public abstract static class Parameterizer<V extends NumberVector<?>, D extends Distance<D>> extends AbstractPrimitiveDistanceBasedAlgorithm.Parameterizer<NumberVector<?>, D> { + public abstract static class Parameterizer<V extends NumberVector> extends AbstractPrimitiveDistanceBasedAlgorithm.Parameterizer<NumberVector> { /** * k Parameter. */ @@ -368,25 +349,58 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan @Override protected void makeOptions(Parameterization config) { - ObjectParameter<PrimitiveDistanceFunction<NumberVector<?>, D>> distanceFunctionP = makeParameterDistanceFunction(SquaredEuclideanDistanceFunction.class, PrimitiveDistanceFunction.class); - if(config.grab(distanceFunctionP)) { - distanceFunction = distanceFunctionP.instantiateClass(config); - if(!(distanceFunction instanceof EuclideanDistanceFunction) && !(distanceFunction instanceof SquaredEuclideanDistanceFunction)) { - getLogger().warning("k-means optimizes the sum of squares - it should be used with squared euclidean distance and may stop converging otherwise!"); - } - } + getParameterK(config); + getParameterInitialization(config); + getParameterDistanceFunction(config); + getParameterMaxIter(config); + } + /** + * Get the k parameter. + * + * @param config Parameterization + */ + protected void getParameterK(Parameterization config) { IntParameter kP = new IntParameter(K_ID); kP.addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT); if(config.grab(kP)) { k = kP.getValue(); } + } + /** + * Get the distance function parameter. + * + * @param config Parameterization + */ + protected void getParameterDistanceFunction(Parameterization config) { + ObjectParameter<PrimitiveDistanceFunction<NumberVector>> distanceFunctionP = makeParameterDistanceFunction(SquaredEuclideanDistanceFunction.class, PrimitiveDistanceFunction.class); + if(config.grab(distanceFunctionP)) { + distanceFunction = distanceFunctionP.instantiateClass(config); + if(!(distanceFunction instanceof EuclideanDistanceFunction) && !(distanceFunction instanceof SquaredEuclideanDistanceFunction)) { + getLogger().warning("k-means optimizes the sum of squares - it should be used with squared euclidean distance and may stop converging otherwise!"); + } + } + } + + /** + * Get the initialization method parameter. + * + * @param config Parameterization + */ + protected void getParameterInitialization(Parameterization config) { ObjectParameter<KMeansInitialization<V>> initialP = new ObjectParameter<>(INIT_ID, KMeansInitialization.class, RandomlyChosenInitialMeans.class); if(config.grab(initialP)) { initializer = initialP.instantiateClass(config); } + } + /** + * Get the max iterations parameter. + * + * @param config Parameterization + */ + protected void getParameterMaxIter(Parameterization config) { IntParameter maxiterP = new IntParameter(MAXITER_ID, 0); maxiterP.addConstraint(CommonConstraints.GREATER_EQUAL_ZERO_INT); if(config.grab(maxiterP)) { @@ -402,6 +416,6 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan abstract protected Logging getLogger(); @Override - abstract protected AbstractKMeans<V, D, ?> makeInstance(); + abstract protected AbstractKMeans<V, ?> makeInstance(); } } |