diff options
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/evaluation/roc/ROC.java')
-rw-r--r-- | src/de/lmu/ifi/dbs/elki/evaluation/roc/ROC.java | 92 |
1 files changed, 87 insertions, 5 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/evaluation/roc/ROC.java b/src/de/lmu/ifi/dbs/elki/evaluation/roc/ROC.java index 1b37bfff..f7fdd254 100644 --- a/src/de/lmu/ifi/dbs/elki/evaluation/roc/ROC.java +++ b/src/de/lmu/ifi/dbs/elki/evaluation/roc/ROC.java @@ -4,7 +4,7 @@ package de.lmu.ifi.dbs.elki.evaluation.roc; This file is part of ELKI: Environment for Developing KDD-Applications Supported by Index-Structures - Copyright (C) 2011 + Copyright (C) 2012 Ludwig-Maximilians-Universität München Lehr- und Forschungseinheit für Datenbanksysteme ELKI Development Team @@ -33,6 +33,7 @@ import de.lmu.ifi.dbs.elki.database.ids.DBID; import de.lmu.ifi.dbs.elki.database.ids.DBIDPair; import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil; import de.lmu.ifi.dbs.elki.database.ids.DBIDs; +import de.lmu.ifi.dbs.elki.database.ids.SetDBIDs; import de.lmu.ifi.dbs.elki.database.query.DistanceResultPair; import de.lmu.ifi.dbs.elki.database.relation.Relation; import de.lmu.ifi.dbs.elki.distance.distancevalue.Distance; @@ -150,6 +151,87 @@ public class ROC { } /** + * Compute a ROC curve given a set of positive IDs and a sorted list of + * (comparable, ID)s, where the comparable object is used to decided when two + * objects are interchangeable. + * + * @param <C> Reference type + * @param size Database size + * @param ids Collection of positive IDs, should support efficient contains() + * @param nei List of neighbors along with some comparable object to detect + * 'same positions'. + * @return area under curve + */ + public static <C extends Comparable<? super C>> List<DoubleDoublePair> materializeROC(int size, SetDBIDs ids, Iterator<? extends PairInterface<C, DBID>> nei) { + int postot = ids.size(); + int negtot = size - postot; + int poscnt = 0; + int negcnt = 0; + ArrayList<DoubleDoublePair> res = new ArrayList<DoubleDoublePair>(postot + 2); + + // start in bottom left + res.add(new DoubleDoublePair(0.0, 0.0)); + + PairInterface<C, DBID> prev = null; + while(nei.hasNext()) { + // Previous positive rate - y axis + double curpos = ((double) poscnt) / postot; + // Previous negative rate - x axis + double curneg = ((double) negcnt) / negtot; + + // Analyze next point + PairInterface<C, DBID> cur = nei.next(); + // positive or negative match? + if(ids.contains(cur.getSecond())) { + poscnt += 1; + } + else { + negcnt += 1; + } + // defer calculation for ties + if((prev != null) && (prev.getFirst().compareTo(cur.getFirst()) == 0)) { + continue; + } + // simplify curve when possible: + if(res.size() >= 2) { + DoubleDoublePair last1 = res.get(res.size() - 2); + DoubleDoublePair last2 = res.get(res.size() - 1); + final double ldx = last2.first - last1.first; + final double cdx = curneg - last2.first; + final double ldy = last2.second - last1.second; + final double cdy = curpos - last2.second; + // vertical simplification + if((ldx == 0) && (cdx == 0)) { + res.remove(res.size() - 1); + } + // horizontal simplification + else if((ldy == 0) && (cdy == 0)) { + res.remove(res.size() - 1); + } + // diagonal simplification + else if(ldy > 0 && cdy > 0) { + if(Math.abs((ldx / ldy) - (cdx / cdy)) < 1E-10) { + res.remove(res.size() - 1); + } + } + } + // Add a new point (for the previous entry!) + res.add(new DoubleDoublePair(curneg, curpos)); + prev = cur; + } + // ensure we end up in the top right corner. + // Since we didn't add a point for the last entry yet, this likely is + // needed. + { + DoubleDoublePair last = res.get(res.size() - 1); + if(last.first < 1.0 || last.second < 1.0) { + res.add(new DoubleDoublePair(1.0, 1.0)); + } + } + return res; + } + + /** * This adapter can be used for an arbitrary collection of Integers, and uses * that id1.compareTo(id2) != 0 for id1 != id2 to satisfy the comparability. * @@ -207,14 +289,14 @@ public class ROC { /** * Original Iterator */ - private Iterator<DistanceResultPair<D>> iter; + private Iterator<? extends DistanceResultPair<D>> iter; /** * Constructor * * @param iter Iterator for distance results */ - public DistanceResultAdapter(Iterator<DistanceResultPair<D>> iter) { + public DistanceResultAdapter(Iterator<? extends DistanceResultPair<D>> iter) { super(); this.iter = iter; } @@ -326,7 +408,7 @@ public class ROC { * @param nei Query result * @return area under curve */ - public static <D extends Distance<D>> double computeROCAUCDistanceResult(int size, Cluster<?> clus, List<DistanceResultPair<D>> nei) { + public static <D extends Distance<D>> double computeROCAUCDistanceResult(int size, Cluster<?> clus, Iterable<? extends DistanceResultPair<D>> nei) { // TODO: ensure the collection has efficient "contains". return ROC.computeROCAUCDistanceResult(size, clus.getIDs(), nei); } @@ -340,7 +422,7 @@ public class ROC { * @param nei Query Result * @return area under curve */ - public static <D extends Distance<D>> double computeROCAUCDistanceResult(int size, DBIDs ids, List<DistanceResultPair<D>> nei) { + public static <D extends Distance<D>> double computeROCAUCDistanceResult(int size, DBIDs ids, Iterable<? extends DistanceResultPair<D>> nei) { // TODO: do not materialize the ROC, but introduce an iterator interface List<DoubleDoublePair> roc = materializeROC(size, DBIDUtil.ensureSet(ids), new DistanceResultAdapter<D>(nei.iterator())); return computeAUC(roc); |