/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize;

import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.optimize.api.TrainingEvaluator;
import org.nd4j.linalg.dataset.DataSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class OutputLayerTrainingEvaluator
implements TrainingEvaluator {
    private BaseMultiLayerNetwork network;
    private double patience;
    private double patienceIncrease;
    private double bestLoss;
    private int validationEpochs;
    private int miniBatchSize;
    private DataSet testSet;
    private double improvementThreshold;
    private static Logger log = LoggerFactory.getLogger(OutputLayerTrainingEvaluator.class);

    public OutputLayerTrainingEvaluator(BaseMultiLayerNetwork network, double patience, double patienceIncrease, double bestLoss, int validationEpochs, int miniBatchSize, DataSet testSet, double improvementThreshold) {
        this.network = network;
        this.patience = 4 * miniBatchSize;
        this.patienceIncrease = patienceIncrease;
        this.bestLoss = bestLoss;
        this.validationEpochs = validationEpochs;
        this.miniBatchSize = miniBatchSize;
        this.testSet = testSet;
        this.improvementThreshold = improvementThreshold;
    }

    @Override
    public boolean shouldStop(int epoch) {
        boolean ret;
        if (epoch % this.validationEpochs != 0 || epoch < 2) {
            return false;
        }
        double score = this.network.score();
        if (score < this.bestLoss && score < this.bestLoss * this.improvementThreshold) {
            this.bestLoss = score;
            this.patience = Math.max(this.patience, (double)epoch * this.patienceIncrease);
        }
        boolean bl = ret = this.patience < (double)epoch;
        if (ret) {
            log.info("Returning early on finetune");
        }
        return ret;
    }

    @Override
    public double patienceIncrease() {
        return this.patienceIncrease;
    }

    @Override
    public double improvementThreshold() {
        return this.improvementThreshold;
    }

    @Override
    public double patience() {
        return this.patience;
    }

    @Override
    public double bestLoss() {
        return this.bestLoss;
    }

    @Override
    public int validationEpochs() {
        return this.validationEpochs;
    }

    @Override
    public int miniBatchSize() {
        return this.miniBatchSize;
    }

    public static class Builder {
        private BaseMultiLayerNetwork network;
        private double patience;
        private double patienceIncrease;
        private double bestLoss;
        private int validationEpochs;
        private int miniBatchSize;
        private DataSet testSet;
        private double improvementThreshold;

        public Builder withNetwork(BaseMultiLayerNetwork network) {
            this.network = network;
            return this;
        }

        public Builder patience(double patience) {
            this.patience = patience;
            return this;
        }

        public Builder patienceIncrease(double patienceIncrease) {
            this.patienceIncrease = patienceIncrease;
            return this;
        }

        public Builder bestLoss(double bestLoss) {
            this.bestLoss = bestLoss;
            return this;
        }

        public Builder validationEpochs(int validationEpochs) {
            this.validationEpochs = validationEpochs;
            return this;
        }

        public Builder testSet(DataSet testSet) {
            this.testSet = testSet;
            return this;
        }

        public Builder miniBatchSize(int miniBatchSize) {
            this.miniBatchSize = miniBatchSize;
            return this;
        }

        public Builder improvementThreshold(double improvementThreshold) {
            this.improvementThreshold = improvementThreshold;
            return this;
        }

        public OutputLayerTrainingEvaluator build() {
            return new OutputLayerTrainingEvaluator(this.network, this.patience, this.patienceIncrease, this.bestLoss, this.validationEpochs, this.miniBatchSize, this.testSet, this.improvementThreshold);
        }
    }
}

