diff options
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/clustering/EM.java')
-rw-r--r-- | src/de/lmu/ifi/dbs/elki/algorithm/clustering/EM.java | 129 |
1 files changed, 42 insertions, 87 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/EM.java b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/EM.java index c1285659..a70a3f6f 100644 --- a/src/de/lmu/ifi/dbs/elki/algorithm/clustering/EM.java +++ b/src/de/lmu/ifi/dbs/elki/algorithm/clustering/EM.java @@ -4,7 +4,7 @@ package de.lmu.ifi.dbs.elki.algorithm.clustering; This file is part of ELKI: Environment for Developing KDD-Applications Supported by Index-Structures - Copyright (C) 2011 + Copyright (C) 2012 Ludwig-Maximilians-Universität München Lehr- und Forschungseinheit für Datenbanksysteme ELKI Development Team @@ -25,9 +25,10 @@ package de.lmu.ifi.dbs.elki.algorithm.clustering; import java.util.ArrayList; import java.util.List; -import java.util.Random; import de.lmu.ifi.dbs.elki.algorithm.AbstractAlgorithm; +import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.KMeansInitialization; +import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.RandomlyGeneratedInitialMeans; import de.lmu.ifi.dbs.elki.data.Cluster; import de.lmu.ifi.dbs.elki.data.Clustering; import de.lmu.ifi.dbs.elki.data.NumberVector; @@ -42,6 +43,7 @@ import de.lmu.ifi.dbs.elki.database.ids.DBID; import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil; import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs; import de.lmu.ifi.dbs.elki.database.relation.Relation; +import de.lmu.ifi.dbs.elki.distance.distancefunction.EuclideanDistanceFunction; import de.lmu.ifi.dbs.elki.logging.Logging; import de.lmu.ifi.dbs.elki.math.MathUtil; import de.lmu.ifi.dbs.elki.math.linearalgebra.Matrix; @@ -58,8 +60,7 @@ import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.GreaterEqualCons import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization; import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.DoubleParameter; import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter; -import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.LongParameter; -import de.lmu.ifi.dbs.elki.utilities.pairs.Pair; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter; /** * Provides the EM algorithm (clustering by expectation maximization). @@ -113,6 +114,11 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri */ public static final OptionID DELTA_ID = OptionID.getOrCreateOptionID("em.delta", "The termination criterion for maximization of E(M): " + "E(M) - E(M') < em.delta"); + /** + * Parameter to specify the initialization method + */ + public static final OptionID INIT_ID = OptionID.getOrCreateOptionID("kmeans.initialization", "Method to choose the initial means."); + private static final double MIN_LOGLIKELIHOOD = -100000; /** @@ -121,32 +127,27 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri private double delta; /** - * Parameter to specify the random generator seed. + * Store the individual probabilities, for use by EMOutlierDetection etc. */ - public static final OptionID SEED_ID = OptionID.getOrCreateOptionID("em.seed", "The random number generator seed."); + private WritableDataStore<double[]> probClusterIGivenX; /** - * Holds the value of {@link #SEED_ID}. + * Class to choose the initial means */ - private Long seed; - - /** - * Store the individual probabilities, for use by EMOutlierDetection etc. - */ - private WritableDataStore<double[]> probClusterIGivenX; + private KMeansInitialization<V> initializer; /** * Constructor. * * @param k k parameter * @param delta delta parameter - * @param seed Seed parameter + * @param initializer Class to choose the initial means */ - public EM(int k, double delta, Long seed) { + public EM(int k, double delta, KMeansInitialization<V> initializer) { super(); this.k = k; this.delta = delta; - this.seed = seed; + this.initializer = initializer; } /** @@ -169,14 +170,14 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri if(logger.isVerbose()) { logger.verbose("initializing " + k + " models"); } - List<V> means = initialMeans(relation); + List<Vector> means = initializer.chooseInitialMeans(relation, k, EuclideanDistanceFunction.STATIC); List<Matrix> covarianceMatrices = new ArrayList<Matrix>(k); List<Double> normDistrFactor = new ArrayList<Double>(k); List<Matrix> invCovMatr = new ArrayList<Matrix>(k); List<Double> clusterWeights = new ArrayList<Double>(k); probClusterIGivenX = DataStoreUtil.makeStorage(relation.getDBIDs(), DataStoreFactory.HINT_HOT | DataStoreFactory.HINT_SORTED, double[].class); - int dimensionality = means.get(0).getDimensionality(); + final int dimensionality = means.get(0).getDimensionality(); for(int i = 0; i < k; i++) { Matrix m = Matrix.identity(dimensionality, dimensionality); covarianceMatrices.add(m); @@ -211,12 +212,12 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri em = emNew; // recompute models - List<V> meanSums = new ArrayList<V>(k); + List<Vector> meanSums = new ArrayList<Vector>(k); double[] sumOfClusterProbabilities = new double[k]; for(int i = 0; i < k; i++) { clusterWeights.set(i, 0.0); - meanSums.add(means.get(i).nullVector()); + meanSums.add(new Vector(dimensionality)); covarianceMatrices.set(i, Matrix.zeroMatrix(dimensionality)); } @@ -226,24 +227,23 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri for(int i = 0; i < k; i++) { sumOfClusterProbabilities[i] += clusterProbabilities[i]; - V summand = relation.get(id).multiplicate(clusterProbabilities[i]); - V currentMeanSum = meanSums.get(i).plus(summand); - meanSums.set(i, currentMeanSum); + Vector summand = relation.get(id).getColumnVector().timesEquals(clusterProbabilities[i]); + meanSums.get(i).plusEquals(summand); } } final int n = relation.size(); for(int i = 0; i < k; i++) { clusterWeights.set(i, sumOfClusterProbabilities[i] / n); - V newMean = meanSums.get(i).multiplicate(1 / sumOfClusterProbabilities[i]); + Vector newMean = meanSums.get(i).timesEquals(1 / sumOfClusterProbabilities[i]); means.set(i, newMean); } // covariance matrices for(DBID id : relation.iterDBIDs()) { double[] clusterProbabilities = probClusterIGivenX.get(id); - V instance = relation.get(id); + Vector instance = relation.get(id).getColumnVector(); for(int i = 0; i < k; i++) { - V difference = instance.minus(means.get(i)); - covarianceMatrices.get(i).plusEquals(difference.getColumnVector().times(difference.getRowVector()).times(clusterProbabilities[i])); + Vector difference = instance.minus(means.get(i)); + covarianceMatrices.get(i).plusEquals(difference.timesTranspose(difference).timesEquals(clusterProbabilities[i])); } } for(int i = 0; i < k; i++) { @@ -281,13 +281,14 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri } hardClusters.get(maxIndex).add(id); } + final V factory = DatabaseUtil.assumeVectorField(relation).getFactory(); Clustering<EMModel<V>> result = new Clustering<EMModel<V>>("EM Clustering", "em-clustering"); // provide models within the result for(int i = 0; i < k; i++) { // TODO: re-do labeling. // SimpleClassLabel label = new SimpleClassLabel(); // label.init(result.canonicalClusterLabel(i)); - Cluster<EMModel<V>> model = new Cluster<EMModel<V>>(hardClusters.get(i), new EMModel<V>(means.get(i), covarianceMatrices.get(i))); + Cluster<EMModel<V>> model = new Cluster<EMModel<V>>(hardClusters.get(i), new EMModel<V>(factory.newNumberVector(means.get(i).getArrayRef()), covarianceMatrices.get(i))); result.addCluster(model); } return result; @@ -308,24 +309,20 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri * @param clusterWeights the weights of the current clusters * @return the expectation value of the current mixture of distributions */ - protected double assignProbabilitiesToInstances(Relation<V> database, List<Double> normDistrFactor, List<V> means, List<Matrix> invCovMatr, List<Double> clusterWeights, WritableDataStore<double[]> probClusterIGivenX) { + protected double assignProbabilitiesToInstances(Relation<V> database, List<Double> normDistrFactor, List<Vector> means, List<Matrix> invCovMatr, List<Double> clusterWeights, WritableDataStore<double[]> probClusterIGivenX) { double emSum = 0.0; for(DBID id : database.iterDBIDs()) { - V x = database.get(id); + Vector x = database.get(id).getColumnVector(); List<Double> probabilities = new ArrayList<Double>(k); for(int i = 0; i < k; i++) { - V difference = x.minus(means.get(i)); - Matrix differenceRow = difference.getRowVector(); - Vector differenceCol = difference.getColumnVector(); - Matrix rowTimesCov = differenceRow.times(invCovMatr.get(i)); - Vector rowTimesCovTimesCol = rowTimesCov.times(differenceCol); - double power = rowTimesCovTimesCol.get(0, 0) / 2.0; + Vector difference = x.minus(means.get(i)); + double rowTimesCovTimesCol = difference.transposeTimesTimes(invCovMatr.get(i), difference); + double power = rowTimesCovTimesCol / 2.0; double prob = normDistrFactor.get(i) * Math.exp(-power); if(logger.isDebuggingFinest()) { - logger.debugFinest(" difference vector= ( " + difference.toString() + " )\n" + " differenceRow:\n" + FormatUtil.format(differenceRow, " ") + "\n" + " differenceCol:\n" + FormatUtil.format(differenceCol, " ") + "\n" + " rowTimesCov:\n" + FormatUtil.format(rowTimesCov, " ") + "\n" + " rowTimesCovTimesCol:\n" + FormatUtil.format(rowTimesCovTimesCol, " ") + "\n" + " power= " + power + "\n" + " prob=" + prob + "\n" + " inv cov matrix: \n" + FormatUtil.format(invCovMatr.get(i), " ")); + logger.debugFinest(" difference vector= ( " + difference.toString() + " )\n" + " difference:\n" + FormatUtil.format(difference, " ") + "\n" + " rowTimesCovTimesCol:\n" + rowTimesCovTimesCol + "\n" + " power= " + power + "\n" + " prob=" + prob + "\n" + " inv cov matrix: \n" + FormatUtil.format(invCovMatr.get(i), " ")); } - probabilities.add(prob); } double priorProbability = 0.0; @@ -356,48 +353,6 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri } /** - * Creates {@link #k k} random points distributed uniformly within the - * attribute ranges of the given database. - * - * @param relation the database must contain enough points in order to - * ascertain the range of attribute values. Less than two points would - * make no sense. The content of the database is not touched otherwise. - * @return a list of {@link #k k} random points distributed uniformly within - * the attribute ranges of the given database - */ - protected List<V> initialMeans(Relation<V> relation) { - final Random random; - if(this.seed != null) { - random = new Random(this.seed); - } - else { - random = new Random(); - } - if(relation.size() > 0) { - final int dim = DatabaseUtil.dimensionality(relation); - Pair<V, V> minmax = DatabaseUtil.computeMinMax(relation); - List<V> means = new ArrayList<V>(k); - if(logger.isVerbose()) { - logger.verbose("initializing random vectors"); - } - for(int i = 0; i < k; i++) { - double[] r = MathUtil.randomDoubleArray(dim, random); - // Rescale - for (int d = 0; d < dim; d++) { - r[d] = minmax.first.doubleValue(d + 1) + (minmax.second.doubleValue(d + 1) - minmax.first.doubleValue(d + 1)) * r[d]; - } - // Instantiate - V randomVector = minmax.first.newInstance(r); - means.add(randomVector); - } - return means; - } - else { - return new ArrayList<V>(0); - } - } - - /** * Get the probabilities for a given point. * * @param index Point ID @@ -429,7 +384,7 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri protected double delta; - protected Long seed; + protected KMeansInitialization<V> initializer; @Override protected void makeOptions(Parameterization config) { @@ -439,20 +394,20 @@ public class EM<V extends NumberVector<V, ?>> extends AbstractAlgorithm<Clusteri k = kP.getValue(); } + ObjectParameter<KMeansInitialization<V>> initialP = new ObjectParameter<KMeansInitialization<V>>(INIT_ID, KMeansInitialization.class, RandomlyGeneratedInitialMeans.class); + if(config.grab(initialP)) { + initializer = initialP.instantiateClass(config); + } + DoubleParameter deltaP = new DoubleParameter(DELTA_ID, new GreaterEqualConstraint(0.0), 0.0); if(config.grab(deltaP)) { delta = deltaP.getValue(); } - - LongParameter seedP = new LongParameter(SEED_ID, true); - if(config.grab(seedP)) { - seed = seedP.getValue(); - } } @Override protected EM<V> makeInstance() { - return new EM<V>(k, delta, seed); + return new EM<V>(k, delta, initializer); } } }
\ No newline at end of file |