diff options
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/benchmark/KNNBenchmarkAlgorithm.java')
-rw-r--r-- | src/de/lmu/ifi/dbs/elki/algorithm/benchmark/KNNBenchmarkAlgorithm.java | 106 |
1 files changed, 47 insertions, 59 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/benchmark/KNNBenchmarkAlgorithm.java b/src/de/lmu/ifi/dbs/elki/algorithm/benchmark/KNNBenchmarkAlgorithm.java index 40726793..f0d5b08f 100644 --- a/src/de/lmu/ifi/dbs/elki/algorithm/benchmark/KNNBenchmarkAlgorithm.java +++ b/src/de/lmu/ifi/dbs/elki/algorithm/benchmark/KNNBenchmarkAlgorithm.java @@ -4,7 +4,7 @@ package de.lmu.ifi.dbs.elki.algorithm.benchmark; This file is part of ELKI: Environment for Developing KDD-Applications Supported by Index-Structures - Copyright (C) 2013 + Copyright (C) 2014 Ludwig-Maximilians-Universität München Lehr- und Forschungseinheit für Datenbanksysteme ELKI Development Team @@ -31,20 +31,18 @@ import de.lmu.ifi.dbs.elki.database.ids.DBIDIter; import de.lmu.ifi.dbs.elki.database.ids.DBIDRange; import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil; import de.lmu.ifi.dbs.elki.database.ids.DBIDs; -import de.lmu.ifi.dbs.elki.database.ids.distance.KNNList; +import de.lmu.ifi.dbs.elki.database.ids.KNNList; import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery; import de.lmu.ifi.dbs.elki.database.query.knn.KNNQuery; import de.lmu.ifi.dbs.elki.database.relation.Relation; import de.lmu.ifi.dbs.elki.datasource.DatabaseConnection; import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle; import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction; -import de.lmu.ifi.dbs.elki.distance.distancevalue.Distance; -import de.lmu.ifi.dbs.elki.distance.distancevalue.NumberDistance; import de.lmu.ifi.dbs.elki.logging.Logging; import de.lmu.ifi.dbs.elki.logging.progress.FiniteProgress; import de.lmu.ifi.dbs.elki.math.MeanVariance; +import de.lmu.ifi.dbs.elki.math.random.RandomFactory; import de.lmu.ifi.dbs.elki.result.Result; -import de.lmu.ifi.dbs.elki.utilities.RandomFactory; import de.lmu.ifi.dbs.elki.utilities.Util; import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException; import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID; @@ -65,7 +63,7 @@ import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.RandomParameter; * * @apiviz.uses KNNQuery */ -public class KNNBenchmarkAlgorithm<O, D extends Distance<D>> extends AbstractDistanceBasedAlgorithm<O, D, Result> { +public class KNNBenchmarkAlgorithm<O> extends AbstractDistanceBasedAlgorithm<O, Result> { /** * The logger for this class. */ @@ -100,7 +98,7 @@ public class KNNBenchmarkAlgorithm<O, D extends Distance<D>> extends AbstractDis * @param sampling Sampling rate * @param random Random factory */ - public KNNBenchmarkAlgorithm(DistanceFunction<? super O, D> distanceFunction, int k, DatabaseConnection queries, double sampling, RandomFactory random) { + public KNNBenchmarkAlgorithm(DistanceFunction<? super O> distanceFunction, int k, DatabaseConnection queries, double sampling, RandomFactory random) { super(distanceFunction); this.k = k; this.queries = queries; @@ -117,62 +115,58 @@ public class KNNBenchmarkAlgorithm<O, D extends Distance<D>> extends AbstractDis */ public Result run(Database database, Relation<O> relation) { // Get a distance and kNN query instance. - DistanceQuery<O, D> distQuery = database.getDistanceQuery(relation, getDistanceFunction()); - KNNQuery<O, D> knnQuery = database.getKNNQuery(distQuery, k); + DistanceQuery<O> distQuery = database.getDistanceQuery(relation, getDistanceFunction()); + KNNQuery<O> knnQuery = database.getKNNQuery(distQuery, k); // No query set - use original database. - if (queries == null) { + if(queries == null) { final DBIDs sample; - if (sampling <= 0) { + if(sampling <= 0) { sample = relation.getDBIDs(); - } else if (sampling < 1.1) { + } + else if(sampling < 1.1) { int size = (int) Math.min(sampling * relation.size(), relation.size()); sample = DBIDUtil.randomSample(relation.getDBIDs(), size, random); - } else { + } + else { int size = (int) Math.min(sampling, relation.size()); sample = DBIDUtil.randomSample(relation.getDBIDs(), size, random); } FiniteProgress prog = LOG.isVeryVerbose() ? new FiniteProgress("kNN queries", sample.size(), LOG) : null; int hash = 0; MeanVariance mv = new MeanVariance(), mvdist = new MeanVariance(); - for (DBIDIter iditer = sample.iter(); iditer.valid(); iditer.advance()) { - KNNList<D> knns = knnQuery.getKNNForDBID(iditer, k); + for(DBIDIter iditer = sample.iter(); iditer.valid(); iditer.advance()) { + KNNList knns = knnQuery.getKNNForDBID(iditer, k); int ichecksum = 0; - for (DBIDIter it = knns.iter(); it.valid(); it.advance()) { - ichecksum += it.internalGetIndex(); + for(DBIDIter it = knns.iter(); it.valid(); it.advance()) { + ichecksum += DBIDUtil.asInteger(it); } hash = Util.mixHashCodes(hash, ichecksum); mv.put(knns.size()); - D kdist = knns.getKNNDistance(); - if (kdist instanceof NumberDistance) { - mvdist.put(((NumberDistance<?, ?>) kdist).doubleValue()); - } - if (prog != null) { - prog.incrementProcessed(LOG); - } - } - if (prog != null) { - prog.ensureCompleted(LOG); + mvdist.put(knns.getKNNDistance()); + LOG.incrementProcessed(prog); } - if (LOG.isStatistics()) { + LOG.ensureCompleted(prog); + if(LOG.isStatistics()) { LOG.statistics("Result hashcode: " + hash); LOG.statistics("Mean number of results: " + mv.getMean() + " +- " + mv.getNaiveStddev()); - if (mvdist.getCount() > 0) { + if(mvdist.getCount() > 0) { LOG.statistics("Mean k-distance: " + mvdist.getMean() + " +- " + mvdist.getNaiveStddev()); } } - } else { + } + else { // Separate query set. TypeInformation res = getDistanceFunction().getInputTypeRestriction(); MultipleObjectsBundle bundle = queries.loadData(); int col = -1; - for (int i = 0; i < bundle.metaLength(); i++) { - if (res.isAssignableFromType(bundle.meta(i))) { + for(int i = 0; i < bundle.metaLength(); i++) { + if(res.isAssignableFromType(bundle.meta(i))) { col = i; break; } } - if (col < 0) { + if(col < 0) { throw new AbortException("No compatible data type in query input was found. Expected: " + res.toString()); } // Random sampling is a bit of hack, sorry. @@ -180,45 +174,40 @@ public class KNNBenchmarkAlgorithm<O, D extends Distance<D>> extends AbstractDis DBIDRange sids = DBIDUtil.generateStaticDBIDRange(bundle.dataLength()); final DBIDs sample; - if (sampling <= 0) { + if(sampling <= 0) { sample = sids; - } else if (sampling < 1.1) { + } + else if(sampling < 1.1) { int size = (int) Math.min(sampling * relation.size(), relation.size()); sample = DBIDUtil.randomSample(sids, size, random); - } else { + } + else { int size = (int) Math.min(sampling, sids.size()); sample = DBIDUtil.randomSample(sids, size, random); } FiniteProgress prog = LOG.isVeryVerbose() ? new FiniteProgress("kNN queries", sample.size(), LOG) : null; int hash = 0; MeanVariance mv = new MeanVariance(), mvdist = new MeanVariance(); - for (DBIDIter iditer = sample.iter(); iditer.valid(); iditer.advance()) { + for(DBIDIter iditer = sample.iter(); iditer.valid(); iditer.advance()) { int off = sids.binarySearch(iditer); assert (off >= 0); @SuppressWarnings("unchecked") O o = (O) bundle.data(off, col); - KNNList<D> knns = knnQuery.getKNNForObject(o, k); + KNNList knns = knnQuery.getKNNForObject(o, k); int ichecksum = 0; - for (DBIDIter it = knns.iter(); it.valid(); it.advance()) { - ichecksum += it.internalGetIndex(); + for(DBIDIter it = knns.iter(); it.valid(); it.advance()) { + ichecksum += DBIDUtil.asInteger(it); } hash = Util.mixHashCodes(hash, ichecksum); mv.put(knns.size()); - D kdist = knns.getKNNDistance(); - if (kdist instanceof NumberDistance) { - mvdist.put(((NumberDistance<?, ?>) kdist).doubleValue()); - } - if (prog != null) { - prog.incrementProcessed(LOG); - } - } - if (prog != null) { - prog.ensureCompleted(LOG); + mvdist.put(knns.getKNNDistance()); + LOG.incrementProcessed(prog); } - if (LOG.isStatistics()) { + LOG.ensureCompleted(prog); + if(LOG.isStatistics()) { LOG.statistics("Result hashcode: " + hash); LOG.statistics("Mean number of results: " + mv.getMean() + " +- " + mv.getNaiveStddev()); - if (mvdist.getCount() > 0) { + if(mvdist.getCount() > 0) { LOG.statistics("Mean k-distance: " + mvdist.getMean() + " +- " + mvdist.getNaiveStddev()); } } @@ -244,9 +233,8 @@ public class KNNBenchmarkAlgorithm<O, D extends Distance<D>> extends AbstractDis * @author Erich Schubert * * @param <O> Object type - * @param <D> Distance type */ - public static class Parameterizer<O, D extends Distance<D>> extends AbstractDistanceBasedAlgorithm.Parameterizer<O, D> { + public static class Parameterizer<O> extends AbstractDistanceBasedAlgorithm.Parameterizer<O> { /** * Parameter for the number of neighbors. */ @@ -291,27 +279,27 @@ public class KNNBenchmarkAlgorithm<O, D extends Distance<D>> extends AbstractDis protected void makeOptions(Parameterization config) { super.makeOptions(config); IntParameter kP = new IntParameter(K_ID); - if (config.grab(kP)) { + if(config.grab(kP)) { k = kP.intValue(); } ObjectParameter<DatabaseConnection> queryP = new ObjectParameter<>(QUERY_ID, DatabaseConnection.class); queryP.setOptional(true); - if (config.grab(queryP)) { + if(config.grab(queryP)) { queries = queryP.instantiateClass(config); } DoubleParameter samplingP = new DoubleParameter(SAMPLING_ID); samplingP.setOptional(true); - if (config.grab(samplingP)) { + if(config.grab(samplingP)) { sampling = samplingP.doubleValue(); } RandomParameter randomP = new RandomParameter(RANDOM_ID, RandomFactory.DEFAULT); - if (config.grab(randomP)) { + if(config.grab(randomP)) { random = randomP.getValue(); } } @Override - protected KNNBenchmarkAlgorithm<O, D> makeInstance() { + protected KNNBenchmarkAlgorithm<O> makeInstance() { return new KNNBenchmarkAlgorithm<>(distanceFunction, k, queries, sampling, random); } } |