diff options
Diffstat (limited to 'src/tutorial/clustering/SameSizeKMeansAlgorithm.java')
-rw-r--r-- | src/tutorial/clustering/SameSizeKMeansAlgorithm.java | 51 |
1 files changed, 23 insertions, 28 deletions
diff --git a/src/tutorial/clustering/SameSizeKMeansAlgorithm.java b/src/tutorial/clustering/SameSizeKMeansAlgorithm.java index 91da44c5..ad57885e 100644 --- a/src/tutorial/clustering/SameSizeKMeansAlgorithm.java +++ b/src/tutorial/clustering/SameSizeKMeansAlgorithm.java @@ -4,7 +4,7 @@ package tutorial.clustering; This file is part of ELKI: Environment for Developing KDD-Applications Supported by Index-Structures - Copyright (C) 2013 + Copyright (C) 2014 Ludwig-Maximilians-Universität München Lehr- und Forschungseinheit für Datenbanksysteme ELKI Development Team @@ -29,8 +29,8 @@ import java.util.Comparator; import java.util.List; import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.AbstractKMeans; -import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.KMeansInitialization; -import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.KMeansPlusPlusInitialMeans; +import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.initialization.KMeansInitialization; +import de.lmu.ifi.dbs.elki.algorithm.clustering.kmeans.initialization.KMeansPlusPlusInitialMeans; import de.lmu.ifi.dbs.elki.data.Cluster; import de.lmu.ifi.dbs.elki.data.Clustering; import de.lmu.ifi.dbs.elki.data.NumberVector; @@ -48,12 +48,11 @@ import de.lmu.ifi.dbs.elki.database.ids.DBIDUtil; import de.lmu.ifi.dbs.elki.database.ids.DBIDs; import de.lmu.ifi.dbs.elki.database.ids.ModifiableDBIDs; import de.lmu.ifi.dbs.elki.database.relation.Relation; -import de.lmu.ifi.dbs.elki.database.relation.RelationUtil; -import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDoubleDistanceFunction; +import de.lmu.ifi.dbs.elki.distance.distancefunction.PrimitiveDistanceFunction; import de.lmu.ifi.dbs.elki.distance.distancefunction.minkowski.EuclideanDistanceFunction; import de.lmu.ifi.dbs.elki.distance.distancefunction.minkowski.SquaredEuclideanDistanceFunction; -import de.lmu.ifi.dbs.elki.distance.distancevalue.DoubleDistance; import de.lmu.ifi.dbs.elki.logging.Logging; +import de.lmu.ifi.dbs.elki.math.linearalgebra.Vector; import de.lmu.ifi.dbs.elki.utilities.datastructures.arrays.IntegerArrayQuickSort; import de.lmu.ifi.dbs.elki.utilities.datastructures.arrays.IntegerComparator; import de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer; @@ -78,7 +77,7 @@ import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter; * * @param <V> Vector type */ -public class SameSizeKMeansAlgorithm<V extends NumberVector<?>> extends AbstractKMeans<V, DoubleDistance, MeanModel<V>> { +public class SameSizeKMeansAlgorithm<V extends NumberVector> extends AbstractKMeans<V, MeanModel> { /** * Class logger */ @@ -92,7 +91,7 @@ public class SameSizeKMeansAlgorithm<V extends NumberVector<?>> extends Abstract * @param maxiter Maximum number of iterations * @param initializer */ - public SameSizeKMeansAlgorithm(PrimitiveDoubleDistanceFunction<? super NumberVector<?>> distanceFunction, int k, int maxiter, KMeansInitialization<V> initializer) { + public SameSizeKMeansAlgorithm(PrimitiveDistanceFunction<? super NumberVector> distanceFunction, int k, int maxiter, KMeansInitialization<? super V> initializer) { super(distanceFunction, k, maxiter, initializer); } @@ -104,11 +103,11 @@ public class SameSizeKMeansAlgorithm<V extends NumberVector<?>> extends Abstract * @return result */ @Override - public Clustering<MeanModel<V>> run(Database database, Relation<V> relation) { + public Clustering<MeanModel> run(Database database, Relation<V> relation) { // Database objects to process final DBIDs ids = relation.getDBIDs(); // Choose initial means - List<? extends NumberVector<?>> means = initializer.chooseInitialMeans(database, relation, k, getDistanceFunction()); + List<Vector> means = initializer.chooseInitialMeans(database, relation, k, getDistanceFunction(), Vector.FACTORY); // Setup cluster assignment store List<ModifiableDBIDs> clusters = new ArrayList<>(); for(int i = 0; i < k; i++) { @@ -125,11 +124,10 @@ public class SameSizeKMeansAlgorithm<V extends NumberVector<?>> extends Abstract means = refineResult(relation, means, clusters, metas, tids); // Wrap result - Clustering<MeanModel<V>> result = new Clustering<>("k-Means Samesize Clustering", "kmeans-samesize-clustering"); - final NumberVector.Factory<V, ?> factory = RelationUtil.getNumberVectorFactory(relation); + Clustering<MeanModel> result = new Clustering<>("k-Means Samesize Clustering", "kmeans-samesize-clustering"); for(int i = 0; i < clusters.size(); i++) { - V mean = factory.newNumberVector(means.get(i).getColumnVector().getArrayRef()); - result.addToplevelCluster(new Cluster<>(clusters.get(i), new MeanModel<>(mean))); + Vector mean = means.get(i).getColumnVector(); + result.addToplevelCluster(new Cluster<>(clusters.get(i), new MeanModel(mean))); } return result; } @@ -141,10 +139,8 @@ public class SameSizeKMeansAlgorithm<V extends NumberVector<?>> extends Abstract * @param means Mean vectors * @return Initialized storage */ - protected WritableDataStore<Meta> initializeMeta(Relation<V> relation, List<? extends NumberVector<?>> means) { - // This is a safe cast - see constructor. - @SuppressWarnings("unchecked") - PrimitiveDoubleDistanceFunction<NumberVector<?>> df = (PrimitiveDoubleDistanceFunction<NumberVector<?>>) getDistanceFunction(); + protected WritableDataStore<Meta> initializeMeta(Relation<V> relation, List<? extends NumberVector> means) { + PrimitiveDistanceFunction<? super NumberVector> df = getDistanceFunction(); // The actual storage final WritableDataStore<Meta> metas = DataStoreUtil.makeStorage(relation.getDBIDs(), DataStoreFactory.HINT_HOT | DataStoreFactory.HINT_TEMP, Meta.class); // Build the metadata, track the two nearest cluster centers. @@ -152,7 +148,7 @@ public class SameSizeKMeansAlgorithm<V extends NumberVector<?>> extends Abstract Meta c = new Meta(k); V fv = relation.get(id); for(int i = 0; i < k; i++) { - c.dists[i] = df.doubleDistance(fv, means.get(i)); + c.dists[i] = df.distance(fv, means.get(i)); if(i > 0) { if(c.dists[i] < c.dists[c.primary]) { c.primary = i; @@ -231,14 +227,14 @@ public class SameSizeKMeansAlgorithm<V extends NumberVector<?>> extends Abstract * @param metas Metadata storage * @param df Distance function */ - protected void updateDistances(Relation<V> relation, List<? extends NumberVector<?>> means, final WritableDataStore<Meta> metas, PrimitiveDoubleDistanceFunction<NumberVector<?>> df) { + protected void updateDistances(Relation<V> relation, List<Vector> means, final WritableDataStore<Meta> metas, PrimitiveDistanceFunction<? super NumberVector> df) { for(DBIDIter id = relation.iterDBIDs(); id.valid(); id.advance()) { Meta c = metas.get(id); V fv = relation.get(id); // Update distances to means. c.secondary = -1; for(int i = 0; i < k; i++) { - c.dists[i] = df.doubleDistance(fv, means.get(i)); + c.dists[i] = df.distance(fv, means.get(i)); if(c.primary != i) { if(c.secondary < 0 || c.dists[i] < c.dists[c.secondary]) { c.secondary = i; @@ -259,10 +255,9 @@ public class SameSizeKMeansAlgorithm<V extends NumberVector<?>> extends Abstract * @param tids DBIDs array * @return final means */ - protected List<? extends NumberVector<?>> refineResult(Relation<V> relation, List<? extends NumberVector<?>> means, List<ModifiableDBIDs> clusters, final WritableDataStore<Meta> metas, ArrayModifiableDBIDs tids) { + protected List<Vector> refineResult(Relation<V> relation, List<Vector> means, List<ModifiableDBIDs> clusters, final WritableDataStore<Meta> metas, ArrayModifiableDBIDs tids) { // This is a safe cast - see constructor. - @SuppressWarnings("unchecked") - PrimitiveDoubleDistanceFunction<NumberVector<?>> df = (PrimitiveDoubleDistanceFunction<NumberVector<?>>) getDistanceFunction(); + PrimitiveDistanceFunction<? super NumberVector> df = getDistanceFunction(); // Our desired cluster size: final int minsize = tids.size() / k; // rounded down final int maxsize = (tids.size() + k - 1) / k; // rounded up @@ -289,7 +284,7 @@ public class SameSizeKMeansAlgorithm<V extends NumberVector<?>> extends Abstract transfers[i] = DBIDUtil.newArray(); } - for(int iter = 0; maxiter < 0 || iter < maxiter; iter++) { + for(int iter = 0; maxiter <= 0 || iter < maxiter; iter++) { updateDistances(relation, means, metas, df); tids.sort(comp); int active = 0; // Track if anything has changed @@ -461,7 +456,7 @@ public class SameSizeKMeansAlgorithm<V extends NumberVector<?>> extends Abstract * * @apiviz.exclude */ - public static class Parameterizer<V extends NumberVector<?>> extends AbstractParameterizer { + public static class Parameterizer<V extends NumberVector> extends AbstractParameterizer { /** * k Parameter. */ @@ -480,12 +475,12 @@ public class SameSizeKMeansAlgorithm<V extends NumberVector<?>> extends Abstract /** * Distance function */ - protected PrimitiveDoubleDistanceFunction<? super NumberVector<?>> distanceFunction; + protected PrimitiveDistanceFunction<? super NumberVector> distanceFunction; @Override protected void makeOptions(Parameterization config) { super.makeOptions(config); - ObjectParameter<PrimitiveDoubleDistanceFunction<? super NumberVector<?>>> distanceFunctionP = makeParameterDistanceFunction(SquaredEuclideanDistanceFunction.class, PrimitiveDoubleDistanceFunction.class); + ObjectParameter<PrimitiveDistanceFunction<? super NumberVector>> distanceFunctionP = makeParameterDistanceFunction(SquaredEuclideanDistanceFunction.class, PrimitiveDistanceFunction.class); if(config.grab(distanceFunctionP)) { distanceFunction = distanceFunctionP.instantiateClass(config); if(!(distanceFunction instanceof EuclideanDistanceFunction) && !(distanceFunction instanceof SquaredEuclideanDistanceFunction)) { |