diff options
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/algorithm/KNNJoin.java')
-rw-r--r-- | src/de/lmu/ifi/dbs/elki/algorithm/KNNJoin.java | 187 |
1 files changed, 58 insertions, 129 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/algorithm/KNNJoin.java b/src/de/lmu/ifi/dbs/elki/algorithm/KNNJoin.java index 0f5078fb..d1acf675 100644 --- a/src/de/lmu/ifi/dbs/elki/algorithm/KNNJoin.java +++ b/src/de/lmu/ifi/dbs/elki/algorithm/KNNJoin.java @@ -4,7 +4,7 @@ package de.lmu.ifi.dbs.elki.algorithm; This file is part of ELKI: Environment for Developing KDD-Applications Supported by Index-Structures - Copyright (C) 2013 + Copyright (C) 2014 Ludwig-Maximilians-Universität München Lehr- und Forschungseinheit für Datenbanksysteme ELKI Development Team @@ -38,15 +38,11 @@ import de.lmu.ifi.dbs.elki.database.datastore.WritableDataStore; import de.lmu.ifi.dbs.elki.database.ids.DBID; import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil; import de.lmu.ifi.dbs.elki.database.ids.DBIDs; -import de.lmu.ifi.dbs.elki.database.ids.distance.DoubleDistanceKNNHeap; -import de.lmu.ifi.dbs.elki.database.ids.distance.KNNHeap; -import de.lmu.ifi.dbs.elki.database.ids.distance.KNNList; +import de.lmu.ifi.dbs.elki.database.ids.KNNHeap; +import de.lmu.ifi.dbs.elki.database.ids.KNNList; import de.lmu.ifi.dbs.elki.database.relation.Relation; -import de.lmu.ifi.dbs.elki.distance.DistanceUtil; import de.lmu.ifi.dbs.elki.distance.distancefunction.DistanceFunction; import de.lmu.ifi.dbs.elki.distance.distancefunction.SpatialPrimitiveDistanceFunction; -import de.lmu.ifi.dbs.elki.distance.distancefunction.SpatialPrimitiveDoubleDistanceFunction; -import de.lmu.ifi.dbs.elki.distance.distancevalue.Distance; import de.lmu.ifi.dbs.elki.index.tree.LeafEntry; import de.lmu.ifi.dbs.elki.index.tree.spatial.SpatialEntry; import de.lmu.ifi.dbs.elki.index.tree.spatial.SpatialIndexTree; @@ -78,25 +74,18 @@ import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter; * @author Erich Schubert * * @param <V> the type of FeatureVector handled by this Algorithm - * @param <D> the type of Distance used by this Algorithm * @param <N> the type of node used in the spatial index structure * @param <E> the type of entry used in the spatial node */ @Title("K-Nearest Neighbor Join") @Description("Algorithm to find the k-nearest neighbors of each object in a spatial database") -public class KNNJoin<V extends NumberVector<?>, D extends Distance<D>, N extends SpatialNode<N, E>, E extends SpatialEntry> extends AbstractDistanceBasedAlgorithm<V, D, DataStore<KNNList<D>>> { +public class KNNJoin<V extends NumberVector, N extends SpatialNode<N, E>, E extends SpatialEntry> extends AbstractDistanceBasedAlgorithm<V, DataStore<KNNList>> { /** * The logger for this class. */ private static final Logging LOG = Logging.getLogger(KNNJoin.class); /** - * Parameter that specifies the k-nearest neighbors to be assigned, must be an - * integer greater than 0. Default value: 1. - */ - public static final OptionID K_ID = new OptionID("knnjoin.k", "Specifies the k-nearest neighbors to be assigned."); - - /** * The k parameter. */ int k; @@ -107,7 +96,7 @@ public class KNNJoin<V extends NumberVector<?>, D extends Distance<D>, N extends * @param distanceFunction Distance function * @param k k parameter */ - public KNNJoin(DistanceFunction<? super V, D> distanceFunction, int k) { + public KNNJoin(DistanceFunction<? super V> distanceFunction, int k) { super(distanceFunction); this.k = k; } @@ -120,7 +109,7 @@ public class KNNJoin<V extends NumberVector<?>, D extends Distance<D>, N extends * @return result */ @SuppressWarnings("unchecked") - public WritableDataStore<KNNList<D>> run(Database database, Relation<V> relation) { + public WritableDataStore<KNNList> run(Database database, Relation<V> relation) { if(!(getDistanceFunction() instanceof SpatialPrimitiveDistanceFunction)) { throw new IllegalStateException("Distance Function must be an instance of " + SpatialPrimitiveDistanceFunction.class.getName()); } @@ -130,13 +119,13 @@ public class KNNJoin<V extends NumberVector<?>, D extends Distance<D>, N extends } // FIXME: Ensure were looking at the right relation! SpatialIndexTree<N, E> index = indexes.iterator().next(); - SpatialPrimitiveDistanceFunction<V, D> distFunction = (SpatialPrimitiveDistanceFunction<V, D>) getDistanceFunction(); + SpatialPrimitiveDistanceFunction<V> distFunction = (SpatialPrimitiveDistanceFunction<V>) getDistanceFunction(); DBIDs ids = relation.getDBIDs(); // data pages List<E> ps_candidates = new ArrayList<>(index.getLeaves()); // knn heaps - List<List<KNNHeap<D>>> heaps = new ArrayList<>(ps_candidates.size()); + List<List<KNNHeap>> heaps = new ArrayList<>(ps_candidates.size()); ComparableMinHeap<Task> pq = new ComparableMinHeap<>(ps_candidates.size() * ps_candidates.size() / 10); // Initialize with the page self-pairing @@ -154,81 +143,65 @@ public class KNNJoin<V extends NumberVector<?>, D extends Distance<D>, N extends FiniteProgress mprogress = LOG.isVerbose() ? new FiniteProgress("Comparing leaf MBRs", sqsize, LOG) : null; for(int i = 0; i < ps_candidates.size(); i++) { E pr_entry = ps_candidates.get(i); - List<KNNHeap<D>> pr_heaps = heaps.get(i); - D pr_knn_distance = computeStopDistance(pr_heaps); + List<KNNHeap> pr_heaps = heaps.get(i); + double pr_knn_distance = computeStopDistance(pr_heaps); for(int j = i + 1; j < ps_candidates.size(); j++) { E ps_entry = ps_candidates.get(j); - List<KNNHeap<D>> ps_heaps = heaps.get(j); - D ps_knn_distance = computeStopDistance(ps_heaps); - D minDist = distFunction.minDist(pr_entry, ps_entry); + List<KNNHeap> ps_heaps = heaps.get(j); + double ps_knn_distance = computeStopDistance(ps_heaps); + double minDist = distFunction.minDist(pr_entry, ps_entry); // Resolve immediately: - if(minDist.isNullDistance()) { + if(minDist <= 0.) { N pr = index.getNode(ps_candidates.get(i)); N ps = index.getNode(ps_candidates.get(j)); - processDataPagesOptimize(distFunction, pr_heaps, ps_heaps, pr, ps); + processDataPages(distFunction, pr_heaps, ps_heaps, pr, ps); } - else if(minDist.compareTo(pr_knn_distance) <= 0 || minDist.compareTo(ps_knn_distance) <= 0) { + else if(minDist <= pr_knn_distance || minDist <= ps_knn_distance) { pq.add(new Task(minDist, i, j)); } - if(mprogress != null) { - mprogress.incrementProcessed(LOG); - } + LOG.incrementProcessed(mprogress); } } - if(mprogress != null) { - mprogress.ensureCompleted(LOG); - } + LOG.ensureCompleted(mprogress); // Process the queue FiniteProgress qprogress = LOG.isVerbose() ? new FiniteProgress("Processing queue", pq.size(), LOG) : null; IndefiniteProgress fprogress = LOG.isVerbose() ? new IndefiniteProgress("Full comparisons", LOG) : null; while(!pq.isEmpty()) { Task task = pq.poll(); - List<KNNHeap<D>> pr_heaps = heaps.get(task.i); - List<KNNHeap<D>> ps_heaps = heaps.get(task.j); - D pr_knn_distance = computeStopDistance(pr_heaps); - D ps_knn_distance = computeStopDistance(ps_heaps); - boolean dor = task.mindist.compareTo(pr_knn_distance) <= 0; - boolean dos = task.mindist.compareTo(ps_knn_distance) <= 0; + List<KNNHeap> pr_heaps = heaps.get(task.i); + List<KNNHeap> ps_heaps = heaps.get(task.j); + double pr_knn_distance = computeStopDistance(pr_heaps); + double ps_knn_distance = computeStopDistance(ps_heaps); + boolean dor = task.mindist <= pr_knn_distance; + boolean dos = task.mindist <= ps_knn_distance; if(dor || dos) { N pr = index.getNode(ps_candidates.get(task.i)); N ps = index.getNode(ps_candidates.get(task.j)); if(dor && dos) { - processDataPagesOptimize(distFunction, pr_heaps, ps_heaps, pr, ps); + processDataPages(distFunction, pr_heaps, ps_heaps, pr, ps); } else { if(dor) { - processDataPagesOptimize(distFunction, pr_heaps, null, pr, ps); + processDataPages(distFunction, pr_heaps, null, pr, ps); } else /* dos */{ - processDataPagesOptimize(distFunction, ps_heaps, null, ps, pr); + processDataPages(distFunction, ps_heaps, null, ps, pr); } } - if(fprogress != null) { - fprogress.incrementProcessed(LOG); - } + LOG.incrementProcessed(fprogress); } - if(qprogress != null) { - qprogress.incrementProcessed(LOG); - } - } - if(qprogress != null) { - qprogress.ensureCompleted(LOG); - } - if(fprogress != null) { - fprogress.setCompleted(LOG); + LOG.incrementProcessed(qprogress); } + LOG.ensureCompleted(qprogress); + LOG.setCompleted(fprogress); - WritableDataStore<KNNList<D>> knnLists = DataStoreUtil.makeStorage(ids, DataStoreFactory.HINT_STATIC, KNNList.class); - // FiniteProgress progress = logger.isVerbose() ? new - // FiniteProgress(this.getClass().getName(), relation.size(), logger) : - // null; + WritableDataStore<KNNList> knnLists = DataStoreUtil.makeStorage(ids, DataStoreFactory.HINT_STATIC, KNNList.class); FiniteProgress pageprog = LOG.isVerbose() ? new FiniteProgress("Number of processed data pages", ps_candidates.size(), LOG) : null; - // int processed = 0; for(int i = 0; i < ps_candidates.size(); i++) { N pr = index.getNode(ps_candidates.get(i)); - List<KNNHeap<D>> pr_heaps = heaps.get(i); + List<KNNHeap> pr_heaps = heaps.get(i); // Finalize lists for(int j = 0; j < pr.getNumEntries(); j++) { @@ -236,21 +209,9 @@ public class KNNJoin<V extends NumberVector<?>, D extends Distance<D>, N extends } // Forget heaps and pq heaps.set(i, null); - // processed += pr.getNumEntries(); - - // if(progress != null) { - // progress.setProcessed(processed, logger); - // } - if(pageprog != null) { - pageprog.incrementProcessed(LOG); - } - } - // if(progress != null) { - // progress.ensureCompleted(logger); - // } - if(pageprog != null) { - pageprog.ensureCompleted(LOG); + LOG.incrementProcessed(pageprog); } + LOG.ensureCompleted(pageprog); return knnLists; } @@ -261,15 +222,15 @@ public class KNNJoin<V extends NumberVector<?>, D extends Distance<D>, N extends * @param pr Node to initialize for * @return List of heaps */ - private List<KNNHeap<D>> initHeaps(SpatialPrimitiveDistanceFunction<V, D> distFunction, N pr) { - List<KNNHeap<D>> pr_heaps = new ArrayList<>(pr.getNumEntries()); + private List<KNNHeap> initHeaps(SpatialPrimitiveDistanceFunction<V> distFunction, N pr) { + List<KNNHeap> pr_heaps = new ArrayList<>(pr.getNumEntries()); // Create for each data object a knn heap for(int j = 0; j < pr.getNumEntries(); j++) { - pr_heaps.add(DBIDUtil.newHeap(distFunction.getDistanceFactory(), k)); + pr_heaps.add(DBIDUtil.newHeap(k)); } // Self-join first, as this is expected to improve most and cannot be // pruned. - processDataPagesOptimize(distFunction, pr_heaps, null, pr, pr); + processDataPages(distFunction, pr_heaps, null, pr, pr); return pr_heaps; } @@ -277,53 +238,20 @@ public class KNNJoin<V extends NumberVector<?>, D extends Distance<D>, N extends * Processes the two data pages pr and ps and determines the k-nearest * neighbors of pr in ps. * - * @param distFunction the distance to use - * @param pr the first data page - * @param ps the second data page - * @param pr_heaps the knn lists for each data object in pr - * @param ps_heaps the knn lists for each data object in ps (if ps != pr) - */ - @SuppressWarnings("unchecked") - private void processDataPagesOptimize(SpatialPrimitiveDistanceFunction<V, D> distFunction, List<? extends KNNHeap<D>> pr_heaps, List<? extends KNNHeap<D>> ps_heaps, N pr, N ps) { - if(DistanceUtil.isDoubleDistanceFunction(distFunction)) { - List<?> khp = (List<?>) pr_heaps; - List<?> khs = (List<?>) ps_heaps; - processDataPagesDouble((SpatialPrimitiveDoubleDistanceFunction<? super V>) distFunction, pr, ps, (List<DoubleDistanceKNNHeap>) khp, (List<DoubleDistanceKNNHeap>) khs); - } - else { - for(int j = 0; j < ps.getNumEntries(); j++) { - final SpatialPointLeafEntry s_e = (SpatialPointLeafEntry) ps.getEntry(j); - DBID s_id = s_e.getDBID(); - for(int i = 0; i < pr.getNumEntries(); i++) { - final SpatialPointLeafEntry r_e = (SpatialPointLeafEntry) pr.getEntry(i); - D distance = distFunction.minDist(s_e, r_e); - pr_heaps.get(i).insert(distance, s_id); - if(pr != ps && ps_heaps != null) { - ps_heaps.get(j).insert(distance, r_e.getDBID()); - } - } - } - } - } - - /** - * Processes the two data pages pr and ps and determines the k-nearest - * neighbors of pr in ps. - * * @param df the distance function to use * @param pr the first data page * @param ps the second data page * @param pr_heaps the knn lists for each data object * @param ps_heaps the knn lists for each data object in ps */ - private void processDataPagesDouble(SpatialPrimitiveDoubleDistanceFunction<? super V> df, N pr, N ps, List<DoubleDistanceKNNHeap> pr_heaps, List<DoubleDistanceKNNHeap> ps_heaps) { + private void processDataPages(SpatialPrimitiveDistanceFunction<? super V> df, List<KNNHeap> pr_heaps, List<KNNHeap> ps_heaps, N pr, N ps) { // Compare pairwise for(int j = 0; j < ps.getNumEntries(); j++) { final SpatialPointLeafEntry s_e = (SpatialPointLeafEntry) ps.getEntry(j); DBID s_id = s_e.getDBID(); for(int i = 0; i < pr.getNumEntries(); i++) { final SpatialPointLeafEntry r_e = (SpatialPointLeafEntry) pr.getEntry(i); - double distance = df.doubleMinDist(s_e, r_e); + double distance = df.minDist(s_e, r_e); pr_heaps.get(i).insert(distance, s_id); if(pr != ps && ps_heaps != null) { ps_heaps.get(j).insert(distance, r_e.getDBID()); @@ -338,20 +266,16 @@ public class KNNJoin<V extends NumberVector<?>, D extends Distance<D>, N extends * @param heaps Heaps list * @return the k-nearest neighbor distance of pr in ps */ - private D computeStopDistance(List<KNNHeap<D>> heaps) { + private double computeStopDistance(List<KNNHeap> heaps) { // Update pruning distance - D pr_knn_distance = null; - for(KNNHeap<D> knnList : heaps) { + double pr_knn_distance = Double.NaN; + for(KNNHeap knnList : heaps) { // set kNN distance of r - if(pr_knn_distance == null) { - pr_knn_distance = knnList.getKNNDistance(); - } - else { - pr_knn_distance = DistanceUtil.max(knnList.getKNNDistance(), pr_knn_distance); - } + double kdist = knnList.getKNNDistance(); + pr_knn_distance = (kdist < pr_knn_distance) ? pr_knn_distance : kdist; } - if(pr_knn_distance == null) { - return getDistanceFunction().getDistanceFactory().infiniteDistance(); + if(pr_knn_distance != pr_knn_distance) { + return Double.POSITIVE_INFINITY; } return pr_knn_distance; } @@ -377,7 +301,7 @@ public class KNNJoin<V extends NumberVector<?>, D extends Distance<D>, N extends /** * Minimum distance. */ - final D mindist; + final double mindist; /** * First offset. @@ -396,7 +320,7 @@ public class KNNJoin<V extends NumberVector<?>, D extends Distance<D>, N extends * @param i First offset * @param j Second offset */ - public Task(D mindist, int i, int j) { + public Task(double mindist, int i, int j) { super(); this.mindist = mindist; this.i = i; @@ -405,7 +329,7 @@ public class KNNJoin<V extends NumberVector<?>, D extends Distance<D>, N extends @Override public int compareTo(Task o) { - return mindist.compareTo(o.mindist); + return Double.compare(mindist, o.mindist); } } @@ -416,7 +340,12 @@ public class KNNJoin<V extends NumberVector<?>, D extends Distance<D>, N extends * * @apiviz.exclude */ - public static class Parameterizer<V extends NumberVector<?>, D extends Distance<D>, N extends SpatialNode<N, E>, E extends SpatialEntry> extends AbstractPrimitiveDistanceBasedAlgorithm.Parameterizer<V, D> { + public static class Parameterizer<V extends NumberVector, N extends SpatialNode<N, E>, E extends SpatialEntry> extends AbstractPrimitiveDistanceBasedAlgorithm.Parameterizer<V> { + /** + * Parameter that specifies the k-nearest neighbors to be assigned, must be an + * integer greater than 0. Default value: 1. + */ + public static final OptionID K_ID = new OptionID("knnjoin.k", "Specifies the k-nearest neighbors to be assigned."); /** * K parameter. */ @@ -433,7 +362,7 @@ public class KNNJoin<V extends NumberVector<?>, D extends Distance<D>, N extends } @Override - protected KNNJoin<V, D, N, E> makeInstance() { + protected KNNJoin<V, N, E> makeInstance() { return new KNNJoin<>(distanceFunction, k); } } |