diff options
Diffstat (limited to 'src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout')
9 files changed, 908 insertions, 0 deletions
diff --git a/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/AbstractHoldout.java b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/AbstractHoldout.java new file mode 100644 index 00000000..af10eac0 --- /dev/null +++ b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/AbstractHoldout.java @@ -0,0 +1,118 @@ +package de.lmu.ifi.dbs.elki.evaluation.classification.holdout; + +/* + This file is part of ELKI: + Environment for Developing KDD-Applications Supported by Index-Structures + + Copyright (C) 2014 + Ludwig-Maximilians-Universität München + Lehr- und Forschungseinheit für Datenbanksysteme + ELKI Development Team + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; + +import de.lmu.ifi.dbs.elki.data.ClassLabel; +import de.lmu.ifi.dbs.elki.data.type.TypeUtil; +import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle; +import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException; + +/** + * Split a data set for holdout evaluation. + * + * @author Erich Schubert + */ +public abstract class AbstractHoldout implements Holdout { + /** + * Labels in the current data set. + */ + protected ArrayList<ClassLabel> labels; + + /** + * Column containing the class labels. + */ + protected int labelcol; + + /** + * Input data bundle. + */ + protected MultipleObjectsBundle bundle; + + @Override + public void initialize(MultipleObjectsBundle bundle) { + this.bundle = bundle; + this.labelcol = findClassLabelColumn(bundle); + this.labels = allClassLabels(bundle); + } + + @Override + public ArrayList<ClassLabel> getLabels() { + return labels; + } + + /** + * Find the class label column in the given data set. + * + * @param bundle Bundle + * @return Class label column + */ + public static int findClassLabelColumn(MultipleObjectsBundle bundle) { + for(int i = 0, l = bundle.metaLength(); i < l; ++i) { + if(TypeUtil.CLASSLABEL.isAssignableFromType(bundle.meta(i))) { + return i; + } + } + return -1; + } + + /** + * Get an array of all class labels in a given data set. + * + * @param bundle Bundle + * @return Class labels. + */ + public static ArrayList<ClassLabel> allClassLabels(MultipleObjectsBundle bundle) { + int col = findClassLabelColumn(bundle); + // TODO: automatically infer class labels? + if(col < 0) { + throw new AbortException("No class label found (try using ClassLabelFilter)."); + } + return allClassLabels(bundle, col); + } + + /** + * Get an array of all class labels in a given data set. + * + * @param bundle Bundle + * @param col Column + * @return Class labels. + */ + public static ArrayList<ClassLabel> allClassLabels(MultipleObjectsBundle bundle, int col) { + HashSet<ClassLabel> labels = new HashSet<ClassLabel>(); + for(int i = 0, l = bundle.dataLength(); i < l; ++i) { + Object o = bundle.data(i, col); + if(o == null || !(o instanceof ClassLabel)) { + continue; + } + labels.add((ClassLabel) o); + } + ArrayList<ClassLabel> ret = new ArrayList<>(labels); + Collections.sort(ret); + return ret; + } +} diff --git a/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/DisjointCrossValidation.java b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/DisjointCrossValidation.java new file mode 100644 index 00000000..75503018 --- /dev/null +++ b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/DisjointCrossValidation.java @@ -0,0 +1,144 @@ +package de.lmu.ifi.dbs.elki.evaluation.classification.holdout; + +/* + This file is part of ELKI: + Environment for Developing KDD-Applications Supported by Index-Structures + + Copyright (C) 2014 + Ludwig-Maximilians-Universität München + Lehr- und Forschungseinheit für Datenbanksysteme + ELKI Development Team + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +import java.util.ArrayList; +import java.util.Random; + +import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle; +import de.lmu.ifi.dbs.elki.math.random.RandomFactory; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter; + +/** + * DisjointCrossValidationHoldout provides a set of partitions of a database to + * perform cross-validation. The test sets are guaranteed to be disjoint. + * + * @author Arthur Zimek + */ +public class DisjointCrossValidation extends RandomizedHoldout { + /** + * Holds the number of folds, current fold. + */ + protected int nfold, fold; + + /** + * Partition assignment and size. + */ + protected int[] assignment, sizes; + + /** + * Constructor. + * + * @param random Random seeding + * @param nfold Number of folds. + */ + public DisjointCrossValidation(RandomFactory random, int nfold) { + super(random); + this.nfold = nfold; + } + + @Override + public void initialize(MultipleObjectsBundle bundle) { + super.initialize(bundle); + fold = 0; + + Random rnd = random.getSingleThreadedRandom(); + sizes = new int[nfold]; + assignment = new int[bundle.dataLength()]; + for(int i = 0; i < assignment.length; ++i) { + int p = rnd.nextInt(nfold); + assignment[i] = p; + ++sizes[p]; + } + } + + @Override + public int numberOfPartitions() { + return nfold; + } + + @Override + public TrainingAndTestSet nextPartitioning() { + if(fold >= nfold) { + return null; + } + final int tesize = sizes[fold], trsize = bundle.dataLength() - tesize; + MultipleObjectsBundle training = new MultipleObjectsBundle(); + MultipleObjectsBundle test = new MultipleObjectsBundle(); + // Process column-wise. + for(int c = 0, cs = bundle.metaLength(); c < cs; ++c) { + ArrayList<Object> tr = new ArrayList<>(trsize), te = new ArrayList<>(tesize); + for(int i = 0; i < bundle.dataLength(); ++i) { + ((assignment[i] != fold) ? tr : te).add(bundle.data(i, c)); + } + training.appendColumn(bundle.meta(c), tr); + test.appendColumn(bundle.meta(c), te); + } + + ++fold; + return new TrainingAndTestSet(training, test, labels); + } + + /** + * Parameterization class + * + * @author Erich Schubert + * + * @apiviz.exclude + */ + public static class Parameterizer extends RandomizedHoldout.Parameterizer { + /** + * Default number of folds. + */ + public static final int N_DEFAULT = 10; + + /** + * Parameter for number of folds. + */ + public static final OptionID NFOLD_ID = new OptionID("nfold", "Number of folds for cross-validation."); + + /** + * Holds the number of folds. + */ + protected int nfold = N_DEFAULT; + + @Override + protected void makeOptions(Parameterization config) { + super.makeOptions(config); + IntParameter nfoldP = new IntParameter(NFOLD_ID, N_DEFAULT)// + .addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT); + if(config.grab(nfoldP)) { + nfold = nfoldP.intValue(); + } + } + + @Override + protected DisjointCrossValidation makeInstance() { + return new DisjointCrossValidation(random, nfold); + } + } +}
\ No newline at end of file diff --git a/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/Holdout.java b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/Holdout.java new file mode 100644 index 00000000..17cd56ac --- /dev/null +++ b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/Holdout.java @@ -0,0 +1,67 @@ +package de.lmu.ifi.dbs.elki.evaluation.classification.holdout; + +/* + This file is part of ELKI: + Environment for Developing KDD-Applications Supported by Index-Structures + + Copyright (C) 2014 + Ludwig-Maximilians-Universität München + Lehr- und Forschungseinheit für Datenbanksysteme + ELKI Development Team + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +import java.util.ArrayList; + +import de.lmu.ifi.dbs.elki.data.ClassLabel; +import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle; + +/** + * A holdout procedure is to provide a range of partitions of a database to + * pairs of training and test data sets. + * + * @author Erich Schubert + */ +public interface Holdout { + /** + * Initialize the holdout procedure for a data set. + * + * @param bundle Data set bundle + */ + void initialize(MultipleObjectsBundle bundle); + + /** + * Get the next partitioning of the given holdout. + * + * @return Next partitioning of the data set + */ + TrainingAndTestSet nextPartitioning(); + + /** + * Get the <i>sorted</i> class labels present in this data set. + * + * For indexing into assignment arrays. + * + * @return Class labels + */ + ArrayList<ClassLabel> getLabels(); + + /** + * How many partitions to test. + * + * @return Number of partitions. + */ + int numberOfPartitions(); +} diff --git a/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/LeaveOneOut.java b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/LeaveOneOut.java new file mode 100644 index 00000000..1059f0a1 --- /dev/null +++ b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/LeaveOneOut.java @@ -0,0 +1,82 @@ +package de.lmu.ifi.dbs.elki.evaluation.classification.holdout; + +/* + This file is part of ELKI: + Environment for Developing KDD-Applications Supported by Index-Structures + + Copyright (C) 2014 + Ludwig-Maximilians-Universität München + Lehr- und Forschungseinheit für Datenbanksysteme + ELKI Development Team + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +import java.util.ArrayList; + +import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle; + +/** + * A leave-one-out-holdout is to provide a set of partitions of a database where + * each instances once hold out as a test instance while the respectively + * remaining instances are training instances. + * + * @author Arthur Zimek + */ +public class LeaveOneOut extends AbstractHoldout { + /** + * Size of the data set. + */ + private int len, pos; + + /** + * Constructor. + */ + public LeaveOneOut() { + super(); + } + + @Override + public void initialize(MultipleObjectsBundle bundle) { + super.initialize(bundle); + len = bundle.dataLength(); + pos = 0; + } + + @Override + public int numberOfPartitions() { + return len; + } + + @Override + public TrainingAndTestSet nextPartitioning() { + if(pos >= len) { + return null; + } + MultipleObjectsBundle training = new MultipleObjectsBundle(); + MultipleObjectsBundle test = new MultipleObjectsBundle(); + // Process column-wise. + for(int c = 0, cs = bundle.metaLength(); c < cs; ++c) { + ArrayList<Object> tr = new ArrayList<>(len - 1), te = new ArrayList<>(1); + for(int i = 0; i < bundle.dataLength(); ++i) { + ((i != pos) ? tr : te).add(bundle.data(i, c)); + } + training.appendColumn(bundle.meta(c), tr); + test.appendColumn(bundle.meta(c), te); + } + + ++pos; + return new TrainingAndTestSet(training, test, labels); + } +}
\ No newline at end of file diff --git a/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/RandomizedCrossValidation.java b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/RandomizedCrossValidation.java new file mode 100644 index 00000000..99607398 --- /dev/null +++ b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/RandomizedCrossValidation.java @@ -0,0 +1,141 @@ +package de.lmu.ifi.dbs.elki.evaluation.classification.holdout; + +/* + This file is part of ELKI: + Environment for Developing KDD-Applications Supported by Index-Structures + + Copyright (C) 2014 + Ludwig-Maximilians-Universität München + Lehr- und Forschungseinheit für Datenbanksysteme + ELKI Development Team + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +import java.util.ArrayList; +import java.util.Random; + +import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle; +import de.lmu.ifi.dbs.elki.math.random.RandomFactory; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter; + +/** + * RandomizedCrossValidationHoldout provides a set of partitions of a database + * to perform cross-validation. The test sets are not guaranteed to be disjoint. + * + * @author Arthur Zimek + */ +public class RandomizedCrossValidation extends RandomizedHoldout { + /** + * Holds the number of folds, current fold. + */ + protected int nfold, fold; + + /** + * Constructor for n-fold cross-validation. + * + * @param random Random seed + * @param nfold Number of folds + */ + public RandomizedCrossValidation(RandomFactory random, int nfold) { + super(random); + this.nfold = nfold; + } + + @Override + public void initialize(MultipleObjectsBundle bundle) { + super.initialize(bundle); + this.fold = 0; + } + + @Override + public int numberOfPartitions() { + return nfold; + } + + @Override + public TrainingAndTestSet nextPartitioning() { + if(fold >= nfold) { + return null; + } + MultipleObjectsBundle training = new MultipleObjectsBundle(); + MultipleObjectsBundle test = new MultipleObjectsBundle(); + Random rnd = random.getRandom(); + int datalen = bundle.dataLength(); + + boolean[] assignment = new boolean[datalen]; + int trsize = 0, tesize = 0; + for(int i = 0; i < assignment.length; ++i) { + boolean p = rnd.nextInt(nfold) < nfold - 1; + assignment[i] = p; + @SuppressWarnings("unused") + int discard = p ? ++trsize : ++tesize; + } + // Process column-wise. + for(int c = 0, cs = bundle.metaLength(); c < cs; ++c) { + ArrayList<Object> tr = new ArrayList<>(trsize), te = new ArrayList<>(tesize); + for(int i = 0; i < datalen; ++i) { + (assignment[i] ? tr : te).add(bundle.data(i, c)); + } + training.appendColumn(bundle.meta(c), tr); + test.appendColumn(bundle.meta(c), te); + } + + ++fold; + return new TrainingAndTestSet(training, test, labels); + } + + /** + * Parameterization class + * + * @author Erich Schubert + * + * @apiviz.exclude + */ + public static class Parameterizer extends RandomizedHoldout.Parameterizer { + /** + * Parameter for number of folds. + */ + public static final OptionID NFOLD_ID = new OptionID("nfold", "positive number of folds for cross-validation"); + + /** + * Default number of folds. + */ + public static final int N_DEFAULT = 10; + + /** + * Holds the number of folds. + */ + protected int nfold; + + @Override + protected void makeOptions(Parameterization config) { + super.makeOptions(config); + IntParameter nfoldP = new IntParameter(NFOLD_ID)// + .setDefaultValue(N_DEFAULT) // + .addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT); + if(config.grab(nfoldP)) { + nfold = nfoldP.intValue(); + } + } + + @Override + protected RandomizedCrossValidation makeInstance() { + return new RandomizedCrossValidation(random, nfold); + } + } +}
\ No newline at end of file diff --git a/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/RandomizedHoldout.java b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/RandomizedHoldout.java new file mode 100644 index 00000000..6202af48 --- /dev/null +++ b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/RandomizedHoldout.java @@ -0,0 +1,78 @@ +package de.lmu.ifi.dbs.elki.evaluation.classification.holdout; + +/* + This file is part of ELKI: + Environment for Developing KDD-Applications Supported by Index-Structures + + Copyright (C) 2014 + Ludwig-Maximilians-Universität München + Lehr- und Forschungseinheit für Datenbanksysteme + ELKI Development Team + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +import de.lmu.ifi.dbs.elki.math.random.RandomFactory; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.RandomParameter; + +/** + * A holdout providing a seed for randomized operations. + * + * @author Arthur Zimek + */ +public abstract class RandomizedHoldout extends AbstractHoldout { + /** + * The random generator. + */ + protected RandomFactory random; + + /** + * Sets the parameter seed to the parameterToDescription map. + */ + public RandomizedHoldout(RandomFactory random) { + super(); + this.random = random; + } + + /** + * Parameterization class + * + * @author Erich Schubert + * + * @apiviz.exclude + */ + public static abstract class Parameterizer extends AbstractParameterizer { + /** + * Random seeding for holdout evaluation. + */ + public static final OptionID SEED_ID = new OptionID("holdout.seed", "Random generator seed for holdout evaluation."); + + /** + * The random generator. + */ + protected RandomFactory random; + + @Override + protected void makeOptions(Parameterization config) { + super.makeOptions(config); + RandomParameter seedP = new RandomParameter(SEED_ID); + if(config.grab(seedP)) { + random = seedP.getValue(); + } + } + } +}
\ No newline at end of file diff --git a/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/StratifiedCrossValidation.java b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/StratifiedCrossValidation.java new file mode 100644 index 00000000..40de7e88 --- /dev/null +++ b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/StratifiedCrossValidation.java @@ -0,0 +1,162 @@ +package de.lmu.ifi.dbs.elki.evaluation.classification.holdout; + +/* + This file is part of ELKI: + Environment for Developing KDD-Applications Supported by Index-Structures + + Copyright (C) 2014 + Ludwig-Maximilians-Universität München + Lehr- und Forschungseinheit für Datenbanksysteme + ELKI Development Team + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +import gnu.trove.list.array.TIntArrayList; + +import java.util.ArrayList; +import java.util.Collections; + +import de.lmu.ifi.dbs.elki.data.ClassLabel; +import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle; +import de.lmu.ifi.dbs.elki.utilities.exceptions.AbortException; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.AbstractParameterizer; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization; +import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter; + +/** + * A stratified n-fold crossvalidation to distribute the data to n buckets where + * each bucket exhibits approximately the same distribution of classes as does + * the complete data set. The buckets are disjoint. The distribution is + * deterministic. + * + * @author Arthur Zimek + */ +public class StratifiedCrossValidation extends AbstractHoldout { + /** + * Holds the number of folds, current fold. + */ + protected int nfold, fold; + + /** + * Partition assignment, sizes + */ + protected int[] assignment, sizes; + + /** + * Provides a stratified crossvalidation. Setting parameter N_P to the + * OptionHandler. + */ + public StratifiedCrossValidation(int nfold) { + super(); + this.nfold = nfold; + } + + @Override + public int numberOfPartitions() { + return nfold; + } + + @Override + public void initialize(MultipleObjectsBundle bundle) { + super.initialize(bundle); + fold = 0; + TIntArrayList[] classBuckets = new TIntArrayList[this.labels.size()]; + for(int i = 0; i < this.labels.size(); i++) { + classBuckets[i] = new TIntArrayList(); + } + for(int i = 0, l = bundle.dataLength(); i < l; ++i) { + ClassLabel label = (ClassLabel) bundle.data(i, labelcol); + if(label == null) { + throw new AbortException("Unlabeled instances currently not supported."); + } + int classIndex = Collections.binarySearch(labels, label); + if(classIndex < 0) { + throw new AbortException("Label not in label list: " + label); + } + classBuckets[classIndex].add(i); + } + // TODO: shuffle the class buckets? + sizes = new int[nfold]; + assignment = new int[bundle.dataLength()]; + for(TIntArrayList bucket : classBuckets) { + for(int i = 0; i < bucket.size(); i++) { + assignment[bucket.get(i)] = i % nfold; + } + } + } + + @Override + public TrainingAndTestSet nextPartitioning() { + if(fold >= nfold) { + return null; + } + final int tesize = sizes[fold], trsize = bundle.dataLength() - tesize; + MultipleObjectsBundle training = new MultipleObjectsBundle(); + MultipleObjectsBundle test = new MultipleObjectsBundle(); + // Process column-wise. + for(int c = 0, cs = bundle.metaLength(); c < cs; ++c) { + ArrayList<Object> tr = new ArrayList<>(trsize), te = new ArrayList<>(tesize); + for(int i = 0; i < bundle.dataLength(); ++i) { + ((assignment[i] != fold) ? tr : te).add(bundle.data(i, c)); + } + training.appendColumn(bundle.meta(c), tr); + test.appendColumn(bundle.meta(c), te); + } + + ++fold; + return new TrainingAndTestSet(training, test, labels); + } + + /** + * Parameterization class + * + * @author Erich Schubert + * + * @apiviz.exclude + */ + public static class Parameterizer extends AbstractParameterizer { + /** + * Default number of folds. + */ + public static final int N_DEFAULT = 10; + + /** + * Parameter for number of folds. + */ + public static final OptionID NFOLD_ID = new OptionID("nfold", "Number of folds for cross-validation"); + + /** + * Holds the number of folds. + */ + protected int nfold; + + @Override + protected void makeOptions(Parameterization config) { + super.makeOptions(config); + IntParameter nfoldP = new IntParameter(NFOLD_ID, N_DEFAULT)// + .addConstraint(CommonConstraints.GREATER_EQUAL_ONE_INT); + if(config.grab(nfoldP)) { + nfold = nfoldP.intValue(); + } + } + + @Override + protected StratifiedCrossValidation makeInstance() { + return new StratifiedCrossValidation(nfold); + } + } +}
\ No newline at end of file diff --git a/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/TrainingAndTestSet.java b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/TrainingAndTestSet.java new file mode 100644 index 00000000..a8725e1c --- /dev/null +++ b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/TrainingAndTestSet.java @@ -0,0 +1,89 @@ +package de.lmu.ifi.dbs.elki.evaluation.classification.holdout; + +/* + This file is part of ELKI: + Environment for Developing KDD-Applications Supported by Index-Structures + + Copyright (C) 2014 + Ludwig-Maximilians-Universität München + Lehr- und Forschungseinheit für Datenbanksysteme + ELKI Development Team + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. + */ + +import java.util.ArrayList; + +import de.lmu.ifi.dbs.elki.data.ClassLabel; +import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle; + +/** + * Wrapper to hold a pair of training and test data sets. The labels of both + * training and test set are provided in labels. + * + * @author Arthur Zimek + */ +public class TrainingAndTestSet { + /** + * The overall labels. + */ + private ArrayList<ClassLabel> labels; + + /** + * The training data. + */ + private MultipleObjectsBundle training; + + /** + * The test data. + */ + private MultipleObjectsBundle test; + + /** + * Provides a pair of training and test data sets out of the given two + * databases. + */ + public TrainingAndTestSet(MultipleObjectsBundle training, MultipleObjectsBundle test, ArrayList<ClassLabel> labels) { + this.training = training; + this.test = test; + this.labels = labels; + } + + /** + * Returns the test data set. + * + * @return the test data set + */ + public MultipleObjectsBundle getTest() { + return test; + } + + /** + * Returns the training data set. + * + * @return the training data set + */ + public MultipleObjectsBundle getTraining() { + return training; + } + + /** + * Returns all labels present in the data set. + * + * @return all labels + */ + public ArrayList<ClassLabel> getLabels() { + return labels; + } +} diff --git a/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/package-info.java b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/package-info.java new file mode 100644 index 00000000..a469dfc3 --- /dev/null +++ b/src/de/lmu/ifi/dbs/elki/evaluation/classification/holdout/package-info.java @@ -0,0 +1,27 @@ +/** + * Holdout and cross-validation strategies for evaluating classifiers. + */ +/* + This file is part of ELKI: + Environment for Developing KDD-Applications Supported by Index-Structures + + Copyright (C) 2014 + Ludwig-Maximilians-Universität München + Lehr- und Forschungseinheit für Datenbanksysteme + ELKI Development Team + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see <http://www.gnu.org/licenses/>. + */ +package de.lmu.ifi.dbs.elki.evaluation.classification.holdout; + |