diff options
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java')
-rw-r--r-- | src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java | 195 |
1 files changed, 99 insertions, 96 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java index dc1fa47c..5754e961 100644 --- a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java +++ b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java @@ -35,6 +35,7 @@ import de.lmu.ifi.dbs.elki.data.model.MeanModel; import de.lmu.ifi.dbs.elki.data.type.CombinedTypeInformation; import de.lmu.ifi.dbs.elki.data.type.TypeInformation; import de.lmu.ifi.dbs.elki.data.type.TypeUtil; +import de.lmu.ifi.dbs.elki.database.datastore.WritableIntegerDataStore; import de.lmu.ifi.dbs.elki.database.ids.ArrayModifiableDBIDs; import de.lmu.ifi.dbs.elki.database.ids.DBID; import de.lmu.ifi.dbs.elki.database.ids.DBIDIter; @@ -49,8 +50,7 @@ import de.lmu.ifi.dbs.elki.distance.distancevalue.Distance; import de.lmu.ifi.dbs.elki.logging.Logging; import de.lmu.ifi.dbs.elki.math.linearalgebra.Vector; import de.lmu.ifi.dbs.elki.utilities.datastructures.QuickSelect; -import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.GreaterConstraint; -import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.GreaterEqualConstraint; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints; import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization; import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter; import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter; @@ -105,68 +105,61 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan * @param relation the database to cluster * @param means a list of k means * @param clusters cluster assignment + * @param assignment Current cluster assignment * @return true when the object was reassigned */ - protected boolean assignToNearestCluster(Relation<V> relation, List<? extends NumberVector<?>> means, List<? extends ModifiableDBIDs> clusters) { + protected boolean assignToNearestCluster(Relation<V> relation, List<? extends NumberVector<?>> means, List<? extends ModifiableDBIDs> clusters, WritableIntegerDataStore assignment) { boolean changed = false; - if (getDistanceFunction() instanceof PrimitiveDoubleDistanceFunction) { + if(getDistanceFunction() instanceof PrimitiveDoubleDistanceFunction) { @SuppressWarnings("unchecked") final PrimitiveDoubleDistanceFunction<? super NumberVector<?>> df = (PrimitiveDoubleDistanceFunction<? super NumberVector<?>>) getDistanceFunction(); - for (DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) { + for(DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) { double mindist = Double.POSITIVE_INFINITY; V fv = relation.get(iditer); int minIndex = 0; - for (int i = 0; i < k; i++) { + for(int i = 0; i < k; i++) { double dist = df.doubleDistance(fv, means.get(i)); - if (dist < mindist) { + if(dist < mindist) { minIndex = i; mindist = dist; } } - if (clusters.get(minIndex).add(iditer)) { - changed = true; - // Remove from previous cluster - // TODO: keep a list of cluster assignments to save this search? - for (int i = 0; i < k; i++) { - if (i != minIndex) { - if (clusters.get(i).remove(iditer)) { - break; - } - } - } - } + changed |= updateAssignment(iditer, clusters, assignment, minIndex); } - } else { + } + else { final PrimitiveDistanceFunction<? super NumberVector<?>, D> df = getDistanceFunction(); - for (DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) { + for(DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) { D mindist = df.getDistanceFactory().infiniteDistance(); V fv = relation.get(iditer); int minIndex = 0; - for (int i = 0; i < k; i++) { + for(int i = 0; i < k; i++) { D dist = df.distance(fv, means.get(i)); - if (dist.compareTo(mindist) < 0) { + if(dist.compareTo(mindist) < 0) { minIndex = i; mindist = dist; } } - if (clusters.get(minIndex).add(iditer)) { - changed = true; - // Remove from previous cluster - // TODO: keep a list of cluster assignments to save this search? - for (int i = 0; i < k; i++) { - if (i != minIndex) { - if (clusters.get(i).remove(iditer)) { - break; - } - } - } - } + changed |= updateAssignment(iditer, clusters, assignment, minIndex); } } return changed; } + protected boolean updateAssignment(DBIDIter iditer, List<? extends ModifiableDBIDs> clusters, WritableIntegerDataStore assignment, int newA) { + final int oldA = assignment.intValue(iditer); + if(oldA == newA) { + return false; + } + clusters.get(newA).add(iditer); + assignment.putInt(iditer, newA); + if(oldA >= 0) { + clusters.get(oldA).remove(iditer); + } + return true; + } + @Override public TypeInformation[] getInputTypeRestriction() { return TypeUtil.array(new CombinedTypeInformation(TypeUtil.NUMBER_VECTOR_FIELD, getDistanceFunction().getInputTypeRestriction())); @@ -181,24 +174,28 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan * @return the mean vectors of the given clusters in the given database */ protected List<Vector> means(List<? extends ModifiableDBIDs> clusters, List<? extends NumberVector<?>> means, Relation<V> database) { + // TODO: use Kahan summation for better numerical precision? List<Vector> newMeans = new ArrayList<>(k); - for (int i = 0; i < k; i++) { + for(int i = 0; i < k; i++) { ModifiableDBIDs list = clusters.get(i); Vector mean = null; - if (list.size() > 0) { - double s = 1.0 / list.size(); + if(list.size() > 0) { DBIDIter iter = list.iter(); - assert (iter.valid()); - mean = database.get(iter).getColumnVector().timesEquals(s); + // Initialize with first. + mean = database.get(iter).getColumnVector(); double[] raw = mean.getArrayRef(); iter.advance(); - for (; iter.valid(); iter.advance()) { + // Update with remaining instances + for(; iter.valid(); iter.advance()) { NumberVector<?> vec = database.get(iter); - for (int j = 0; j < mean.getDimensionality(); j++) { - raw[j] += s * vec.doubleValue(j); + for(int j = 0; j < mean.getDimensionality(); j++) { + raw[j] += vec.doubleValue(j); } } - } else { + mean.timesEquals(1.0 / list.size()); + } + else { + // Keep degenerated means as-is for now. mean = means.get(i).getColumnVector(); } newMeans.add(mean); @@ -218,17 +215,18 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan final int dim = medians.get(0).getDimensionality(); final SortDBIDsBySingleDimension sorter = new SortDBIDsBySingleDimension(database); List<NumberVector<?>> newMedians = new ArrayList<>(k); - for (int i = 0; i < k; i++) { + for(int i = 0; i < k; i++) { ArrayModifiableDBIDs list = DBIDUtil.newArray(clusters.get(i)); - if (list.size() > 0) { + if(list.size() > 0) { Vector mean = new Vector(dim); - for (int d = 0; d < dim; d++) { + for(int d = 0; d < dim; d++) { sorter.setDimension(d); DBID id = QuickSelect.median(list, sorter); mean.set(d, database.get(id).doubleValue(d)); } newMedians.add(mean); - } else { + } + else { newMedians.add((NumberVector<?>) medians.get(i)); } } @@ -244,14 +242,11 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan * @param op Cluster size change / Weight change */ protected void incrementalUpdateMean(Vector mean, V vec, int newsize, double op) { - if (newsize == 0) { + if(newsize == 0) { return; // Keep old mean } - Vector delta = vec.getColumnVector(); - // Compute difference from mean - delta.minusEquals(mean); - delta.timesEquals(op / newsize); - mean.plusEquals(delta); + Vector delta = vec.getColumnVector().minusEquals(mean); + mean.plusTimesEquals(delta, op / newsize); } /** @@ -260,76 +255,84 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan * @param relation Relation * @param means Means * @param clusters Clusters + * @param assignment Current cluster assignment * @return true when the means have changed */ - protected boolean macQueenIterate(Relation<V> relation, List<Vector> means, List<ModifiableDBIDs> clusters) { + protected boolean macQueenIterate(Relation<V> relation, List<Vector> means, List<ModifiableDBIDs> clusters, WritableIntegerDataStore assignment) { boolean changed = false; - if (getDistanceFunction() instanceof PrimitiveDoubleDistanceFunction) { + if(getDistanceFunction() instanceof PrimitiveDoubleDistanceFunction) { // Raw distance function @SuppressWarnings("unchecked") final PrimitiveDoubleDistanceFunction<? super NumberVector<?>> df = (PrimitiveDoubleDistanceFunction<? super NumberVector<?>>) getDistanceFunction(); // Incremental update - for (DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) { + for(DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) { double mindist = Double.POSITIVE_INFINITY; V fv = relation.get(iditer); int minIndex = 0; - for (int i = 0; i < k; i++) { + for(int i = 0; i < k; i++) { double dist = df.doubleDistance(fv, means.get(i)); - if (dist < mindist) { + if(dist < mindist) { minIndex = i; mindist = dist; } } - // Update the cluster mean incrementally: - for (int i = 0; i < k; i++) { - ModifiableDBIDs ci = clusters.get(i); - if (i == minIndex) { - if (ci.add(iditer)) { - incrementalUpdateMean(means.get(i), fv, ci.size(), +1); - changed = true; - } - } else if (ci.remove(iditer)) { - incrementalUpdateMean(means.get(i), fv, ci.size() + 1, -1); - changed = true; - } - } + changed |= updateMeanAndAssignment(clusters, means, minIndex, fv, iditer, assignment); } - } else { + } + else { // Raw distance function final PrimitiveDistanceFunction<? super NumberVector<?>, D> df = getDistanceFunction(); // Incremental update - for (DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) { + for(DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) { D mindist = df.getDistanceFactory().infiniteDistance(); V fv = relation.get(iditer); int minIndex = 0; - for (int i = 0; i < k; i++) { + for(int i = 0; i < k; i++) { D dist = df.distance(fv, means.get(i)); - if (dist.compareTo(mindist) < 0) { + if(dist.compareTo(mindist) < 0) { minIndex = i; mindist = dist; } } - // Update the cluster mean incrementally: - for (int i = 0; i < k; i++) { - ModifiableDBIDs ci = clusters.get(i); - if (i == minIndex) { - if (ci.add(iditer)) { - incrementalUpdateMean(means.get(i), fv, ci.size(), +1); - changed = true; - } - } else if (ci.remove(iditer)) { - incrementalUpdateMean(means.get(i), fv, ci.size() + 1, -1); - changed = true; - } - } + changed |= updateMeanAndAssignment(clusters, means, minIndex, fv, iditer, assignment); } } return changed; } + /** + * Try to update the cluster assignment. + * + * @param clusters Current clusters + * @param means Means to update + * @param minIndex Cluster to assign to + * @param fv Vector + * @param iditer Object ID + * @param assignment Current cluster assignment + * @return {@code true} when assignment changed + */ + private boolean updateMeanAndAssignment(List<ModifiableDBIDs> clusters, List<Vector> means, int minIndex, V fv, DBIDIter iditer, WritableIntegerDataStore assignment) { + int cur = assignment.intValue(iditer); + if(cur == minIndex) { + return false; + } + final ModifiableDBIDs curclus = clusters.get(minIndex); + curclus.add(iditer); + incrementalUpdateMean(means.get(minIndex), fv, curclus.size(), +1); + + if(cur >= 0) { + ModifiableDBIDs ci = clusters.get(cur); + ci.remove(iditer); + incrementalUpdateMean(means.get(cur), fv, ci.size() + 1, -1); + } + + assignment.putInt(iditer, minIndex); + return true; + } + @Override public void setK(int k) { this.k = k; @@ -366,27 +369,27 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan @Override protected void makeOptions(Parameterization config) { ObjectParameter<PrimitiveDistanceFunction<NumberVector<?>, D>> distanceFunctionP = makeParameterDistanceFunction(SquaredEuclideanDistanceFunction.class, PrimitiveDistanceFunction.class); - if (config.grab(distanceFunctionP)) { + if(config.grab(distanceFunctionP)) { distanceFunction = distanceFunctionP.instantiateClass(config); - if (!(distanceFunction instanceof EuclideanDistanceFunction) && !(distanceFunction instanceof SquaredEuclideanDistanceFunction)) { + if(!(distanceFunction instanceof EuclideanDistanceFunction) && !(distanceFunction instanceof SquaredEuclideanDistanceFunction)) { getLogger().warning("k-means optimizes the sum of squares - it should be used with squared euclidean distance and may stop converging otherwise!"); } } IntParameter kP = new IntParameter(K_ID); - kP.addConstraint(new GreaterConstraint(0)); - if (config.grab(kP)) { + kP.addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT); + if(config.grab(kP)) { k = kP.getValue(); } ObjectParameter<KMeansInitialization<V>> initialP = new ObjectParameter<>(INIT_ID, KMeansInitialization.class, RandomlyChosenInitialMeans.class); - if (config.grab(initialP)) { + if(config.grab(initialP)) { initializer = initialP.instantiateClass(config); } IntParameter maxiterP = new IntParameter(MAXITER_ID, 0); - maxiterP.addConstraint(new GreaterEqualConstraint(0)); - if (config.grab(maxiterP)) { + maxiterP.addConstraint(CommonConstraints.GREATER_EQUAL_ZERO_INT); + if(config.grab(maxiterP)) { maxiter = maxiterP.getValue(); } } |