summaryrefslogtreecommitdiff
path: root/src/de/lmu/ifi/dbs/elki/evaluation/roc/ROC.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/evaluation/roc/ROC.java')
-rw-r--r--src/de/lmu/ifi/dbs/elki/evaluation/roc/ROC.java92
1 files changed, 87 insertions, 5 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/evaluation/roc/ROC.java b/src/de/lmu/ifi/dbs/elki/evaluation/roc/ROC.java
index 1b37bfff..f7fdd254 100644
--- a/src/de/lmu/ifi/dbs/elki/evaluation/roc/ROC.java
+++ b/src/de/lmu/ifi/dbs/elki/evaluation/roc/ROC.java
@@ -4,7 +4,7 @@ package de.lmu.ifi.dbs.elki.evaluation.roc;
This file is part of ELKI:
Environment for Developing KDD-Applications Supported by Index-Structures
- Copyright (C) 2011
+ Copyright (C) 2012
Ludwig-Maximilians-Universität München
Lehr- und Forschungseinheit für Datenbanksysteme
ELKI Development Team
@@ -33,6 +33,7 @@ import de.lmu.ifi.dbs.elki.database.ids.DBID;
import de.lmu.ifi.dbs.elki.database.ids.DBIDPair;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
+import de.lmu.ifi.dbs.elki.database.ids.SetDBIDs;
import de.lmu.ifi.dbs.elki.database.query.DistanceResultPair;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.distance.distancevalue.Distance;
@@ -150,6 +151,87 @@ public class ROC {
}
/**
+ * Compute a ROC curve given a set of positive IDs and a sorted list of
+ * (comparable, ID)s, where the comparable object is used to decided when two
+ * objects are interchangeable.
+ *
+ * @param <C> Reference type
+ * @param size Database size
+ * @param ids Collection of positive IDs, should support efficient contains()
+ * @param nei List of neighbors along with some comparable object to detect
+ * 'same positions'.
+ * @return area under curve
+ */
+ public static <C extends Comparable<? super C>> List<DoubleDoublePair> materializeROC(int size, SetDBIDs ids, Iterator<? extends PairInterface<C, DBID>> nei) {
+ int postot = ids.size();
+ int negtot = size - postot;
+ int poscnt = 0;
+ int negcnt = 0;
+ ArrayList<DoubleDoublePair> res = new ArrayList<DoubleDoublePair>(postot + 2);
+
+ // start in bottom left
+ res.add(new DoubleDoublePair(0.0, 0.0));
+
+ PairInterface<C, DBID> prev = null;
+ while(nei.hasNext()) {
+ // Previous positive rate - y axis
+ double curpos = ((double) poscnt) / postot;
+ // Previous negative rate - x axis
+ double curneg = ((double) negcnt) / negtot;
+
+ // Analyze next point
+ PairInterface<C, DBID> cur = nei.next();
+ // positive or negative match?
+ if(ids.contains(cur.getSecond())) {
+ poscnt += 1;
+ }
+ else {
+ negcnt += 1;
+ }
+ // defer calculation for ties
+ if((prev != null) && (prev.getFirst().compareTo(cur.getFirst()) == 0)) {
+ continue;
+ }
+ // simplify curve when possible:
+ if(res.size() >= 2) {
+ DoubleDoublePair last1 = res.get(res.size() - 2);
+ DoubleDoublePair last2 = res.get(res.size() - 1);
+ final double ldx = last2.first - last1.first;
+ final double cdx = curneg - last2.first;
+ final double ldy = last2.second - last1.second;
+ final double cdy = curpos - last2.second;
+ // vertical simplification
+ if((ldx == 0) && (cdx == 0)) {
+ res.remove(res.size() - 1);
+ }
+ // horizontal simplification
+ else if((ldy == 0) && (cdy == 0)) {
+ res.remove(res.size() - 1);
+ }
+ // diagonal simplification
+ else if(ldy > 0 && cdy > 0) {
+ if(Math.abs((ldx / ldy) - (cdx / cdy)) < 1E-10) {
+ res.remove(res.size() - 1);
+ }
+ }
+ }
+ // Add a new point (for the previous entry!)
+ res.add(new DoubleDoublePair(curneg, curpos));
+ prev = cur;
+ }
+ // ensure we end up in the top right corner.
+ // Since we didn't add a point for the last entry yet, this likely is
+ // needed.
+ {
+ DoubleDoublePair last = res.get(res.size() - 1);
+ if(last.first < 1.0 || last.second < 1.0) {
+ res.add(new DoubleDoublePair(1.0, 1.0));
+ }
+ }
+ return res;
+ }
+
+ /**
* This adapter can be used for an arbitrary collection of Integers, and uses
* that id1.compareTo(id2) != 0 for id1 != id2 to satisfy the comparability.
*
@@ -207,14 +289,14 @@ public class ROC {
/**
* Original Iterator
*/
- private Iterator<DistanceResultPair<D>> iter;
+ private Iterator<? extends DistanceResultPair<D>> iter;
/**
* Constructor
*
* @param iter Iterator for distance results
*/
- public DistanceResultAdapter(Iterator<DistanceResultPair<D>> iter) {
+ public DistanceResultAdapter(Iterator<? extends DistanceResultPair<D>> iter) {
super();
this.iter = iter;
}
@@ -326,7 +408,7 @@ public class ROC {
* @param nei Query result
* @return area under curve
*/
- public static <D extends Distance<D>> double computeROCAUCDistanceResult(int size, Cluster<?> clus, List<DistanceResultPair<D>> nei) {
+ public static <D extends Distance<D>> double computeROCAUCDistanceResult(int size, Cluster<?> clus, Iterable<? extends DistanceResultPair<D>> nei) {
// TODO: ensure the collection has efficient "contains".
return ROC.computeROCAUCDistanceResult(size, clus.getIDs(), nei);
}
@@ -340,7 +422,7 @@ public class ROC {
* @param nei Query Result
* @return area under curve
*/
- public static <D extends Distance<D>> double computeROCAUCDistanceResult(int size, DBIDs ids, List<DistanceResultPair<D>> nei) {
+ public static <D extends Distance<D>> double computeROCAUCDistanceResult(int size, DBIDs ids, Iterable<? extends DistanceResultPair<D>> nei) {
// TODO: do not materialize the ROC, but introduce an iterator interface
List<DoubleDoublePair> roc = materializeROC(size, DBIDUtil.ensureSet(ids), new DistanceResultAdapter<D>(nei.iterator()));
return computeAUC(roc);