/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.earlystopping.trainer;

import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.earlystopping.EarlyStoppingConfiguration;
import org.deeplearning4j.earlystopping.listener.EarlyStoppingListener;
import org.deeplearning4j.earlystopping.trainer.BaseEarlyStoppingTrainer;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;

public class EarlyStoppingGraphTrainer
extends BaseEarlyStoppingTrainer<ComputationGraph> {
    private ComputationGraph net;

    public EarlyStoppingGraphTrainer(EarlyStoppingConfiguration<ComputationGraph> esConfig, ComputationGraph net, DataSetIterator train) {
        this(esConfig, net, train, null);
    }

    public EarlyStoppingGraphTrainer(EarlyStoppingConfiguration<ComputationGraph> esConfig, ComputationGraph net, DataSetIterator train, EarlyStoppingListener<ComputationGraph> listener) {
        super(esConfig, net, train, null, listener);
        if (net.getNumInputArrays() != 1 || net.getNumOutputArrays() != 1) {
            throw new IllegalStateException("Cannot do early stopping training on ComputationGraph with DataSetIterator: graph does not have 1 input and 1 output array");
        }
        this.net = net;
    }

    public EarlyStoppingGraphTrainer(EarlyStoppingConfiguration<ComputationGraph> esConfig, ComputationGraph net, MultiDataSetIterator train, EarlyStoppingListener<ComputationGraph> listener) {
        super(esConfig, net, null, train, listener);
        this.net = net;
    }

    @Override
    protected void fit(DataSet ds) {
        this.net.fit(ds);
    }

    @Override
    protected void fit(MultiDataSet mds) {
        this.net.fit(mds);
    }
}

