Generating+cross-validation+folds+(Java+approach)

toc

This article describes how to generate train/test splits for [|cross-validation] using the Weka API directly.

The following variables are given: code format="java" Instances data = ...;   // contains the full dataset we wann create train/test sets from int seed = ...;         // the seed for randomizing the data int folds = ...;        // the number of folds to generate, >=2 code

= Randomize the data = First, randomize your data: code format="java" Random rand = new Random(seed);  // create seeded number generator randData = new Instances(data);  // create copy of original data randData.randomize(rand);        // randomize data with number generator code In case your data has a nominal class and you wanna perform stratified cross-validation: code format="java" randData.stratify(folds); code

= Generate the folds =

Single run
Next thing that we have to do is creating the train and the test set: code format="java" for (int n = 0; n < folds; n++) { Instances train = randData.trainCV(folds, n); Instances test = randData.testCV(folds, n); // further processing, classification, etc. ... } code code format="java" Instances train = randData.trainCV(folds, n, rand); code
 * Note:**
 * the above code is used by the filter
 * the class and the Explorer/Experimenter would use this method for obtaining the train set:

Multiple runs
The example above only performs one run of a cross-validation. In case you want to run 10 runs of 10-fold cross-validation, use the following loop: code format="java" Instances data = ...; // our dataset again, obtained from somewhere int runs = 10; for (int i = 0; i < runs; i++) { seed = i+1; // every run gets a new, but defined seed value // see: randomize the data ...  // see: generate the folds ... } code

= See also =
 * Use Weka in your Java code - for general use of the Weka API

= Downloads = The following code runs with Weka 3.5.6 or newer: > simulates a single run of 10-fold cross-validation > simulates a single run of 10-fold cross-validation, but outputs the confusion matrices for each single train/test pair as well. > simulates 10 runs of 10-fold cross-validation > simulates a single run of 10-fold cross-validation, but also adds the classification/distribution/error flag to the test data (uses the filter)
 * [[file:CrossValidationSingleRun.java]] ([|book], [|stable-3.6], [|developer])
 * [[file:CrossValidationSingleRunVariant.java]] ([|book], [|stable-3.6], [|developer])
 * [[file:CrossValidationMultipleRuns.java]] ([|book], [|stable-3.6], [|developer])
 * [[file:CrossValidationAddPrediction.java]] ([|stable-3.6], [|developer])