summaryrefslogtreecommitdiff
path: root/src/de/lmu/ifi/dbs/elki/algorithm/benchmark/KNNBenchmarkAlgorithm.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/benchmark/KNNBenchmarkAlgorithm.java')
-rw-r--r--src/de/lmu/ifi/dbs/elki/algorithm/benchmark/KNNBenchmarkAlgorithm.java106
1 files changed, 47 insertions, 59 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/benchmark/KNNBenchmarkAlgorithm.java b/src/de/lmu/ifi/dbs/elki/algorithm/benchmark/KNNBenchmarkAlgorithm.java
index 40726793..f0d5b08f 100644
--- a/src/de/lmu/ifi/dbs/elki/algorithm/benchmark/KNNBenchmarkAlgorithm.java
+++ b/src/de/lmu/ifi/dbs/elki/algorithm/benchmark/KNNBenchmarkAlgorithm.java
@@ -4,7 +4,7 @@ package de.lmu.ifi.dbs.elki.algorithm.benchmark;
This file is part of ELKI:
Environment for Developing KDD-Applications Supported by Index-Structures
- Copyright (C) 2013
+ Copyright (C) 2014
Ludwig-Maximilians-Universität München
Lehr- und Forschungseinheit für Datenbanksysteme
ELKI Development Team
@@ -31,20 +31,18 @@ import de.lmu.ifi.dbs.elki.database.ids.DBIDIter;
import de.lmu.ifi.dbs.elki.database.ids.DBIDRange;
import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil;
import de.lmu.ifi.dbs.elki.database.ids.DBIDs;
-import de.lmu.ifi.dbs.elki.database.ids.distance.KNNList;
+import de.lmu.ifi.dbs.elki.database.ids.KNNList;
import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery;
import de.lmu.ifi.dbs.elki.database.query.knn.KNNQuery;
import de.lmu.ifi.dbs.elki.database.relation.Relation;
import de.lmu.ifi.dbs.elki.datasource.DatabaseConnection;
import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle;
import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction;
-import de.lmu.ifi.dbs.elki.distance.distancevalue.Distance;
-import de.lmu.ifi.dbs.elki.distance.distancevalue.NumberDistance;
import de.lmu.ifi.dbs.elki.logging.Logging;
import de.lmu.ifi.dbs.elki.logging.progress.FiniteProgress;
import de.lmu.ifi.dbs.elki.math.MeanVariance;
+import de.lmu.ifi.dbs.elki.math.random.RandomFactory;
import de.lmu.ifi.dbs.elki.result.Result;
-import de.lmu.ifi.dbs.elki.utilities.RandomFactory;
import de.lmu.ifi.dbs.elki.utilities.Util;
import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
@@ -65,7 +63,7 @@ import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.RandomParameter;
*
* @apiviz.uses KNNQuery
*/
-public class KNNBenchmarkAlgorithm<O, D extends Distance<D>> extends AbstractDistanceBasedAlgorithm<O, D, Result> {
+public class KNNBenchmarkAlgorithm<O> extends AbstractDistanceBasedAlgorithm<O, Result> {
/**
* The logger for this class.
*/
@@ -100,7 +98,7 @@ public class KNNBenchmarkAlgorithm<O, D extends Distance<D>> extends AbstractDis
* @param sampling Sampling rate
* @param random Random factory
*/
- public KNNBenchmarkAlgorithm(DistanceFunction<? super O, D> distanceFunction, int k, DatabaseConnection queries, double sampling, RandomFactory random) {
+ public KNNBenchmarkAlgorithm(DistanceFunction<? super O> distanceFunction, int k, DatabaseConnection queries, double sampling, RandomFactory random) {
super(distanceFunction);
this.k = k;
this.queries = queries;
@@ -117,62 +115,58 @@ public class KNNBenchmarkAlgorithm<O, D extends Distance<D>> extends AbstractDis
*/
public Result run(Database database, Relation<O> relation) {
// Get a distance and kNN query instance.
- DistanceQuery<O, D> distQuery = database.getDistanceQuery(relation, getDistanceFunction());
- KNNQuery<O, D> knnQuery = database.getKNNQuery(distQuery, k);
+ DistanceQuery<O> distQuery = database.getDistanceQuery(relation, getDistanceFunction());
+ KNNQuery<O> knnQuery = database.getKNNQuery(distQuery, k);
// No query set - use original database.
- if (queries == null) {
+ if(queries == null) {
final DBIDs sample;
- if (sampling <= 0) {
+ if(sampling <= 0) {
sample = relation.getDBIDs();
- } else if (sampling < 1.1) {
+ }
+ else if(sampling < 1.1) {
int size = (int) Math.min(sampling * relation.size(), relation.size());
sample = DBIDUtil.randomSample(relation.getDBIDs(), size, random);
- } else {
+ }
+ else {
int size = (int) Math.min(sampling, relation.size());
sample = DBIDUtil.randomSample(relation.getDBIDs(), size, random);
}
FiniteProgress prog = LOG.isVeryVerbose() ? new FiniteProgress("kNN queries", sample.size(), LOG) : null;
int hash = 0;
MeanVariance mv = new MeanVariance(), mvdist = new MeanVariance();
- for (DBIDIter iditer = sample.iter(); iditer.valid(); iditer.advance()) {
- KNNList<D> knns = knnQuery.getKNNForDBID(iditer, k);
+ for(DBIDIter iditer = sample.iter(); iditer.valid(); iditer.advance()) {
+ KNNList knns = knnQuery.getKNNForDBID(iditer, k);
int ichecksum = 0;
- for (DBIDIter it = knns.iter(); it.valid(); it.advance()) {
- ichecksum += it.internalGetIndex();
+ for(DBIDIter it = knns.iter(); it.valid(); it.advance()) {
+ ichecksum += DBIDUtil.asInteger(it);
}
hash = Util.mixHashCodes(hash, ichecksum);
mv.put(knns.size());
- D kdist = knns.getKNNDistance();
- if (kdist instanceof NumberDistance) {
- mvdist.put(((NumberDistance<?, ?>) kdist).doubleValue());
- }
- if (prog != null) {
- prog.incrementProcessed(LOG);
- }
- }
- if (prog != null) {
- prog.ensureCompleted(LOG);
+ mvdist.put(knns.getKNNDistance());
+ LOG.incrementProcessed(prog);
}
- if (LOG.isStatistics()) {
+ LOG.ensureCompleted(prog);
+ if(LOG.isStatistics()) {
LOG.statistics("Result hashcode: " + hash);
LOG.statistics("Mean number of results: " + mv.getMean() + " +- " + mv.getNaiveStddev());
- if (mvdist.getCount() > 0) {
+ if(mvdist.getCount() > 0) {
LOG.statistics("Mean k-distance: " + mvdist.getMean() + " +- " + mvdist.getNaiveStddev());
}
}
- } else {
+ }
+ else {
// Separate query set.
TypeInformation res = getDistanceFunction().getInputTypeRestriction();
MultipleObjectsBundle bundle = queries.loadData();
int col = -1;
- for (int i = 0; i < bundle.metaLength(); i++) {
- if (res.isAssignableFromType(bundle.meta(i))) {
+ for(int i = 0; i < bundle.metaLength(); i++) {
+ if(res.isAssignableFromType(bundle.meta(i))) {
col = i;
break;
}
}
- if (col < 0) {
+ if(col < 0) {
throw new AbortException("No compatible data type in query input was found. Expected: " + res.toString());
}
// Random sampling is a bit of hack, sorry.
@@ -180,45 +174,40 @@ public class KNNBenchmarkAlgorithm<O, D extends Distance<D>> extends AbstractDis
DBIDRange sids = DBIDUtil.generateStaticDBIDRange(bundle.dataLength());
final DBIDs sample;
- if (sampling <= 0) {
+ if(sampling <= 0) {
sample = sids;
- } else if (sampling < 1.1) {
+ }
+ else if(sampling < 1.1) {
int size = (int) Math.min(sampling * relation.size(), relation.size());
sample = DBIDUtil.randomSample(sids, size, random);
- } else {
+ }
+ else {
int size = (int) Math.min(sampling, sids.size());
sample = DBIDUtil.randomSample(sids, size, random);
}
FiniteProgress prog = LOG.isVeryVerbose() ? new FiniteProgress("kNN queries", sample.size(), LOG) : null;
int hash = 0;
MeanVariance mv = new MeanVariance(), mvdist = new MeanVariance();
- for (DBIDIter iditer = sample.iter(); iditer.valid(); iditer.advance()) {
+ for(DBIDIter iditer = sample.iter(); iditer.valid(); iditer.advance()) {
int off = sids.binarySearch(iditer);
assert (off >= 0);
@SuppressWarnings("unchecked")
O o = (O) bundle.data(off, col);
- KNNList<D> knns = knnQuery.getKNNForObject(o, k);
+ KNNList knns = knnQuery.getKNNForObject(o, k);
int ichecksum = 0;
- for (DBIDIter it = knns.iter(); it.valid(); it.advance()) {
- ichecksum += it.internalGetIndex();
+ for(DBIDIter it = knns.iter(); it.valid(); it.advance()) {
+ ichecksum += DBIDUtil.asInteger(it);
}
hash = Util.mixHashCodes(hash, ichecksum);
mv.put(knns.size());
- D kdist = knns.getKNNDistance();
- if (kdist instanceof NumberDistance) {
- mvdist.put(((NumberDistance<?, ?>) kdist).doubleValue());
- }
- if (prog != null) {
- prog.incrementProcessed(LOG);
- }
- }
- if (prog != null) {
- prog.ensureCompleted(LOG);
+ mvdist.put(knns.getKNNDistance());
+ LOG.incrementProcessed(prog);
}
- if (LOG.isStatistics()) {
+ LOG.ensureCompleted(prog);
+ if(LOG.isStatistics()) {
LOG.statistics("Result hashcode: " + hash);
LOG.statistics("Mean number of results: " + mv.getMean() + " +- " + mv.getNaiveStddev());
- if (mvdist.getCount() > 0) {
+ if(mvdist.getCount() > 0) {
LOG.statistics("Mean k-distance: " + mvdist.getMean() + " +- " + mvdist.getNaiveStddev());
}
}
@@ -244,9 +233,8 @@ public class KNNBenchmarkAlgorithm<O, D extends Distance<D>> extends AbstractDis
* @author Erich Schubert
*
* @param <O> Object type
- * @param <D> Distance type
*/
- public static class Parameterizer<O, D extends Distance<D>> extends AbstractDistanceBasedAlgorithm.Parameterizer<O, D> {
+ public static class Parameterizer<O> extends AbstractDistanceBasedAlgorithm.Parameterizer<O> {
/**
* Parameter for the number of neighbors.
*/
@@ -291,27 +279,27 @@ public class KNNBenchmarkAlgorithm<O, D extends Distance<D>> extends AbstractDis
protected void makeOptions(Parameterization config) {
super.makeOptions(config);
IntParameter kP = new IntParameter(K_ID);
- if (config.grab(kP)) {
+ if(config.grab(kP)) {
k = kP.intValue();
}
ObjectParameter<DatabaseConnection> queryP = new ObjectParameter<>(QUERY_ID, DatabaseConnection.class);
queryP.setOptional(true);
- if (config.grab(queryP)) {
+ if(config.grab(queryP)) {
queries = queryP.instantiateClass(config);
}
DoubleParameter samplingP = new DoubleParameter(SAMPLING_ID);
samplingP.setOptional(true);
- if (config.grab(samplingP)) {
+ if(config.grab(samplingP)) {
sampling = samplingP.doubleValue();
}
RandomParameter randomP = new RandomParameter(RANDOM_ID, RandomFactory.DEFAULT);
- if (config.grab(randomP)) {
+ if(config.grab(randomP)) {
random = randomP.getValue();
}
}
@Override
- protected KNNBenchmarkAlgorithm<O, D> makeInstance() {
+ protected KNNBenchmarkAlgorithm<O> makeInstance() {
return new KNNBenchmarkAlgorithm<>(distanceFunction, k, queries, sampling, random);
}
}