summaryrefslogtreecommitdiff
path: root/src/de/lmu/ifi/dbs/elki/application/ClassifierHoldoutEvaluationTask.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/application/ClassifierHoldoutEvaluationTask.java')
-rw-r--r--src/de/lmu/ifi/dbs/elki/application/ClassifierHoldoutEvaluationTask.java253
1 files changed, 253 insertions, 0 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/application/ClassifierHoldoutEvaluationTask.java b/src/de/lmu/ifi/dbs/elki/application/ClassifierHoldoutEvaluationTask.java
new file mode 100644
index 00000000..c4e364e6
--- /dev/null
+++ b/src/de/lmu/ifi/dbs/elki/application/ClassifierHoldoutEvaluationTask.java
@@ -0,0 +1,253 @@
+package de.lmu.ifi.dbs.elki.application;
+
+/*
+ This file is part of ELKI:
+ Environment for Developing KDD-Applications Supported by Index-Structures
+
+ Copyright (C) 2014
+ Ludwig-Maximilians-Universität München
+ Lehr- und Forschungseinheit für Datenbanksysteme
+ ELKI Development Team
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+
+import de.lmu.ifi.dbs.elki.algorithm.classification.Classifier;
+import de.lmu.ifi.dbs.elki.data.ClassLabel;
+import de.lmu.ifi.dbs.elki.data.type.TypeUtil;
+import de.lmu.ifi.dbs.elki.database.AbstractDatabase;
+import de.lmu.ifi.dbs.elki.database.Database;
+import de.lmu.ifi.dbs.elki.database.StaticArrayDatabase;
+import de.lmu.ifi.dbs.elki.database.relation.Relation;
+import de.lmu.ifi.dbs.elki.datasource.DatabaseConnection;
+import de.lmu.ifi.dbs.elki.datasource.FileBasedDatabaseConnection;
+import de.lmu.ifi.dbs.elki.datasource.MultipleObjectsBundleDatabaseConnection;
+import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle;
+import de.lmu.ifi.dbs.elki.evaluation.classification.ConfusionMatrix;
+import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.AbstractHoldout;
+import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.Holdout;
+import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.StratifiedCrossValidation;
+import de.lmu.ifi.dbs.elki.evaluation.classification.holdout.TrainingAndTestSet;
+import de.lmu.ifi.dbs.elki.index.IndexFactory;
+import de.lmu.ifi.dbs.elki.logging.Logging;
+import de.lmu.ifi.dbs.elki.logging.statistics.Duration;
+/*
+ This file is part of ELKI:
+ Environment for Developing KDD-Applications Supported by Index-Structures
+
+ Copyright (C) 2014
+ Ludwig-Maximilians-Universität München
+ Lehr- und Forschungseinheit für Datenbanksysteme
+ ELKI Development Team
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
+
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see <http://www.gnu.org/licenses/>.
+ */
+import de.lmu.ifi.dbs.elki.utilities.exceptions.UnableToComplyException;
+import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
+import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
+import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectListParameter;
+import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.ObjectParameter;
+import de.lmu.ifi.dbs.elki.workflow.AlgorithmStep;
+
+/**
+ * Evaluate a classifier.
+ *
+ * TODO: split into application and task.
+ *
+ * TODO: add support for predefined test and training pairs!
+ *
+ * @author Erich Schubert
+ *
+ * @param <O> Object type
+ */
+public class ClassifierHoldoutEvaluationTask<O> extends AbstractApplication {
+ /**
+ * Class logger.
+ */
+ private static final Logging LOG = Logging.getLogger(ClassifierHoldoutEvaluationTask.class);
+
+ /**
+ * Holds the database connection to get the initial data from.
+ */
+ protected DatabaseConnection databaseConnection = null;
+
+ /**
+ * Indexes to add.
+ */
+ protected Collection<IndexFactory<?, ?>> indexFactories;
+
+ /**
+ * Classifier to evaluate.
+ */
+ protected Classifier<O> algorithm;
+
+ /**
+ * Holds the holdout.
+ */
+ protected Holdout holdout;
+
+ /**
+ * Constructor.
+ *
+ * @param databaseConnection Data source
+ * @param indexFactories Data indexes
+ * @param algorithm Classification algorithm
+ * @param holdout Evaluation holdout
+ */
+ public ClassifierHoldoutEvaluationTask(DatabaseConnection databaseConnection, Collection<IndexFactory<?, ?>> indexFactories, Classifier<O> algorithm, Holdout holdout) {
+ this.databaseConnection = databaseConnection;
+ this.indexFactories = indexFactories;
+ this.algorithm = algorithm;
+ this.holdout = holdout;
+ }
+
+ @Override
+ public void run() throws UnableToComplyException {
+ Duration ptime = LOG.newDuration("evaluation.time.load").begin();
+ MultipleObjectsBundle allData = databaseConnection.loadData();
+ holdout.initialize(allData);
+ LOG.statistics(ptime.end());
+
+ Duration time = LOG.newDuration("evaluation.time.total").begin();
+ ArrayList<ClassLabel> labels = holdout.getLabels();
+ int[][] confusion = new int[labels.size()][labels.size()];
+ for(int p = 0; p < holdout.numberOfPartitions(); p++) {
+ TrainingAndTestSet partition = holdout.nextPartitioning();
+ // Load the data set into a database structure (for indexing)
+ Duration dur = LOG.newDuration(this.getClass().getName() + ".fold-" + (p + 1) + ".init.time").begin();
+ Database db = new StaticArrayDatabase(new MultipleObjectsBundleDatabaseConnection(partition.getTraining()), indexFactories);
+ db.initialize();
+ LOG.statistics(dur.end());
+ // Train the classifier
+ dur = LOG.newDuration(this.getClass().getName() + ".fold-" + (p + 1) + ".train.time").begin();
+ Relation<ClassLabel> lrel = db.getRelation(TypeUtil.CLASSLABEL);
+ algorithm.buildClassifier(db, lrel);
+ LOG.statistics(dur.end());
+ // Evaluate the test set
+ dur = LOG.newDuration(this.getClass().getName() + ".fold-" + (p + 1) + ".evaluation.time").begin();
+ // FIXME: this part is still a big hack, unfortunately!
+ MultipleObjectsBundle test = partition.getTest();
+ int lcol = AbstractHoldout.findClassLabelColumn(test);
+ int tcol = (lcol == 0) ? 1 : 0;
+ for(int i = 0, l = test.dataLength(); i < l; ++i) {
+ @SuppressWarnings("unchecked")
+ O obj = (O) test.data(i, tcol);
+ ClassLabel truelbl = (ClassLabel) test.data(i, lcol);
+ ClassLabel predlbl = algorithm.classify(obj);
+ int pred = Collections.binarySearch(labels, predlbl);
+ int real = Collections.binarySearch(labels, truelbl);
+ confusion[pred][real]++;
+ }
+ LOG.statistics(dur.end());
+ }
+ LOG.statistics(time.end());
+ ConfusionMatrix m = new ConfusionMatrix(labels, confusion);
+ LOG.statistics(m.toString());
+ }
+
+ /**
+ * Parameterization class.
+ *
+ * @author Erich Schubert
+ *
+ * @apiviz.exclude
+ */
+ public static class Parameterizer<O> extends AbstractApplication.Parameterizer {
+ /**
+ * Parameter to specify the holdout for evaluation, must extend
+ * {@link de.lmu.ifi.dbs.elki.evaluation.classification.holdout.Holdout}.
+ * <p>
+ * Key: {@code -classifier.holdout}
+ * </p>
+ * <p>
+ * Default value: {@link StratifiedCrossValidation}
+ * </p>
+ */
+ public static final OptionID HOLDOUT_ID = new OptionID("evaluation.holdout", "Holdout class used in evaluation.");
+
+ /**
+ * Holds the database connection to get the initial data from.
+ */
+ protected DatabaseConnection databaseConnection = null;
+
+ /**
+ * Indexes to add.
+ */
+ protected Collection<IndexFactory<?, ?>> indexFactories;
+
+ /**
+ * Classifier to evaluate.
+ */
+ protected Classifier<O> algorithm;
+
+ /**
+ * Holds the holdout.
+ */
+ protected Holdout holdout;
+
+ @Override
+ protected void makeOptions(Parameterization config) {
+ super.makeOptions(config);
+ // Get database connection.
+ final ObjectParameter<DatabaseConnection> dbcP = new ObjectParameter<>(AbstractDatabase.Parameterizer.DATABASE_CONNECTION_ID, DatabaseConnection.class, FileBasedDatabaseConnection.class);
+ if(config.grab(dbcP)) {
+ databaseConnection = dbcP.instantiateClass(config);
+ }
+ // Get indexes.
+ final ObjectListParameter<IndexFactory<?, ?>> indexFactoryP = new ObjectListParameter<>(AbstractDatabase.Parameterizer.INDEX_ID, IndexFactory.class, true);
+ if(config.grab(indexFactoryP)) {
+ indexFactories = indexFactoryP.instantiateClasses(config);
+ }
+ ObjectParameter<Classifier<O>> algorithmP = new ObjectParameter<>(AlgorithmStep.Parameterizer.ALGORITHM_ID, Classifier.class);
+ if(config.grab(algorithmP)) {
+ algorithm = algorithmP.instantiateClass(config);
+ }
+
+ ObjectParameter<Holdout> holdoutP = new ObjectParameter<>(HOLDOUT_ID, Holdout.class, StratifiedCrossValidation.class);
+ if(config.grab(holdoutP)) {
+ holdout = holdoutP.instantiateClass(config);
+ }
+ }
+
+ @Override
+ protected ClassifierHoldoutEvaluationTask<O> makeInstance() {
+ return new ClassifierHoldoutEvaluationTask<O>(databaseConnection, indexFactories, algorithm, holdout);
+ }
+ }
+
+ /**
+ * Runs the classifier evaluation task accordingly to the specified
+ * parameters.
+ *
+ * @param args parameter list according to description
+ */
+ public static void main(String[] args) {
+ runCLIApplication(ClassifierHoldoutEvaluationTask.class, args);
+ }
+}