/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.earlystopping.scorecalc;

import org.deeplearning4j.earlystopping.scorecalc.base.BaseScoreCalculator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.shade.jackson.annotation.JsonProperty;

public class DataSetLossCalculator
extends BaseScoreCalculator<Model> {
    @JsonProperty
    private boolean average;

    public DataSetLossCalculator(DataSetIterator dataSetIterator, boolean average) {
        super(dataSetIterator);
        this.average = average;
    }

    public DataSetLossCalculator(MultiDataSetIterator dataSetIterator, boolean average) {
        super(dataSetIterator);
        this.average = average;
    }

    public String toString() {
        return "DataSetLossCalculator(average=" + this.average + ")";
    }

    @Override
    protected void reset() {
        this.scoreSum = 0.0;
        this.minibatchCount = 0;
        this.exampleCount = 0;
    }

    @Override
    protected INDArray output(Model network, INDArray input, INDArray fMask, INDArray lMask) {
        return this.output(network, DataSetLossCalculator.arr(input), DataSetLossCalculator.arr(fMask), DataSetLossCalculator.arr(lMask))[0];
    }

    @Override
    protected INDArray[] output(Model network, INDArray[] input, INDArray[] fMask, INDArray[] lMask) {
        if (network instanceof MultiLayerNetwork) {
            INDArray out = ((MultiLayerNetwork)network).output(input[0], false, DataSetLossCalculator.get0(fMask), DataSetLossCalculator.get0(lMask));
            return new INDArray[]{out};
        }
        if (network instanceof ComputationGraph) {
            return ((ComputationGraph)network).output(false, input, fMask, lMask);
        }
        throw new RuntimeException("Unknown model type: " + network.getClass());
    }

    @Override
    protected double scoreMinibatch(Model network, INDArray[] features, INDArray[] labels, INDArray[] fMask, INDArray[] lMask, INDArray[] output) {
        if (network instanceof MultiLayerNetwork) {
            return ((MultiLayerNetwork)network).score(new DataSet(DataSetLossCalculator.get0(features), DataSetLossCalculator.get0(labels), DataSetLossCalculator.get0(fMask), DataSetLossCalculator.get0(lMask)), false) * (double)features[0].size(0);
        }
        if (network instanceof ComputationGraph) {
            return ((ComputationGraph)network).score((org.nd4j.linalg.dataset.api.MultiDataSet)new MultiDataSet(features, labels, fMask, lMask)) * (double)features[0].size(0);
        }
        throw new RuntimeException("Unknown model type: " + network.getClass());
    }

    @Override
    protected double finalScore(double scoreSum, int minibatchCount, int exampleCount) {
        if (this.average) {
            return scoreSum / (double)exampleCount;
        }
        return scoreSum;
    }
}

