package de.lmu.ifi.dbs.elki.evaluation.classification.holdout;
/*
This file is part of ELKI:
Environment for Developing KDD-Applications Supported by Index-Structures
Copyright (C) 2015
Ludwig-Maximilians-Universität München
Lehr- und Forschungseinheit für Datenbanksysteme
ELKI Development Team
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see .
*/
import java.util.ArrayList;
import java.util.Random;
import de.lmu.ifi.dbs.elki.datasource.bundle.MultipleObjectsBundle;
import de.lmu.ifi.dbs.elki.math.random.RandomFactory;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.OptionID;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.constraints.CommonConstraints;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameterization.Parameterization;
import de.lmu.ifi.dbs.elki.utilities.optionhandling.parameters.IntParameter;
/**
* RandomizedCrossValidationHoldout provides a set of partitions of a database
* to perform cross-validation. The test sets are not guaranteed to be disjoint.
*
* @author Arthur Zimek
*/
public class RandomizedCrossValidation extends RandomizedHoldout {
/**
* Holds the number of folds, current fold.
*/
protected int nfold, fold;
/**
* Constructor for n-fold cross-validation.
*
* @param random Random seed
* @param nfold Number of folds
*/
public RandomizedCrossValidation(RandomFactory random, int nfold) {
super(random);
this.nfold = nfold;
}
@Override
public void initialize(MultipleObjectsBundle bundle) {
super.initialize(bundle);
this.fold = 0;
}
@Override
public int numberOfPartitions() {
return nfold;
}
@Override
public TrainingAndTestSet nextPartitioning() {
if(fold >= nfold) {
return null;
}
MultipleObjectsBundle training = new MultipleObjectsBundle();
MultipleObjectsBundle test = new MultipleObjectsBundle();
Random rnd = random.getRandom();
int datalen = bundle.dataLength();
boolean[] assignment = new boolean[datalen];
int trsize = 0, tesize = 0;
for(int i = 0; i < assignment.length; ++i) {
boolean p = rnd.nextInt(nfold) < nfold - 1;
assignment[i] = p;
@SuppressWarnings("unused")
int discard = p ? ++trsize : ++tesize;
}
// Process column-wise.
for(int c = 0, cs = bundle.metaLength(); c < cs; ++c) {
ArrayList