diff options
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/statistics/EvaluateRankingQuality.java')
-rw-r--r-- | src/de/lmu/ifi/dbs/elki/algorithm/statistics/EvaluateRankingQuality.java | 219 |
1 files changed, 219 insertions, 0 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/statistics/EvaluateRankingQuality.java b/src/de/lmu/ifi/dbs/elki/algorithm/statistics/EvaluateRankingQuality.java new file mode 100644 index 00000000..cbbe578e --- /dev/null +++ b/src/de/lmu/ifi/dbs/elki/algorithm/statistics/EvaluateRankingQuality.java @@ -0,0 +1,219 @@ +package de.lmu.ifi.dbs.elki.algorithm.statistics; +/* +This file is part of ELKI: +Environment for Developing KDD-Applications Supported by Index-Structures + +Copyright (C) 2011 +Ludwig-Maximilians-Universität München +Lehr- und Forschungseinheit für Datenbanksysteme +ELKI Development Team + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU Affero General Public License as published by +the Free Software Foundation, either version 3 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU Affero General Public License for more details. + +You should have received a copy of the GNU Affero General Public License +along with this program. If not, see <http://www.gnu.org/licenses/>. +*/ + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; + +import de.lmu.ifi.dbs.elki.algorithm.AbstractDistanceBasedAlgorithm; +import de.lmu.ifi.dbs.elki.algorithm.clustering.trivial.ByLabelClustering; +import de.lmu.ifi.dbs.elki.data.Cluster; +import de.lmu.ifi.dbs.elki.data.DoubleVector; +import de.lmu.ifi.dbs.elki.data.NumberVector; +import de.lmu.ifi.dbs.elki.data.model.Model; +import de.lmu.ifi.dbs.elki.data.type.CombinedTypeInformation; +import de.lmu.ifi.dbs.elki.data.type.TypeInformation; +import de.lmu.ifi.dbs.elki.data.type.TypeUtil; +import de.lmu.ifi.dbs.elki.database.Database; +import de.lmu.ifi.dbs.elki.database.ids.DBID; +import de.lmu.ifi.dbs.elki.database.query.DistanceResultPair; +import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery; +import de.lmu.ifi.dbs.elki.database.query.knn.KNNQuery; +import de.lmu.ifi.dbs.elki.database.relation.Relation; +import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction; +import de.lmu.ifi.dbs.elki.distance.distancevalue.NumberDistance; +import de.lmu.ifi.dbs.elki.evaluation.roc.ROC; +import de.lmu.ifi.dbs.elki.logging.Logging; +import de.lmu.ifi.dbs.elki.logging.progress.FiniteProgress; +import de.lmu.ifi.dbs.elki.math.AggregatingHistogram; +import de.lmu.ifi.dbs.elki.math.MathUtil; +import de.lmu.ifi.dbs.elki.math.MeanVariance; +import de.lmu.ifi.dbs.elki.math.linearalgebra.Matrix; +import de.lmu.ifi.dbs.elki.math.linearalgebra.Vector; +import de.lmu.ifi.dbs.elki.result.CollectionResult; +import de.lmu.ifi.dbs.elki.result.HistogramResult; +import de.lmu.ifi.dbs.elki.utilities.DatabaseUtil; +import de.lmu.ifi.dbs.elki.utilities.documentation.Description; +import de.lmu.ifi.dbs.elki.utilities.documentation.Title; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.GreaterEqualConstraint; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter; +import de.lmu.ifi.dbs.elki.utilities.pairs.FCPair; +import de.lmu.ifi.dbs.elki.utilities.pairs.Pair; + +/** + * Evaluate a distance function with respect to kNN queries. For each point, the + * neighbors are sorted by distance, then the ROC AUC is computed. A score of 1 + * means that the distance function provides a perfect ordering of relevant + * neighbors first, then irrelevant neighbors. A value of 0.5 can be obtained by + * random sorting. A value of 0 means the distance function is inverted, i.e. a + * similarity. + * + * In contrast to {@link RankingQualityHistogram}, this method uses a binning + * based on the centrality of objects. This allows analyzing whether or not a + * particular distance degrades for the outer parts of a cluster. + * + * TODO: Allow fixed binning range, configurable + * + * TODO: Add sampling + * + * @author Erich Schubert + * @param <V> Vector type + * @param <D> Distance type + */ +@Title("Evaluate Ranking Quality") +@Description("Evaluates the effectiveness of a distance function via the obtained rankings.") +public class EvaluateRankingQuality<V extends NumberVector<V, ?>, D extends NumberDistance<D, ?>> extends AbstractDistanceBasedAlgorithm<V, D, CollectionResult<DoubleVector>> { + /** + * The logger for this class. + */ + private static final Logging logger = Logging.getLogger(EvaluateRankingQuality.class); + + /** + * Option to configure the number of bins to use. + */ + public static final OptionID HISTOGRAM_BINS_ID = OptionID.getOrCreateOptionID("rankqual.bins", "Number of bins to use in the histogram"); + + /** + * Constructor. + * + * @param distanceFunction + * @param numbins + */ + public EvaluateRankingQuality(DistanceFunction<? super V, D> distanceFunction, int numbins) { + super(distanceFunction); + this.numbins = numbins; + } + + /** + * Number of bins to use. + */ + int numbins = 20; + + /** + * Run the algorithm. + */ + @Override + public HistogramResult<DoubleVector> run(Database database) throws IllegalStateException { + final Relation<V> relation = database.getRelation(getInputTypeRestriction()[0]); + final DistanceQuery<V, D> distQuery = database.getDistanceQuery(relation, getDistanceFunction()); + final KNNQuery<V, D> knnQuery = database.getKNNQuery(distQuery, relation.size()); + + if(logger.isVerbose()) { + logger.verbose("Preprocessing clusters..."); + } + // Cluster by labels + Collection<Cluster<Model>> split = (new ByLabelClustering()).run(database).getAllClusters(); + + // Compute cluster averages and covariance matrix + HashMap<Cluster<?>, V> averages = new HashMap<Cluster<?>, V>(split.size()); + HashMap<Cluster<?>, Matrix> covmats = new HashMap<Cluster<?>, Matrix>(split.size()); + for(Cluster<?> clus : split) { + averages.put(clus, DatabaseUtil.centroid(relation, clus.getIDs())); + covmats.put(clus, DatabaseUtil.covarianceMatrix(relation, clus.getIDs())); + } + + AggregatingHistogram<MeanVariance, Double> hist = AggregatingHistogram.MeanVarianceHistogram(numbins, 0.0, 1.0); + + if(logger.isVerbose()) { + logger.verbose("Processing points..."); + } + FiniteProgress rocloop = logger.isVerbose() ? new FiniteProgress("Computing ROC AUC values", relation.size(), logger) : null; + + // sort neighbors + for(Cluster<?> clus : split) { + ArrayList<FCPair<Double, DBID>> cmem = new ArrayList<FCPair<Double, DBID>>(clus.size()); + Vector av = averages.get(clus).getColumnVector(); + Matrix covm = covmats.get(clus); + + for(DBID i1 : clus.getIDs()) { + Double d = MathUtil.mahalanobisDistance(covm, av.minus(relation.get(i1).getColumnVector())); + cmem.add(new FCPair<Double, DBID>(d, i1)); + } + Collections.sort(cmem); + + for(int ind = 0; ind < cmem.size(); ind++) { + DBID i1 = cmem.get(ind).getSecond(); + List<DistanceResultPair<D>> knn = knnQuery.getKNNForDBID(i1, relation.size()); + double result = ROC.computeROCAUCDistanceResult(relation.size(), clus, knn); + + hist.aggregate(((double) ind) / clus.size(), result); + + if(rocloop != null) { + rocloop.incrementProcessed(logger); + } + } + } + if(rocloop != null) { + rocloop.ensureCompleted(logger); + } + // Collections.sort(results); + + // Transform Histogram into a Double Vector array. + Collection<DoubleVector> res = new ArrayList<DoubleVector>(relation.size()); + for(Pair<Double, MeanVariance> pair : hist) { + DoubleVector row = new DoubleVector(new double[] { pair.getFirst(), pair.getSecond().getCount(), pair.getSecond().getMean(), pair.getSecond().getSampleVariance() }); + res.add(row); + } + return new HistogramResult<DoubleVector>("Ranking Quality Histogram", "ranking-histogram", res); + } + + @Override + public TypeInformation[] getInputTypeRestriction() { + return TypeUtil.array(new CombinedTypeInformation(getDistanceFunction().getInputTypeRestriction(), TypeUtil.NUMBER_VECTOR_FIELD)); + } + + @Override + protected Logging getLogger() { + return logger; + } + + /** + * Parameterization class. + * + * @author Erich Schubert + * + * @apiviz.exclude + */ + public static class Parameterizer<V extends NumberVector<V, ?>, D extends NumberDistance<D, ?>> extends AbstractDistanceBasedAlgorithm.Parameterizer<V, D> { + protected int numbins = 20; + + @Override + protected void makeOptions(Parameterization config) { + super.makeOptions(config); + final IntParameter param = new IntParameter(HISTOGRAM_BINS_ID, new GreaterEqualConstraint(2), 20); + if(config.grab(param)) { + numbins = param.getValue(); + } + } + + @Override + protected EvaluateRankingQuality<V, D> makeInstance() { + return new EvaluateRankingQuality<V, D>(distanceFunction, numbins); + } + } +}
\ No newline at end of file |