summaryrefslogtreecommitdiff
path: root/src/de/lmu/ifi/dbs/elki/algorithm/statistics/MeanAveragePrecisionForDistance.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/statistics/MeanAveragePrecisionForDistance.java')
-rw-r--r--src/de/lmu/ifi/dbs/elki/algorithm/statistics/MeanAveragePrecisionForDistance.java388
1 files changed, 388 insertions, 0 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/statistics/MeanAveragePrecisionForDistance.java b/src/de/lmu/ifi/dbs/elki/algorithm/statistics/MeanAveragePrecisionForDistance.java
new file mode 100644
index 00000000..20a4bdac
--- /dev/null
+++ b/src/de/lmu/ifi/dbs/elki/algorithm/statistics/MeanAveragePrecisionForDistance.java
@@ -0,0 +1,388 @@
+package de.lmu.ifi.dbs.elki.algorithm.statistics;
+
+/*
+ This file is part of ELKI:
+ Environment for Developing KDD-Applications Supported by Index-Structures
+
+ Copyright (C) 2014
+ Ludwig-Maximilians-Universität München
+ Lehr- und Forschungseinheit für Datenbanksysteme
+ ELKI Development Team
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+import de.lmu.ifi.dbs.elki.algorithm.AbstractDistanceBasedAlgorithm;
+import de.lmu.ifi.dbs.elki.data.LabelList;
+import de.lmu.ifi.dbs.elki.data.type.AlternativeTypeInformation;
+import de.lmu.ifi.dbs.elki.data.type.TypeInformation;
+import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
+import de.lmu.ifi.dbs.elki.database.Database;
+import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
+import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
+import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
+import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs;
+import de.lmu.ifi.dbs.elki.database.ids.ModifiableDoubleDBIDList;
+import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery;
+import de.lmu.ifi.dbs.elki.database.relation.Relation;
+import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction;
+import de.lmu.ifi.dbs.elki.evaluation.scores.AveragePrecisionEvaluation;
+import de.lmu.ifi.dbs.elki.evaluation.scores.ROCEvaluation;
+import de.lmu.ifi.dbs.elki.logging.Logging;
+import de.lmu.ifi.dbs.elki.logging.progress.FiniteProgress;
+import de.lmu.ifi.dbs.elki.math.Mean;
+import de.lmu.ifi.dbs.elki.math.random.RandomFactory;
+import de.lmu.ifi.dbs.elki.result.Result;
+import de.lmu.ifi.dbs.elki.result.textwriter.TextWriteable;
+import de.lmu.ifi.dbs.elki.result.textwriter.TextWriterStream;
+import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException;
+import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
+import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
+import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
+import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.DoubleParameter;
+import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.Flag;
+import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.RandomParameter;
+
+/**
+ * Evaluate a distance functions performance by computing the mean average
+ * precision, when ranking the objects by distance.
+ *
+ * @author Erich Schubert
+ *
+ * @apiviz.has MAPResult
+ *
+ * @param <O> Object type
+ */
+public class MeanAveragePrecisionForDistance<O> extends AbstractDistanceBasedAlgorithm<O, MeanAveragePrecisionForDistance.MAPResult> {
+ /**
+ * The logger for this class.
+ */
+ private static final Logging LOG = Logging.getLogger(MeanAveragePrecisionForDistance.class);
+
+ /**
+ * Relative number of object to use in sampling.
+ */
+ private double sampling = 1.0;
+
+ /**
+ * Random sampling seed.
+ */
+ private RandomFactory random = null;
+
+ /**
+ * Include query object in evaluation.
+ */
+ private boolean includeSelf;
+
+ /**
+ * Constructor.
+ *
+ * @param distanceFunction Distance function
+ * @param sampling Sampling rate
+ * @param random Random sampling generator
+ * @param includeSelf Include query object in evaluation
+ */
+ public MeanAveragePrecisionForDistance(DistanceFunction<? super O> distanceFunction, double sampling, RandomFactory random, boolean includeSelf) {
+ super(distanceFunction);
+ this.sampling = sampling;
+ this.random = random;
+ this.includeSelf = includeSelf;
+ }
+
+ /**
+ * Run the algorithm
+ *
+ * @param database Database to run on (for kNN queries)
+ * @param relation Relation for distance computations
+ * @param lrelation Relation for class label comparison
+ * @return Vectors containing mean and standard deviation.
+ */
+ public MAPResult run(Database database, Relation<O> relation, Relation<?> lrelation) {
+ final DistanceQuery<O> distQuery = database.getDistanceQuery(relation, getDistanceFunction());
+
+ final DBIDs ids;
+ if(sampling < 1.) {
+ int size = Math.max(1, (int) (sampling * relation.size()));
+ ids = DBIDUtil.randomSample(relation.getDBIDs(), size, random);
+ }
+ else if(sampling > 1.) {
+ ids = DBIDUtil.randomSample(relation.getDBIDs(), (int) sampling, random);
+ }
+ else {
+ ids = relation.getDBIDs();
+ }
+
+ // For storing the positive neighbors.
+ ModifiableDBIDs posn = DBIDUtil.newHashSet();
+ // Distance storage.
+ ModifiableDoubleDBIDList nlist = DBIDUtil.newDistanceDBIDList(relation.size());
+
+ // Statistics tracking
+ Mean map = new Mean(), mroc = new Mean();
+
+ FiniteProgress objloop = LOG.isVerbose() ? new FiniteProgress("Processing query objects", ids.size(), LOG) : null;
+ for(DBIDIter iter = ids.iter(); iter.valid(); iter.advance()) {
+ Object label = lrelation.get(iter);
+ findMatches(posn, lrelation, label);
+ if(posn.size() > 0) {
+ computeDistances(nlist, iter, distQuery, relation);
+ if(nlist.size() != relation.size() - (includeSelf ? 0 : 1)) {
+ LOG.warning("Neighbor list does not have the desired size: " + nlist.size());
+ }
+ map.put(new AveragePrecisionEvaluation().evaluate(posn, nlist));
+ // We may as well compute ROC AUC while we're at it.
+ mroc.put(new ROCEvaluation().evaluate(posn, nlist));
+ }
+ LOG.incrementProcessed(objloop);
+ }
+ LOG.ensureCompleted(objloop);
+ if(map.getCount() < 1) {
+ throw new AbortException("No object matched - are labels parsed correctly?");
+ }
+ if(!(map.getMean() >= 0) || !(mroc.getMean() >= 0)) {
+ throw new AbortException("NaN in MAP/ROC.");
+ }
+
+ return new MAPResult(map.getMean(), mroc.getMean(), ids.size());
+ }
+
+ /**
+ * Test whether two relation agree.
+ *
+ * @param ref Reference object
+ * @param test Test object
+ * @return {@code true} if the objects match
+ */
+ protected static boolean match(Object ref, Object test) {
+ if(ref == null) {
+ return false;
+ }
+ // Cheap and fast, may hold for class labels!
+ if(ref == test) {
+ return true;
+ }
+ if(ref instanceof LabelList && test instanceof LabelList) {
+ final LabelList lref = (LabelList) ref;
+ final LabelList ltest = (LabelList) test;
+ final int s1 = lref.size(), s2 = ltest.size();
+ if(s1 == 0 || s2 == 0) {
+ return false;
+ }
+ for(int i = 0; i < s1; i++) {
+ String l1 = lref.get(i);
+ if(l1 == null) {
+ continue;
+ }
+ for(int j = 0; j < s2; j++) {
+ if(l1.equals(ltest.get(j))) {
+ return true;
+ }
+ }
+ }
+ }
+ // Fallback to equality, e.g. on class labels
+ return ref.equals(test);
+ }
+
+ /**
+ * Find all matching objects.
+ *
+ * @param posn Output set.
+ * @param lrelation Label relation
+ * @param label Query object label
+ */
+ private void findMatches(ModifiableDBIDs posn, Relation<?> lrelation, Object label) {
+ posn.clear();
+ for(DBIDIter ri = lrelation.iterDBIDs(); ri.valid(); ri.advance()) {
+ if(match(label, lrelation.get(ri))) {
+ posn.add(ri);
+ }
+ }
+ }
+
+ /**
+ * Compute the distances to the neighbor objects.
+ *
+ * @param nlist Neighbor list (output)
+ * @param query Query object
+ * @param distQuery Distance function
+ * @param relation Data relation
+ */
+ private void computeDistances(ModifiableDoubleDBIDList nlist, DBIDIter query, final DistanceQuery<O> distQuery, Relation<O> relation) {
+ nlist.clear();
+ O qo = relation.get(query);
+ for(DBIDIter ri = relation.iterDBIDs(); ri.valid(); ri.advance()) {
+ if(!includeSelf && DBIDUtil.equal(ri, query)) {
+ continue;
+ }
+ double dist = distQuery.distance(qo, ri);
+ if(dist != dist) { /* NaN */
+ dist = Double.POSITIVE_INFINITY;
+ }
+ nlist.add(dist, ri);
+ }
+ nlist.sort();
+ }
+
+ @Override
+ public TypeInformation[] getInputTypeRestriction() {
+ TypeInformation cls = new AlternativeTypeInformation(TypeUtil.CLASSLABEL, TypeUtil.LABELLIST);
+ return TypeUtil.array(getDistanceFunction().getInputTypeRestriction(), cls);
+ }
+
+ @Override
+ protected Logging getLogger() {
+ return LOG;
+ }
+
+ /**
+ * Result object for MAP scores.
+ *
+ * @author Erich Schubert
+ */
+ public static class MAPResult implements Result, TextWriteable {
+ /**
+ * MAP value
+ */
+ private double map;
+
+ /**
+ * ROC AUC value
+ */
+ private double rocauc;
+
+ /**
+ * Sample size
+ */
+ private int samplesize;
+
+ /**
+ * Constructor.
+ *
+ * @param map MAP value
+ * @param rocauc ROC AUC value
+ * @param samplesize Sample size
+ */
+ public MAPResult(double map, double rocauc, int samplesize) {
+ super();
+ this.map = map;
+ this.rocauc = rocauc;
+ this.samplesize = samplesize;
+ }
+
+ /**
+ * @return the area under curve
+ */
+ public double getROCAUC() {
+ return rocauc;
+ }
+
+ /**
+ * @return the medium average precision
+ */
+ public double getMAP() {
+ return map;
+ }
+
+ @Override
+ public String getLongName() {
+ return "MAP score";
+ }
+
+ @Override
+ public String getShortName() {
+ return "map-score";
+ }
+
+ @Override
+ public void writeToText(TextWriterStream out, String label) {
+ out.inlinePrintNoQuotes("MAP");
+ out.inlinePrint(map);
+ out.flush();
+ out.inlinePrintNoQuotes("ROCAUC");
+ out.inlinePrint(rocauc);
+ out.flush();
+ out.inlinePrintNoQuotes("Samplesize");
+ out.inlinePrint(samplesize);
+ out.flush();
+ }
+ }
+
+ /**
+ * Parameterization class.
+ *
+ * @author Erich Schubert
+ *
+ * @apiviz.exclude
+ *
+ * @param <O> Object type
+ */
+ public static class Parameterizer<O> extends AbstractDistanceBasedAlgorithm.Parameterizer<O> {
+ /**
+ * Parameter to enable sampling.
+ */
+ public static final OptionID SAMPLING_ID = new OptionID("map.sampling", "Relative amount of object to sample.");
+
+ /**
+ * Parameter to control the sampling random seed.
+ */
+ public static final OptionID SEED_ID = new OptionID("map.sampling-seed", "Random seed for deterministic sampling.");
+
+ /**
+ * Parameter to include the query object.
+ */
+ public static final OptionID INCLUDESELF_ID = new OptionID("map.includeself", "Include the query object in the evaluation.");
+
+ /**
+ * Relative amount of data to sample.
+ */
+ protected double sampling = 1.0;
+
+ /**
+ * Random sampling seed.
+ */
+ protected RandomFactory seed = null;
+
+ /**
+ * Include query object in evaluation.
+ */
+ protected boolean includeSelf;
+
+ @Override
+ protected void makeOptions(Parameterization config) {
+ super.makeOptions(config);
+ final DoubleParameter samplingP = new DoubleParameter(SAMPLING_ID);
+ samplingP.addConstraint(CommonConstraints.GREATER_THAN_ZERO_DOUBLE);
+ samplingP.addConstraint(CommonConstraints.LESS_EQUAL_ONE_DOUBLE);
+ samplingP.setOptional(true);
+ if(config.grab(samplingP)) {
+ sampling = samplingP.getValue();
+ }
+ final RandomParameter rndP = new RandomParameter(SEED_ID);
+ rndP.setOptional(true);
+ if(config.grab(rndP)) {
+ seed = rndP.getValue();
+ }
+ final Flag includeP = new Flag(INCLUDESELF_ID);
+ if(config.grab(includeP)) {
+ includeSelf = includeP.isTrue();
+ }
+ }
+
+ @Override
+ protected MeanAveragePrecisionForDistance<O> makeInstance() {
+ return new MeanAveragePrecisionForDistance<>(distanceFunction, sampling, seed, includeSelf);
+ }
+ }
+}