/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.evaluation;

import com.oracle.labs.mlrg.olcut.util.Pair;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import org.tribuo.Dataset;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Trainer;
import org.tribuo.evaluation.Evaluation;
import org.tribuo.evaluation.Evaluator;
import org.tribuo.evaluation.KFoldSplitter;

public class CrossValidation<T extends Output<T>, E extends Evaluation<T>> {
    private static final Logger logger = Logger.getLogger(CrossValidation.class.getName());
    private final Trainer<T> trainer;
    private final int numFolds;
    private final Dataset<T> data;
    private final Evaluator<T, E> evaluator;
    private final KFoldSplitter<T> splitter;

    public CrossValidation(Trainer<T> trainer, Dataset<T> data, Evaluator<T, E> evaluator, int k) {
        this(trainer, data, evaluator, k, 12345L);
    }

    public CrossValidation(Trainer<T> trainer, Dataset<T> data, Evaluator<T, E> evaluator, int k, long seed) {
        this.trainer = trainer;
        this.data = data;
        this.evaluator = evaluator;
        this.numFolds = k;
        this.splitter = new KFoldSplitter(k, seed);
    }

    public int getK() {
        return this.numFolds;
    }

    public List<Pair<E, Model<T>>> evaluate() {
        ArrayList<Pair<Pair, Model<T>>> evals = new ArrayList<Pair<Pair, Model<T>>>();
        Iterator<KFoldSplitter.TrainTestFold<T>> iter = this.splitter.split(this.data, true);
        int ct = 0;
        while (iter.hasNext()) {
            logger.log(Level.INFO, "Training for fold " + ct);
            KFoldSplitter.TrainTestFold<T> fold = iter.next();
            Model<T> model = this.trainer.train(fold.train);
            evals.add(new Pair(this.evaluator.evaluate(model, fold.test), model));
            ++ct;
        }
        return evals;
    }
}

