summaryrefslogtreecommitdiff
path: root/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java')
-rw-r--r--src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java195
1 files changed, 99 insertions, 96 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java
index dc1fa47c..5754e961 100644
--- a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java
+++ b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/kmeans/AbstractKMeans.java
@@ -35,6 +35,7 @@ import de.lmu.ifi.dbs.elki.data.model.MeanModel;
import de.lmu.ifi.dbs.elki.data.type.CombinedTypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
+import de.lmu.ifi.dbs.elki.database.datastore.WritableIntegerDataStore;
import de.lmu.ifi.dbs.elki.database.ids.ArrayModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.ids.DBID;
import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
@@ -49,8 +50,7 @@ import de.lmu.ifi.dbs.elki.distance.distancevalue.Distance;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.math.linearalgebra.Vector;
import de.lmu.ifi.dbs.elki.utilities.datastructures.QuickSelect;
-import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.GreaterConstraint;
-import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.GreaterEqualConstraint;
+import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
@@ -105,68 +105,61 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan
* @param relation the database to cluster
* @param means a list of k means
* @param clusters cluster assignment
+ * @param assignment Current cluster assignment
* @return true when the object was reassigned
*/
- protected boolean assignToNearestCluster(Relation<V> relation, List<? extends NumberVector<?>> means, List<? extends ModifiableDBIDs> clusters) {
+ protected boolean assignToNearestCluster(Relation<V> relation, List<? extends NumberVector<?>> means, List<? extends ModifiableDBIDs> clusters, WritableIntegerDataStore assignment) {
boolean changed = false;
- if (getDistanceFunction() instanceof PrimitiveDoubleDistanceFunction) {
+ if(getDistanceFunction() instanceof PrimitiveDoubleDistanceFunction) {
@SuppressWarnings("unchecked")
final PrimitiveDoubleDistanceFunction<? super NumberVector<?>> df = (PrimitiveDoubleDistanceFunction<? super NumberVector<?>>) getDistanceFunction();
- for (DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) {
+ for(DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) {
double mindist = Double.POSITIVE_INFINITY;
V fv = relation.get(iditer);
int minIndex = 0;
- for (int i = 0; i < k; i++) {
+ for(int i = 0; i < k; i++) {
double dist = df.doubleDistance(fv, means.get(i));
- if (dist < mindist) {
+ if(dist < mindist) {
minIndex = i;
mindist = dist;
}
}
- if (clusters.get(minIndex).add(iditer)) {
- changed = true;
- // Remove from previous cluster
- // TODO: keep a list of cluster assignments to save this search?
- for (int i = 0; i < k; i++) {
- if (i != minIndex) {
- if (clusters.get(i).remove(iditer)) {
- break;
- }
- }
- }
- }
+ changed |= updateAssignment(iditer, clusters, assignment, minIndex);
}
- } else {
+ }
+ else {
final PrimitiveDistanceFunction<? super NumberVector<?>, D> df = getDistanceFunction();
- for (DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) {
+ for(DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) {
D mindist = df.getDistanceFactory().infiniteDistance();
V fv = relation.get(iditer);
int minIndex = 0;
- for (int i = 0; i < k; i++) {
+ for(int i = 0; i < k; i++) {
D dist = df.distance(fv, means.get(i));
- if (dist.compareTo(mindist) < 0) {
+ if(dist.compareTo(mindist) < 0) {
minIndex = i;
mindist = dist;
}
}
- if (clusters.get(minIndex).add(iditer)) {
- changed = true;
- // Remove from previous cluster
- // TODO: keep a list of cluster assignments to save this search?
- for (int i = 0; i < k; i++) {
- if (i != minIndex) {
- if (clusters.get(i).remove(iditer)) {
- break;
- }
- }
- }
- }
+ changed |= updateAssignment(iditer, clusters, assignment, minIndex);
}
}
return changed;
}
+ protected boolean updateAssignment(DBIDIter iditer, List<? extends ModifiableDBIDs> clusters, WritableIntegerDataStore assignment, int newA) {
+ final int oldA = assignment.intValue(iditer);
+ if(oldA == newA) {
+ return false;
+ }
+ clusters.get(newA).add(iditer);
+ assignment.putInt(iditer, newA);
+ if(oldA >= 0) {
+ clusters.get(oldA).remove(iditer);
+ }
+ return true;
+ }
+
@Override
public TypeInformation[] getInputTypeRestriction() {
return TypeUtil.array(new CombinedTypeInformation(TypeUtil.NUMBER_VECTOR_FIELD, getDistanceFunction().getInputTypeRestriction()));
@@ -181,24 +174,28 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan
* @return the mean vectors of the given clusters in the given database
*/
protected List<Vector> means(List<? extends ModifiableDBIDs> clusters, List<? extends NumberVector<?>> means, Relation<V> database) {
+ // TODO: use Kahan summation for better numerical precision?
List<Vector> newMeans = new ArrayList<>(k);
- for (int i = 0; i < k; i++) {
+ for(int i = 0; i < k; i++) {
ModifiableDBIDs list = clusters.get(i);
Vector mean = null;
- if (list.size() > 0) {
- double s = 1.0 / list.size();
+ if(list.size() > 0) {
DBIDIter iter = list.iter();
- assert (iter.valid());
- mean = database.get(iter).getColumnVector().timesEquals(s);
+ // Initialize with first.
+ mean = database.get(iter).getColumnVector();
double[] raw = mean.getArrayRef();
iter.advance();
- for (; iter.valid(); iter.advance()) {
+ // Update with remaining instances
+ for(; iter.valid(); iter.advance()) {
NumberVector<?> vec = database.get(iter);
- for (int j = 0; j < mean.getDimensionality(); j++) {
- raw[j] += s * vec.doubleValue(j);
+ for(int j = 0; j < mean.getDimensionality(); j++) {
+ raw[j] += vec.doubleValue(j);
}
}
- } else {
+ mean.timesEquals(1.0 / list.size());
+ }
+ else {
+ // Keep degenerated means as-is for now.
mean = means.get(i).getColumnVector();
}
newMeans.add(mean);
@@ -218,17 +215,18 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan
final int dim = medians.get(0).getDimensionality();
final SortDBIDsBySingleDimension sorter = new SortDBIDsBySingleDimension(database);
List<NumberVector<?>> newMedians = new ArrayList<>(k);
- for (int i = 0; i < k; i++) {
+ for(int i = 0; i < k; i++) {
ArrayModifiableDBIDs list = DBIDUtil.newArray(clusters.get(i));
- if (list.size() > 0) {
+ if(list.size() > 0) {
Vector mean = new Vector(dim);
- for (int d = 0; d < dim; d++) {
+ for(int d = 0; d < dim; d++) {
sorter.setDimension(d);
DBID id = QuickSelect.median(list, sorter);
mean.set(d, database.get(id).doubleValue(d));
}
newMedians.add(mean);
- } else {
+ }
+ else {
newMedians.add((NumberVector<?>) medians.get(i));
}
}
@@ -244,14 +242,11 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan
* @param op Cluster size change / Weight change
*/
protected void incrementalUpdateMean(Vector mean, V vec, int newsize, double op) {
- if (newsize == 0) {
+ if(newsize == 0) {
return; // Keep old mean
}
- Vector delta = vec.getColumnVector();
- // Compute difference from mean
- delta.minusEquals(mean);
- delta.timesEquals(op / newsize);
- mean.plusEquals(delta);
+ Vector delta = vec.getColumnVector().minusEquals(mean);
+ mean.plusTimesEquals(delta, op / newsize);
}
/**
@@ -260,76 +255,84 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan
* @param relation Relation
* @param means Means
* @param clusters Clusters
+ * @param assignment Current cluster assignment
* @return true when the means have changed
*/
- protected boolean macQueenIterate(Relation<V> relation, List<Vector> means, List<ModifiableDBIDs> clusters) {
+ protected boolean macQueenIterate(Relation<V> relation, List<Vector> means, List<ModifiableDBIDs> clusters, WritableIntegerDataStore assignment) {
boolean changed = false;
- if (getDistanceFunction() instanceof PrimitiveDoubleDistanceFunction) {
+ if(getDistanceFunction() instanceof PrimitiveDoubleDistanceFunction) {
// Raw distance function
@SuppressWarnings("unchecked")
final PrimitiveDoubleDistanceFunction<? super NumberVector<?>> df = (PrimitiveDoubleDistanceFunction<? super NumberVector<?>>) getDistanceFunction();
// Incremental update
- for (DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) {
+ for(DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) {
double mindist = Double.POSITIVE_INFINITY;
V fv = relation.get(iditer);
int minIndex = 0;
- for (int i = 0; i < k; i++) {
+ for(int i = 0; i < k; i++) {
double dist = df.doubleDistance(fv, means.get(i));
- if (dist < mindist) {
+ if(dist < mindist) {
minIndex = i;
mindist = dist;
}
}
- // Update the cluster mean incrementally:
- for (int i = 0; i < k; i++) {
- ModifiableDBIDs ci = clusters.get(i);
- if (i == minIndex) {
- if (ci.add(iditer)) {
- incrementalUpdateMean(means.get(i), fv, ci.size(), +1);
- changed = true;
- }
- } else if (ci.remove(iditer)) {
- incrementalUpdateMean(means.get(i), fv, ci.size() + 1, -1);
- changed = true;
- }
- }
+ changed |= updateMeanAndAssignment(clusters, means, minIndex, fv, iditer, assignment);
}
- } else {
+ }
+ else {
// Raw distance function
final PrimitiveDistanceFunction<? super NumberVector<?>, D> df = getDistanceFunction();
// Incremental update
- for (DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) {
+ for(DBIDIter iditer = relation.iterDBIDs(); iditer.valid(); iditer.advance()) {
D mindist = df.getDistanceFactory().infiniteDistance();
V fv = relation.get(iditer);
int minIndex = 0;
- for (int i = 0; i < k; i++) {
+ for(int i = 0; i < k; i++) {
D dist = df.distance(fv, means.get(i));
- if (dist.compareTo(mindist) < 0) {
+ if(dist.compareTo(mindist) < 0) {
minIndex = i;
mindist = dist;
}
}
- // Update the cluster mean incrementally:
- for (int i = 0; i < k; i++) {
- ModifiableDBIDs ci = clusters.get(i);
- if (i == minIndex) {
- if (ci.add(iditer)) {
- incrementalUpdateMean(means.get(i), fv, ci.size(), +1);
- changed = true;
- }
- } else if (ci.remove(iditer)) {
- incrementalUpdateMean(means.get(i), fv, ci.size() + 1, -1);
- changed = true;
- }
- }
+ changed |= updateMeanAndAssignment(clusters, means, minIndex, fv, iditer, assignment);
}
}
return changed;
}
+ /**
+ * Try to update the cluster assignment.
+ *
+ * @param clusters Current clusters
+ * @param means Means to update
+ * @param minIndex Cluster to assign to
+ * @param fv Vector
+ * @param iditer Object ID
+ * @param assignment Current cluster assignment
+ * @return {@code true} when assignment changed
+ */
+ private boolean updateMeanAndAssignment(List<ModifiableDBIDs> clusters, List<Vector> means, int minIndex, V fv, DBIDIter iditer, WritableIntegerDataStore assignment) {
+ int cur = assignment.intValue(iditer);
+ if(cur == minIndex) {
+ return false;
+ }
+ final ModifiableDBIDs curclus = clusters.get(minIndex);
+ curclus.add(iditer);
+ incrementalUpdateMean(means.get(minIndex), fv, curclus.size(), +1);
+
+ if(cur >= 0) {
+ ModifiableDBIDs ci = clusters.get(cur);
+ ci.remove(iditer);
+ incrementalUpdateMean(means.get(cur), fv, ci.size() + 1, -1);
+ }
+
+ assignment.putInt(iditer, minIndex);
+ return true;
+ }
+
@Override
public void setK(int k) {
this.k = k;
@@ -366,27 +369,27 @@ public abstract class AbstractKMeans<V extends NumberVector<?>, D extends Distan
@Override
protected void makeOptions(Parameterization config) {
ObjectParameter<PrimitiveDistanceFunction<NumberVector<?>, D>> distanceFunctionP = makeParameterDistanceFunction(SquaredEuclideanDistanceFunction.class, PrimitiveDistanceFunction.class);
- if (config.grab(distanceFunctionP)) {
+ if(config.grab(distanceFunctionP)) {
distanceFunction = distanceFunctionP.instantiateClass(config);
- if (!(distanceFunction instanceof EuclideanDistanceFunction) && !(distanceFunction instanceof SquaredEuclideanDistanceFunction)) {
+ if(!(distanceFunction instanceof EuclideanDistanceFunction) && !(distanceFunction instanceof SquaredEuclideanDistanceFunction)) {
getLogger().warning("k-means optimizes the sum of squares - it should be used with squared euclidean distance and may stop converging otherwise!");
}
}
IntParameter kP = new IntParameter(K_ID);
- kP.addConstraint(new GreaterConstraint(0));
- if (config.grab(kP)) {
+ kP.addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT);
+ if(config.grab(kP)) {
k = kP.getValue();
}
ObjectParameter<KMeansInitialization<V>> initialP = new ObjectParameter<>(INIT_ID, KMeansInitialization.class, RandomlyChosenInitialMeans.class);
- if (config.grab(initialP)) {
+ if(config.grab(initialP)) {
initializer = initialP.instantiateClass(config);
}
IntParameter maxiterP = new IntParameter(MAXITER_ID, 0);
- maxiterP.addConstraint(new GreaterEqualConstraint(0));
- if (config.grab(maxiterP)) {
+ maxiterP.addConstraint(CommonConstraints.GREATER_EQUAL_ZERO_INT);
+ if(config.grab(maxiterP)) {
maxiter = maxiterP.getValue();
}
}