summaryrefslogtreecommitdiff
path: root/src/de/lmu/ifi/dbs/elki/algorithm/clustering/EM.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/clustering/EM.java')
-rw-r--r--src/de/lmu/ifi/dbs/elki/algorithm/clustering/EM.java129
1 files changed, 42 insertions, 87 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/EM.java b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/EM.java
index c1285659..a70a3f6f 100644
--- a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/EM.java
+++ b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/EM.java
@@ -4,7 +4,7 @@ package de.lmu.ifi.dbs.elki.algorithm.clustering;
This file is part of ELKI:
Environment for Developing KDD-Applications Supported by Index-Structures
- Copyright (C) 2011
+ Copyright (C) 2012
Ludwig-Maximilians-Universität München
Lehr- und Forschungseinheit für Datenbanksysteme
ELKI Development Team
@@ -25,9 +25,10 @@ package de.lmu.ifi.dbs.elki.algorithm.clustering;
import java.util.ArrayList;
import java.util.List;
-import java.util.Random;
import de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm;
+import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.KMeansInitialization;
+import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.RandomlyGeneratedInitialMeans;
import de.lmu.ifi.dbs.elki.data.Cluster;
import de.lmu.ifi.dbs.elki.data.Clustering;
import de.lmu.ifi.dbs.elki.data.NumberVector;
@@ -42,6 +43,7 @@ import de.lmu.ifi.dbs.elki.database.ids.DBID;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
+import de.lmu.ifi.dbs.elki.distance.distancefunction.EuclideanDistanceFunction;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.math.MathUtil;
import de.lmu.ifi.dbs.elki.math.linearalgebra.Matrix;
@@ -58,8 +60,7 @@ import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.GreaterEqualCons
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.DoubleParameter;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
-import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.LongParameter;
-import de.lmu.ifi.dbs.elki.utilities.pairs.Pair;
+import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
/**
* Provides the EM algorithm (clustering by expectation maximization).
@@ -113,6 +114,11 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri
*/
public static final OptionID DELTA_ID = OptionID.getOrCreateOptionID("em.delta", "The termination criterion for maximization of E(M): " + "E(M) - E(M') < em.delta");
+ /**
+ * Parameter to specify the initialization method
+ */
+ public static final OptionID INIT_ID = OptionID.getOrCreateOptionID("kmeans.initialization", "Method to choose the initial means.");
+
private static final double MIN_LOGLIKELIHOOD = -100000;
/**
@@ -121,32 +127,27 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri
private double delta;
/**
- * Parameter to specify the random generator seed.
+ * Store the individual probabilities, for use by EMOutlierDetection etc.
*/
- public static final OptionID SEED_ID = OptionID.getOrCreateOptionID("em.seed", "The random number generator seed.");
+ private WritableDataStore<double[]> probClusterIGivenX;
/**
- * Holds the value of {@link #SEED_ID}.
+ * Class to choose the initial means
*/
- private Long seed;
-
- /**
- * Store the individual probabilities, for use by EMOutlierDetection etc.
- */
- private WritableDataStore<double[]> probClusterIGivenX;
+ private KMeansInitialization<V> initializer;
/**
* Constructor.
*
* @param k k parameter
* @param delta delta parameter
- * @param seed Seed parameter
+ * @param initializer Class to choose the initial means
*/
- public EM(int k, double delta, Long seed) {
+ public EM(int k, double delta, KMeansInitialization<V> initializer) {
super();
this.k = k;
this.delta = delta;
- this.seed = seed;
+ this.initializer = initializer;
}
/**
@@ -169,14 +170,14 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri
if(logger.isVerbose()) {
logger.verbose("initializing " + k + " models");
}
- List<V> means = initialMeans(relation);
+ List<Vector> means = initializer.chooseInitialMeans(relation, k, EuclideanDistanceFunction.STATIC);
List<Matrix> covarianceMatrices = new ArrayList<Matrix>(k);
List<Double> normDistrFactor = new ArrayList<Double>(k);
List<Matrix> invCovMatr = new ArrayList<Matrix>(k);
List<Double> clusterWeights = new ArrayList<Double>(k);
probClusterIGivenX = DataStoreUtil.makeStorage(relation.getDBIDs(), DataStoreFactory.HINT_HOT | DataStoreFactory.HINT_SORTED, double[].class);
- int dimensionality = means.get(0).getDimensionality();
+ final int dimensionality = means.get(0).getDimensionality();
for(int i = 0; i < k; i++) {
Matrix m = Matrix.identity(dimensionality, dimensionality);
covarianceMatrices.add(m);
@@ -211,12 +212,12 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri
em = emNew;
// recompute models
- List<V> meanSums = new ArrayList<V>(k);
+ List<Vector> meanSums = new ArrayList<Vector>(k);
double[] sumOfClusterProbabilities = new double[k];
for(int i = 0; i < k; i++) {
clusterWeights.set(i, 0.0);
- meanSums.add(means.get(i).nullVector());
+ meanSums.add(new Vector(dimensionality));
covarianceMatrices.set(i, Matrix.zeroMatrix(dimensionality));
}
@@ -226,24 +227,23 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri
for(int i = 0; i < k; i++) {
sumOfClusterProbabilities[i] += clusterProbabilities[i];
- V summand = relation.get(id).multiplicate(clusterProbabilities[i]);
- V currentMeanSum = meanSums.get(i).plus(summand);
- meanSums.set(i, currentMeanSum);
+ Vector summand = relation.get(id).getColumnVector().timesEquals(clusterProbabilities[i]);
+ meanSums.get(i).plusEquals(summand);
}
}
final int n = relation.size();
for(int i = 0; i < k; i++) {
clusterWeights.set(i, sumOfClusterProbabilities[i] / n);
- V newMean = meanSums.get(i).multiplicate(1 / sumOfClusterProbabilities[i]);
+ Vector newMean = meanSums.get(i).timesEquals(1 / sumOfClusterProbabilities[i]);
means.set(i, newMean);
}
// covariance matrices
for(DBID id : relation.iterDBIDs()) {
double[] clusterProbabilities = probClusterIGivenX.get(id);
- V instance = relation.get(id);
+ Vector instance = relation.get(id).getColumnVector();
for(int i = 0; i < k; i++) {
- V difference = instance.minus(means.get(i));
- covarianceMatrices.get(i).plusEquals(difference.getColumnVector().times(difference.getRowVector()).times(clusterProbabilities[i]));
+ Vector difference = instance.minus(means.get(i));
+ covarianceMatrices.get(i).plusEquals(difference.timesTranspose(difference).timesEquals(clusterProbabilities[i]));
}
}
for(int i = 0; i < k; i++) {
@@ -281,13 +281,14 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri
}
hardClusters.get(maxIndex).add(id);
}
+ final V factory = DatabaseUtil.assumeVectorField(relation).getFactory();
Clustering<EMModel<V>> result = new Clustering<EMModel<V>>("EM Clustering", "em-clustering");
// provide models within the result
for(int i = 0; i < k; i++) {
// TODO: re-do labeling.
// SimpleClassLabel label = new SimpleClassLabel();
// label.init(result.canonicalClusterLabel(i));
- Cluster<EMModel<V>> model = new Cluster<EMModel<V>>(hardClusters.get(i), new EMModel<V>(means.get(i), covarianceMatrices.get(i)));
+ Cluster<EMModel<V>> model = new Cluster<EMModel<V>>(hardClusters.get(i), new EMModel<V>(factory.newNumberVector(means.get(i).getArrayRef()), covarianceMatrices.get(i)));
result.addCluster(model);
}
return result;
@@ -308,24 +309,20 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri
* @param clusterWeights the weights of the current clusters
* @return the expectation value of the current mixture of distributions
*/
- protected double assignProbabilitiesToInstances(Relation<V> database, List<Double> normDistrFactor, List<V> means, List<Matrix> invCovMatr, List<Double> clusterWeights, WritableDataStore<double[]> probClusterIGivenX) {
+ protected double assignProbabilitiesToInstances(Relation<V> database, List<Double> normDistrFactor, List<Vector> means, List<Matrix> invCovMatr, List<Double> clusterWeights, WritableDataStore<double[]> probClusterIGivenX) {
double emSum = 0.0;
for(DBID id : database.iterDBIDs()) {
- V x = database.get(id);
+ Vector x = database.get(id).getColumnVector();
List<Double> probabilities = new ArrayList<Double>(k);
for(int i = 0; i < k; i++) {
- V difference = x.minus(means.get(i));
- Matrix differenceRow = difference.getRowVector();
- Vector differenceCol = difference.getColumnVector();
- Matrix rowTimesCov = differenceRow.times(invCovMatr.get(i));
- Vector rowTimesCovTimesCol = rowTimesCov.times(differenceCol);
- double power = rowTimesCovTimesCol.get(0, 0) / 2.0;
+ Vector difference = x.minus(means.get(i));
+ double rowTimesCovTimesCol = difference.transposeTimesTimes(invCovMatr.get(i), difference);
+ double power = rowTimesCovTimesCol / 2.0;
double prob = normDistrFactor.get(i) * Math.exp(-power);
if(logger.isDebuggingFinest()) {
- logger.debugFinest(" difference vector= ( " + difference.toString() + " )\n" + " differenceRow:\n" + FormatUtil.format(differenceRow, " ") + "\n" + " differenceCol:\n" + FormatUtil.format(differenceCol, " ") + "\n" + " rowTimesCov:\n" + FormatUtil.format(rowTimesCov, " ") + "\n" + " rowTimesCovTimesCol:\n" + FormatUtil.format(rowTimesCovTimesCol, " ") + "\n" + " power= " + power + "\n" + " prob=" + prob + "\n" + " inv cov matrix: \n" + FormatUtil.format(invCovMatr.get(i), " "));
+ logger.debugFinest(" difference vector= ( " + difference.toString() + " )\n" + " difference:\n" + FormatUtil.format(difference, " ") + "\n" + " rowTimesCovTimesCol:\n" + rowTimesCovTimesCol + "\n" + " power= " + power + "\n" + " prob=" + prob + "\n" + " inv cov matrix: \n" + FormatUtil.format(invCovMatr.get(i), " "));
}
-
probabilities.add(prob);
}
double priorProbability = 0.0;
@@ -356,48 +353,6 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri
}
/**
- * Creates {@link #k k} random points distributed uniformly within the
- * attribute ranges of the given database.
- *
- * @param relation the database must contain enough points in order to
- * ascertain the range of attribute values. Less than two points would
- * make no sense. The content of the database is not touched otherwise.
- * @return a list of {@link #k k} random points distributed uniformly within
- * the attribute ranges of the given database
- */
- protected List<V> initialMeans(Relation<V> relation) {
- final Random random;
- if(this.seed != null) {
- random = new Random(this.seed);
- }
- else {
- random = new Random();
- }
- if(relation.size() > 0) {
- final int dim = DatabaseUtil.dimensionality(relation);
- Pair<V, V> minmax = DatabaseUtil.computeMinMax(relation);
- List<V> means = new ArrayList<V>(k);
- if(logger.isVerbose()) {
- logger.verbose("initializing random vectors");
- }
- for(int i = 0; i < k; i++) {
- double[] r = MathUtil.randomDoubleArray(dim, random);
- // Rescale
- for (int d = 0; d < dim; d++) {
- r[d] = minmax.first.doubleValue(d + 1) + (minmax.second.doubleValue(d + 1) - minmax.first.doubleValue(d + 1)) * r[d];
- }
- // Instantiate
- V randomVector = minmax.first.newInstance(r);
- means.add(randomVector);
- }
- return means;
- }
- else {
- return new ArrayList<V>(0);
- }
- }
-
- /**
* Get the probabilities for a given point.
*
* @param index Point ID
@@ -429,7 +384,7 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri
protected double delta;
- protected Long seed;
+ protected KMeansInitialization<V> initializer;
@Override
protected void makeOptions(Parameterization config) {
@@ -439,20 +394,20 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri
k = kP.getValue();
}
+ ObjectParameter<KMeansInitialization<V>> initialP = new ObjectParameter<KMeansInitialization<V>>(INIT_ID, KMeansInitialization.class, RandomlyGeneratedInitialMeans.class);
+ if(config.grab(initialP)) {
+ initializer = initialP.instantiateClass(config);
+ }
+
DoubleParameter deltaP = new DoubleParameter(DELTA_ID, new GreaterEqualConstraint(0.0), 0.0);
if(config.grab(deltaP)) {
delta = deltaP.getValue();
}
-
- LongParameter seedP = new LongParameter(SEED_ID, true);
- if(config.grab(seedP)) {
- seed = seedP.getValue();
- }
}
@Override
protected EM<V> makeInstance() {
- return new EM<V>(k, delta, seed);
+ return new EM<V>(k, delta, initializer);
}
}
} \ No newline at end of file