diff options
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/benchmark/RangeQueryBenchmarkAlgorithm.java')
-rw-r--r-- | src/de/lmu/ifi/dbs/elki/algorithm/benchmark/RangeQueryBenchmarkAlgorithm.java | 121 |
1 files changed, 56 insertions, 65 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/benchmark/RangeQueryBenchmarkAlgorithm.java b/src/de/lmu/ifi/dbs/elki/algorithm/benchmark/RangeQueryBenchmarkAlgorithm.java index 1b5e827b..b5336aac 100644 --- a/src/de/lmu/ifi/dbs/elki/algorithm/benchmark/RangeQueryBenchmarkAlgorithm.java +++ b/src/de/lmu/ifi/dbs/elki/algorithm/benchmark/RangeQueryBenchmarkAlgorithm.java @@ -4,7 +4,7 @@ package de.lmu.ifi.dbs.elki.algorithm.benchmark; This file is part of ELKI: Environment for Developing KDD-Applications Supported by Index-Structures - Copyright (C) 2013 + Copyright (C) 2014 Ludwig-Maximilians-Universität München Lehr- und Forschungseinheit für Datenbanksysteme ELKI Development Team @@ -25,7 +25,6 @@ package de.lmu.ifi.dbs.elki.algorithm.benchmark; import de.lmu.ifi.dbs.elki.algorithm.AbstractDistanceBasedAlgorithm; import de.lmu.ifi.dbs.elki.data.NumberVector; -import de.lmu.ifi.dbs.elki.data.NumberVector.Factory; import de.lmu.ifi.dbs.elki.data.type.TypeInformation; import de.lmu.ifi.dbs.elki.data.type.TypeUtil; import de.lmu.ifi.dbs.elki.data.type.VectorFieldTypeInformation; @@ -34,7 +33,7 @@ import de.lmu.ifi.dbs.elki.database.ids.DBIDIter; import de.lmu.ifi.dbs.elki.database.ids.DBIDRange; import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil; import de.lmu.ifi.dbs.elki.database.ids.DBIDs; -import de.lmu.ifi.dbs.elki.database.ids.distance.DistanceDBIDList; +import de.lmu.ifi.dbs.elki.database.ids.DoubleDBIDList; import de.lmu.ifi.dbs.elki.database.query.distance.DistanceQuery; import de.lmu.ifi.dbs.elki.database.query.range.RangeQuery; import de.lmu.ifi.dbs.elki.database.relation.Relation; @@ -42,12 +41,11 @@ import de.lmu.ifi.dbs.elki.database.relation.RelationUtil; import de.lmu.ifi.dbs.elki.datasource.DatabaseConnection; import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle; import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction; -import de.lmu.ifi.dbs.elki.distance.distancevalue.NumberDistance; import de.lmu.ifi.dbs.elki.logging.Logging; import de.lmu.ifi.dbs.elki.logging.progress.FiniteProgress; import de.lmu.ifi.dbs.elki.math.MeanVariance; +import de.lmu.ifi.dbs.elki.math.random.RandomFactory; import de.lmu.ifi.dbs.elki.result.Result; -import de.lmu.ifi.dbs.elki.utilities.RandomFactory; import de.lmu.ifi.dbs.elki.utilities.Util; import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException; import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID; @@ -93,11 +91,10 @@ import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.RandomParameter; * @author Erich Schubert * * @param <O> Vector type - * @param <D> Distance type * * @apiviz.uses RangeQuery */ -public class RangeQueryBenchmarkAlgorithm<O extends NumberVector<?>, D extends NumberDistance<D, ?>> extends AbstractDistanceBasedAlgorithm<O, D, Result> { +public class RangeQueryBenchmarkAlgorithm<O extends NumberVector> extends AbstractDistanceBasedAlgorithm<O, Result> { /** * The logger for this class. */ @@ -126,7 +123,7 @@ public class RangeQueryBenchmarkAlgorithm<O extends NumberVector<?>, D extends N * @param sampling Sampling rate * @param random Random factory */ - public RangeQueryBenchmarkAlgorithm(DistanceFunction<? super O, D> distanceFunction, DatabaseConnection queries, double sampling, RandomFactory random) { + public RangeQueryBenchmarkAlgorithm(DistanceFunction<? super O> distanceFunction, DatabaseConnection queries, double sampling, RandomFactory random) { super(distanceFunction); this.queries = queries; this.sampling = sampling; @@ -141,45 +138,42 @@ public class RangeQueryBenchmarkAlgorithm<O extends NumberVector<?>, D extends N * @param radrel Radius relation * @return Null result */ - public Result run(Database database, Relation<O> relation, Relation<NumberVector<?>> radrel) { - if (queries != null) { + public Result run(Database database, Relation<O> relation, Relation<NumberVector> radrel) { + if(queries != null) { throw new AbortException("This 'run' method will not use the given query set!"); } // Get a distance and kNN query instance. - DistanceQuery<O, D> distQuery = database.getDistanceQuery(relation, getDistanceFunction()); - RangeQuery<O, D> rangeQuery = database.getRangeQuery(distQuery); - D dfactory = distQuery.getDistanceFactory(); + DistanceQuery<O> distQuery = database.getDistanceQuery(relation, getDistanceFunction()); + RangeQuery<O> rangeQuery = database.getRangeQuery(distQuery); final DBIDs sample; - if (sampling <= 0) { + if(sampling <= 0) { sample = relation.getDBIDs(); - } else if (sampling < 1.1) { + } + else if(sampling < 1.1) { int size = (int) Math.min(sampling * relation.size(), relation.size()); sample = DBIDUtil.randomSample(relation.getDBIDs(), size, random); - } else { + } + else { int size = (int) Math.min(sampling, relation.size()); sample = DBIDUtil.randomSample(relation.getDBIDs(), size, random); } FiniteProgress prog = LOG.isVeryVerbose() ? new FiniteProgress("kNN queries", sample.size(), LOG) : null; int hash = 0; MeanVariance mv = new MeanVariance(); - for (DBIDIter iditer = sample.iter(); iditer.valid(); iditer.advance()) { - D r = dfactory.fromDouble(radrel.get(iditer).doubleValue(0)); - DistanceDBIDList<D> rres = rangeQuery.getRangeForDBID(iditer, r); + for(DBIDIter iditer = sample.iter(); iditer.valid(); iditer.advance()) { + double r = radrel.get(iditer).doubleValue(0); + DoubleDBIDList rres = rangeQuery.getRangeForDBID(iditer, r); int ichecksum = 0; - for (DBIDIter it = rres.iter(); it.valid(); it.advance()) { - ichecksum += it.internalGetIndex(); + for(DBIDIter it = rres.iter(); it.valid(); it.advance()) { + ichecksum += DBIDUtil.asInteger(it); } hash = Util.mixHashCodes(hash, ichecksum); mv.put(rres.size()); - if (prog != null) { - prog.incrementProcessed(LOG); - } + LOG.incrementProcessed(prog); } - if (prog != null) { - prog.ensureCompleted(LOG); - } - if (LOG.isStatistics()) { + LOG.ensureCompleted(prog); + if(LOG.isStatistics()) { LOG.statistics("Result hashcode: " + hash); LOG.statistics("Mean number of results: " + mv.getMean() + " +- " + mv.getNaiveStddev()); } @@ -194,33 +188,32 @@ public class RangeQueryBenchmarkAlgorithm<O extends NumberVector<?>, D extends N * @return Null result */ public Result run(Database database, Relation<O> relation) { - if (queries == null) { + if(queries == null) { throw new AbortException("A query set is required for this 'run' method."); } // Get a distance and kNN query instance. - DistanceQuery<O, D> distQuery = database.getDistanceQuery(relation, getDistanceFunction()); - RangeQuery<O, D> rangeQuery = database.getRangeQuery(distQuery); - D dfactory = distQuery.getDistanceFactory(); - Factory<O, ?> ofactory = RelationUtil.getNumberVectorFactory(relation); + DistanceQuery<O> distQuery = database.getDistanceQuery(relation, getDistanceFunction()); + RangeQuery<O> rangeQuery = database.getRangeQuery(distQuery); + NumberVector.Factory<O> ofactory = RelationUtil.getNumberVectorFactory(relation); int dim = RelationUtil.dimensionality(relation); // Separate query set. - TypeInformation res = new VectorFieldTypeInformation<NumberVector<?>>(NumberVector.class, dim + 1); + TypeInformation res = VectorFieldTypeInformation.typeRequest(NumberVector.class, dim + 1, dim + 1); MultipleObjectsBundle bundle = queries.loadData(); int col = -1; - for (int i = 0; i < bundle.metaLength(); i++) { - if (res.isAssignableFromType(bundle.meta(i))) { + for(int i = 0; i < bundle.metaLength(); i++) { + if(res.isAssignableFromType(bundle.meta(i))) { col = i; break; } } - if (col < 0) { + if(col < 0) { StringBuilder buf = new StringBuilder(); buf.append("No compatible data type in query input was found. Expected: "); buf.append(res.toString()); buf.append(" have: "); - for (int i = 0; i < bundle.metaLength(); i++) { - if (i > 0) { + for(int i = 0; i < bundle.metaLength(); i++) { + if(i > 0) { buf.append(' '); } buf.append(bundle.meta(i).toString()); @@ -232,12 +225,14 @@ public class RangeQueryBenchmarkAlgorithm<O extends NumberVector<?>, D extends N DBIDRange sids = DBIDUtil.generateStaticDBIDRange(bundle.dataLength()); final DBIDs sample; - if (sampling <= 0) { + if(sampling <= 0) { sample = sids; - } else if (sampling < 1.1) { + } + else if(sampling < 1.1) { int size = (int) Math.min(sampling * relation.size(), relation.size()); sample = DBIDUtil.randomSample(sids, size, random); - } else { + } + else { int size = (int) Math.min(sampling, sids.size()); sample = DBIDUtil.randomSample(sids, size, random); } @@ -245,30 +240,26 @@ public class RangeQueryBenchmarkAlgorithm<O extends NumberVector<?>, D extends N int hash = 0; MeanVariance mv = new MeanVariance(); double[] buf = new double[dim]; - for (DBIDIter iditer = sample.iter(); iditer.valid(); iditer.advance()) { + for(DBIDIter iditer = sample.iter(); iditer.valid(); iditer.advance()) { int off = sids.binarySearch(iditer); assert (off >= 0); - NumberVector<?> o = (NumberVector<?>) bundle.data(off, col); - for (int i = 0; i < dim; i++) { + NumberVector o = (NumberVector) bundle.data(off, col); + for(int i = 0; i < dim; i++) { buf[i] = o.doubleValue(i); } O v = ofactory.newNumberVector(buf); - D r = dfactory.fromDouble(o.doubleValue(dim)); - DistanceDBIDList<D> rres = rangeQuery.getRangeForObject(v, r); + double r = o.doubleValue(dim); + DoubleDBIDList rres = rangeQuery.getRangeForObject(v, r); int ichecksum = 0; - for (DBIDIter it = rres.iter(); it.valid(); it.advance()) { - ichecksum += it.internalGetIndex(); + for(DBIDIter it = rres.iter(); it.valid(); it.advance()) { + ichecksum += DBIDUtil.asInteger(it); } hash = Util.mixHashCodes(hash, ichecksum); mv.put(rres.size()); - if (prog != null) { - prog.incrementProcessed(LOG); - } + LOG.incrementProcessed(prog); } - if (prog != null) { - prog.ensureCompleted(LOG); - } - if (LOG.isStatistics()) { + LOG.ensureCompleted(prog); + if(LOG.isStatistics()) { LOG.statistics("Result hashcode: " + hash); LOG.statistics("Mean number of results: " + mv.getMean() + " +- " + mv.getNaiveStddev()); } @@ -277,9 +268,10 @@ public class RangeQueryBenchmarkAlgorithm<O extends NumberVector<?>, D extends N @Override public TypeInformation[] getInputTypeRestriction() { - if (queries == null) { - return TypeUtil.array(getDistanceFunction().getInputTypeRestriction(), new VectorFieldTypeInformation<NumberVector<?>>(NumberVector.class, 1)); - } else { + if(queries == null) { + return TypeUtil.array(getDistanceFunction().getInputTypeRestriction(), TypeUtil.NUMBER_VECTOR_FIELD_1D); + } + else { return TypeUtil.array(getDistanceFunction().getInputTypeRestriction()); } } @@ -297,9 +289,8 @@ public class RangeQueryBenchmarkAlgorithm<O extends NumberVector<?>, D extends N * @author Erich Schubert * * @param <O> Object type - * @param <D> Distance type */ - public static class Parameterizer<O extends NumberVector<?>, D extends NumberDistance<D, ?>> extends AbstractDistanceBasedAlgorithm.Parameterizer<O, D> { + public static class Parameterizer<O extends NumberVector> extends AbstractDistanceBasedAlgorithm.Parameterizer<O> { /** * Parameter for the query dataset. */ @@ -335,22 +326,22 @@ public class RangeQueryBenchmarkAlgorithm<O extends NumberVector<?>, D extends N super.makeOptions(config); ObjectParameter<DatabaseConnection> queryP = new ObjectParameter<>(QUERY_ID, DatabaseConnection.class); queryP.setOptional(true); - if (config.grab(queryP)) { + if(config.grab(queryP)) { queries = queryP.instantiateClass(config); } DoubleParameter samplingP = new DoubleParameter(SAMPLING_ID); samplingP.setOptional(true); - if (config.grab(samplingP)) { + if(config.grab(samplingP)) { sampling = samplingP.doubleValue(); } RandomParameter randomP = new RandomParameter(RANDOM_ID, RandomFactory.DEFAULT); - if (config.grab(randomP)) { + if(config.grab(randomP)) { random = randomP.getValue(); } } @Override - protected RangeQueryBenchmarkAlgorithm<O, D> makeInstance() { + protected RangeQueryBenchmarkAlgorithm<O> makeInstance() { return new RangeQueryBenchmarkAlgorithm<>(distanceFunction, queries, sampling, random); } } |