/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.eval;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.eval.BaseEvaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.transforms.Abs;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.serde.RowVectorDeserializer;
import org.nd4j.linalg.lossfunctions.serde.RowVectorSerializer;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

public class RegressionEvaluation
extends BaseEvaluation<RegressionEvaluation> {
    public static final int DEFAULT_PRECISION = 5;
    private boolean initialized;
    private List<String> columnNames;
    private int precision;
    @JsonSerialize(using=RowVectorSerializer.class)
    @JsonDeserialize(using=RowVectorDeserializer.class)
    private INDArray exampleCountPerColumn;
    @JsonSerialize(using=RowVectorSerializer.class)
    @JsonDeserialize(using=RowVectorDeserializer.class)
    private INDArray labelsSumPerColumn;
    @JsonSerialize(using=RowVectorSerializer.class)
    @JsonDeserialize(using=RowVectorDeserializer.class)
    private INDArray sumSquaredErrorsPerColumn;
    @JsonSerialize(using=RowVectorSerializer.class)
    @JsonDeserialize(using=RowVectorDeserializer.class)
    private INDArray sumAbsErrorsPerColumn;
    @JsonSerialize(using=RowVectorSerializer.class)
    @JsonDeserialize(using=RowVectorDeserializer.class)
    private INDArray currentMean;
    @JsonSerialize(using=RowVectorSerializer.class)
    @JsonDeserialize(using=RowVectorDeserializer.class)
    private INDArray currentPredictionMean;
    @JsonSerialize(using=RowVectorSerializer.class)
    @JsonDeserialize(using=RowVectorDeserializer.class)
    private INDArray sumOfProducts;
    @JsonSerialize(using=RowVectorSerializer.class)
    @JsonDeserialize(using=RowVectorDeserializer.class)
    private INDArray sumSquaredLabels;
    @JsonSerialize(using=RowVectorSerializer.class)
    @JsonDeserialize(using=RowVectorDeserializer.class)
    private INDArray sumSquaredPredicted;

    public RegressionEvaluation() {
        this(null, 5);
    }

    public RegressionEvaluation(int nColumns) {
        this(RegressionEvaluation.createDefaultColumnNames(nColumns), 5);
    }

    public RegressionEvaluation(int nColumns, int precision) {
        this(RegressionEvaluation.createDefaultColumnNames(nColumns), precision);
    }

    public RegressionEvaluation(String ... columnNames) {
        this(columnNames == null || columnNames.length == 0 ? null : Arrays.asList(columnNames), 5);
    }

    public RegressionEvaluation(List<String> columnNames) {
        this(columnNames, 5);
    }

    public RegressionEvaluation(List<String> columnNames, int precision) {
        this.precision = precision;
        if (columnNames == null || columnNames.size() == 0) {
            this.initialized = false;
        } else {
            this.columnNames = columnNames;
            this.initialize(columnNames.size());
        }
    }

    @Override
    public void reset() {
        this.initialized = false;
    }

    private void initialize(int n) {
        if (this.columnNames == null || this.columnNames.size() != n) {
            this.columnNames = RegressionEvaluation.createDefaultColumnNames(n);
        }
        this.exampleCountPerColumn = Nd4j.zeros((int)n);
        this.labelsSumPerColumn = Nd4j.zeros((int)n);
        this.sumSquaredErrorsPerColumn = Nd4j.zeros((int)n);
        this.sumAbsErrorsPerColumn = Nd4j.zeros((int)n);
        this.currentMean = Nd4j.zeros((int)n);
        this.currentPredictionMean = Nd4j.zeros((int)n);
        this.sumOfProducts = Nd4j.zeros((int)n);
        this.sumSquaredLabels = Nd4j.zeros((int)n);
        this.sumSquaredPredicted = Nd4j.zeros((int)n);
        this.initialized = true;
    }

    private static List<String> createDefaultColumnNames(int nColumns) {
        ArrayList<String> list = new ArrayList<String>(nColumns);
        for (int i = 0; i < nColumns; ++i) {
            list.add("col_" + i);
        }
        return list;
    }

    @Override
    public void eval(INDArray labels, INDArray predictions) {
        this.eval(labels, predictions, (INDArray)null);
    }

    @Override
    public void eval(INDArray labels, INDArray predictions, INDArray maskArray) {
        if (labels.rank() == 3) {
            this.evalTimeSeries(labels, predictions, maskArray);
            return;
        }
        if (maskArray != null && !Arrays.equals(maskArray.shape(), labels.shape())) {
            throw new RuntimeException("Per output masking detected, but mask array and labels have different shapes: " + Arrays.toString(maskArray.shape()) + " vs. labels shape " + Arrays.toString(labels.shape()));
        }
        if (!this.initialized) {
            this.initialize(labels.size(1));
        }
        if (this.columnNames.size() != labels.size(1) || this.columnNames.size() != predictions.size(1)) {
            throw new IllegalArgumentException("Number of the columns of labels and predictions must match specification (" + this.columnNames.size() + "). Got " + labels.size(1) + " and " + predictions.size(1));
        }
        if (maskArray != null) {
            labels = labels.mul(maskArray);
            predictions = predictions.mul(maskArray);
        }
        this.labelsSumPerColumn.addi(labels.sum(new int[]{0}));
        INDArray error = predictions.sub(labels);
        INDArray absErrorSum = Nd4j.getExecutioner().execAndReturn((TransformOp)new Abs(error.dup())).sum(new int[]{0});
        INDArray squaredErrorSum = error.mul(error).sum(new int[]{0});
        this.sumAbsErrorsPerColumn.addi(absErrorSum);
        this.sumSquaredErrorsPerColumn.addi(squaredErrorSum);
        this.sumOfProducts.addi(labels.mul(predictions).sum(new int[]{0}));
        this.sumSquaredLabels.addi(labels.mul(labels).sum(new int[]{0}));
        this.sumSquaredPredicted.addi(predictions.mul(predictions).sum(new int[]{0}));
        int nRows = labels.size(0);
        INDArray newExampleCountPerColumn = maskArray == null ? this.exampleCountPerColumn.add((Number)nRows) : this.exampleCountPerColumn.add(maskArray.sum(new int[]{0}));
        this.currentMean.muliRowVector(this.exampleCountPerColumn).addi(labels.sum(new int[]{0})).diviRowVector(newExampleCountPerColumn);
        this.currentPredictionMean.muliRowVector(this.exampleCountPerColumn).addi(predictions.sum(new int[]{0})).divi(newExampleCountPerColumn);
        this.exampleCountPerColumn = newExampleCountPerColumn;
    }

    @Override
    public void merge(RegressionEvaluation other) {
        if (other.labelsSumPerColumn == null) {
            return;
        }
        if (this.labelsSumPerColumn == null) {
            this.columnNames = other.columnNames;
            this.precision = other.precision;
            this.exampleCountPerColumn = other.exampleCountPerColumn;
            this.labelsSumPerColumn = other.labelsSumPerColumn.dup();
            this.sumSquaredErrorsPerColumn = other.sumSquaredErrorsPerColumn.dup();
            this.sumAbsErrorsPerColumn = other.sumAbsErrorsPerColumn.dup();
            this.currentMean = other.currentMean.dup();
            this.currentPredictionMean = other.currentPredictionMean.dup();
            this.sumOfProducts = other.sumOfProducts.dup();
            this.sumSquaredLabels = other.sumSquaredLabels.dup();
            this.sumSquaredPredicted = other.sumSquaredPredicted.dup();
            return;
        }
        this.labelsSumPerColumn.addi(other.labelsSumPerColumn);
        this.sumSquaredErrorsPerColumn.addi(other.sumSquaredErrorsPerColumn);
        this.sumAbsErrorsPerColumn.addi(other.sumAbsErrorsPerColumn);
        this.currentMean.muliRowVector(this.exampleCountPerColumn).addi(other.currentMean.mulRowVector(other.exampleCountPerColumn)).diviRowVector(this.exampleCountPerColumn.add(other.exampleCountPerColumn));
        this.currentPredictionMean.muliRowVector(this.exampleCountPerColumn).addi(other.currentPredictionMean.mulRowVector(other.exampleCountPerColumn)).diviRowVector(this.exampleCountPerColumn.add(other.exampleCountPerColumn));
        this.sumOfProducts.addi(other.sumOfProducts);
        this.sumSquaredLabels.addi(other.sumSquaredLabels);
        this.sumSquaredPredicted.addi(other.sumSquaredPredicted);
        this.exampleCountPerColumn.addi(other.exampleCountPerColumn);
    }

    @Override
    public String stats() {
        if (!this.initialized) {
            return "RegressionEvaluation: No Data";
        }
        if (this.columnNames == null) {
            this.columnNames = RegressionEvaluation.createDefaultColumnNames(this.numColumns());
        }
        int maxLabelLength = 0;
        for (String s : this.columnNames) {
            maxLabelLength = Math.max(maxLabelLength, s.length());
        }
        int labelWidth = maxLabelLength + 5;
        int columnWidth = this.precision + 10;
        String format = "%-" + labelWidth + "s%-" + columnWidth + "." + this.precision + "e%-" + columnWidth + "." + this.precision + "e%-" + columnWidth + "." + this.precision + "e%-" + columnWidth + "." + this.precision + "e%-" + columnWidth + "." + this.precision + "e";
        StringBuilder sb = new StringBuilder();
        String headerFormat = "%-" + labelWidth + "s%-" + columnWidth + "s%-" + columnWidth + "s%-" + columnWidth + "s%-" + columnWidth + "s%-" + columnWidth + "s";
        sb.append(String.format(headerFormat, "Column", "MSE", "MAE", "RMSE", "RSE", "R^2"));
        sb.append("\n");
        for (int i = 0; i < this.columnNames.size(); ++i) {
            double mse = this.meanSquaredError(i);
            double mae = this.meanAbsoluteError(i);
            double rmse = this.rootMeanSquaredError(i);
            double rse = this.relativeSquaredError(i);
            double corr = this.correlationR2(i);
            sb.append(String.format(format, this.columnNames.get(i), mse, mae, rmse, rse, corr));
            sb.append("\n");
        }
        return sb.toString();
    }

    public int numColumns() {
        if (this.columnNames == null) {
            if (this.exampleCountPerColumn == null) {
                return 0;
            }
            return this.exampleCountPerColumn.size(1);
        }
        return this.columnNames.size();
    }

    public double meanSquaredError(int column) {
        return this.sumSquaredErrorsPerColumn.getDouble(column) / this.exampleCountPerColumn.getDouble(column);
    }

    public double meanAbsoluteError(int column) {
        return this.sumAbsErrorsPerColumn.getDouble(column) / this.exampleCountPerColumn.getDouble(column);
    }

    public double rootMeanSquaredError(int column) {
        return Math.sqrt(this.sumSquaredErrorsPerColumn.getDouble(column) / this.exampleCountPerColumn.getDouble(column));
    }

    public double correlationR2(int column) {
        double sumxiyi = this.sumOfProducts.getDouble(column);
        double predictionMean = this.currentPredictionMean.getDouble(column);
        double labelMean = this.currentMean.getDouble(column);
        double sumSquaredLabels = this.sumSquaredLabels.getDouble(column);
        double sumSquaredPredicted = this.sumSquaredPredicted.getDouble(column);
        double exampleCount = this.exampleCountPerColumn.getDouble(column);
        double r2 = sumxiyi - exampleCount * predictionMean * labelMean;
        return r2 /= Math.sqrt(sumSquaredLabels - exampleCount * labelMean * labelMean) * Math.sqrt(sumSquaredPredicted - exampleCount * predictionMean * predictionMean);
    }

    public double relativeSquaredError(int column) {
        double numerator = this.sumSquaredPredicted.getDouble(column) - 2.0 * this.sumOfProducts.getDouble(column) + this.sumSquaredLabels.getDouble(column);
        double denominator = this.sumSquaredLabels.getDouble(column) - this.exampleCountPerColumn.getDouble(column) * this.currentMean.getDouble(column) * this.currentMean.getDouble(column);
        if (Math.abs(denominator) > Nd4j.EPS_THRESHOLD) {
            return numerator / denominator;
        }
        return Double.POSITIVE_INFINITY;
    }

    public double averageMeanSquaredError() {
        double ret = 0.0;
        for (int i = 0; i < this.numColumns(); ++i) {
            ret += this.meanSquaredError(i);
        }
        return ret / (double)this.numColumns();
    }

    public double averageMeanAbsoluteError() {
        double ret = 0.0;
        for (int i = 0; i < this.numColumns(); ++i) {
            ret += this.meanAbsoluteError(i);
        }
        return ret / (double)this.numColumns();
    }

    public double averagerootMeanSquaredError() {
        double ret = 0.0;
        for (int i = 0; i < this.numColumns(); ++i) {
            ret += this.rootMeanSquaredError(i);
        }
        return ret / (double)this.numColumns();
    }

    public double averagerelativeSquaredError() {
        double ret = 0.0;
        for (int i = 0; i < this.numColumns(); ++i) {
            ret += this.relativeSquaredError(i);
        }
        return ret / (double)this.numColumns();
    }

    public double averagecorrelationR2() {
        double ret = 0.0;
        for (int i = 0; i < this.numColumns(); ++i) {
            ret += this.correlationR2(i);
        }
        return ret / (double)this.numColumns();
    }

    public boolean isInitialized() {
        return this.initialized;
    }

    public List<String> getColumnNames() {
        return this.columnNames;
    }

    public int getPrecision() {
        return this.precision;
    }

    public INDArray getExampleCountPerColumn() {
        return this.exampleCountPerColumn;
    }

    public INDArray getLabelsSumPerColumn() {
        return this.labelsSumPerColumn;
    }

    public INDArray getSumSquaredErrorsPerColumn() {
        return this.sumSquaredErrorsPerColumn;
    }

    public INDArray getSumAbsErrorsPerColumn() {
        return this.sumAbsErrorsPerColumn;
    }

    public INDArray getCurrentMean() {
        return this.currentMean;
    }

    public INDArray getCurrentPredictionMean() {
        return this.currentPredictionMean;
    }

    public INDArray getSumOfProducts() {
        return this.sumOfProducts;
    }

    public INDArray getSumSquaredLabels() {
        return this.sumSquaredLabels;
    }

    public INDArray getSumSquaredPredicted() {
        return this.sumSquaredPredicted;
    }

    public void setInitialized(boolean initialized) {
        this.initialized = initialized;
    }

    public void setColumnNames(List<String> columnNames) {
        this.columnNames = columnNames;
    }

    public void setPrecision(int precision) {
        this.precision = precision;
    }

    public void setExampleCountPerColumn(INDArray exampleCountPerColumn) {
        this.exampleCountPerColumn = exampleCountPerColumn;
    }

    public void setLabelsSumPerColumn(INDArray labelsSumPerColumn) {
        this.labelsSumPerColumn = labelsSumPerColumn;
    }

    public void setSumSquaredErrorsPerColumn(INDArray sumSquaredErrorsPerColumn) {
        this.sumSquaredErrorsPerColumn = sumSquaredErrorsPerColumn;
    }

    public void setSumAbsErrorsPerColumn(INDArray sumAbsErrorsPerColumn) {
        this.sumAbsErrorsPerColumn = sumAbsErrorsPerColumn;
    }

    public void setCurrentMean(INDArray currentMean) {
        this.currentMean = currentMean;
    }

    public void setCurrentPredictionMean(INDArray currentPredictionMean) {
        this.currentPredictionMean = currentPredictionMean;
    }

    public void setSumOfProducts(INDArray sumOfProducts) {
        this.sumOfProducts = sumOfProducts;
    }

    public void setSumSquaredLabels(INDArray sumSquaredLabels) {
        this.sumSquaredLabels = sumSquaredLabels;
    }

    public void setSumSquaredPredicted(INDArray sumSquaredPredicted) {
        this.sumSquaredPredicted = sumSquaredPredicted;
    }

    @Override
    public String toString() {
        return "RegressionEvaluation(initialized=" + this.isInitialized() + ", columnNames=" + this.getColumnNames() + ", precision=" + this.getPrecision() + ", exampleCountPerColumn=" + this.getExampleCountPerColumn() + ", labelsSumPerColumn=" + this.getLabelsSumPerColumn() + ", sumSquaredErrorsPerColumn=" + this.getSumSquaredErrorsPerColumn() + ", sumAbsErrorsPerColumn=" + this.getSumAbsErrorsPerColumn() + ", currentMean=" + this.getCurrentMean() + ", currentPredictionMean=" + this.getCurrentPredictionMean() + ", sumOfProducts=" + this.getSumOfProducts() + ", sumSquaredLabels=" + this.getSumSquaredLabels() + ", sumSquaredPredicted=" + this.getSumSquaredPredicted() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof RegressionEvaluation)) {
            return false;
        }
        RegressionEvaluation other = (RegressionEvaluation)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (this.isInitialized() != other.isInitialized()) {
            return false;
        }
        List<String> this$columnNames = this.getColumnNames();
        List<String> other$columnNames = other.getColumnNames();
        if (this$columnNames == null ? other$columnNames != null : !((Object)this$columnNames).equals(other$columnNames)) {
            return false;
        }
        if (this.getPrecision() != other.getPrecision()) {
            return false;
        }
        INDArray this$exampleCountPerColumn = this.getExampleCountPerColumn();
        INDArray other$exampleCountPerColumn = other.getExampleCountPerColumn();
        if (this$exampleCountPerColumn == null ? other$exampleCountPerColumn != null : !this$exampleCountPerColumn.equals(other$exampleCountPerColumn)) {
            return false;
        }
        INDArray this$labelsSumPerColumn = this.getLabelsSumPerColumn();
        INDArray other$labelsSumPerColumn = other.getLabelsSumPerColumn();
        if (this$labelsSumPerColumn == null ? other$labelsSumPerColumn != null : !this$labelsSumPerColumn.equals(other$labelsSumPerColumn)) {
            return false;
        }
        INDArray this$sumSquaredErrorsPerColumn = this.getSumSquaredErrorsPerColumn();
        INDArray other$sumSquaredErrorsPerColumn = other.getSumSquaredErrorsPerColumn();
        if (this$sumSquaredErrorsPerColumn == null ? other$sumSquaredErrorsPerColumn != null : !this$sumSquaredErrorsPerColumn.equals(other$sumSquaredErrorsPerColumn)) {
            return false;
        }
        INDArray this$sumAbsErrorsPerColumn = this.getSumAbsErrorsPerColumn();
        INDArray other$sumAbsErrorsPerColumn = other.getSumAbsErrorsPerColumn();
        if (this$sumAbsErrorsPerColumn == null ? other$sumAbsErrorsPerColumn != null : !this$sumAbsErrorsPerColumn.equals(other$sumAbsErrorsPerColumn)) {
            return false;
        }
        INDArray this$currentMean = this.getCurrentMean();
        INDArray other$currentMean = other.getCurrentMean();
        if (this$currentMean == null ? other$currentMean != null : !this$currentMean.equals(other$currentMean)) {
            return false;
        }
        INDArray this$currentPredictionMean = this.getCurrentPredictionMean();
        INDArray other$currentPredictionMean = other.getCurrentPredictionMean();
        if (this$currentPredictionMean == null ? other$currentPredictionMean != null : !this$currentPredictionMean.equals(other$currentPredictionMean)) {
            return false;
        }
        INDArray this$sumOfProducts = this.getSumOfProducts();
        INDArray other$sumOfProducts = other.getSumOfProducts();
        if (this$sumOfProducts == null ? other$sumOfProducts != null : !this$sumOfProducts.equals(other$sumOfProducts)) {
            return false;
        }
        INDArray this$sumSquaredLabels = this.getSumSquaredLabels();
        INDArray other$sumSquaredLabels = other.getSumSquaredLabels();
        if (this$sumSquaredLabels == null ? other$sumSquaredLabels != null : !this$sumSquaredLabels.equals(other$sumSquaredLabels)) {
            return false;
        }
        INDArray this$sumSquaredPredicted = this.getSumSquaredPredicted();
        INDArray other$sumSquaredPredicted = other.getSumSquaredPredicted();
        return !(this$sumSquaredPredicted == null ? other$sumSquaredPredicted != null : !this$sumSquaredPredicted.equals(other$sumSquaredPredicted));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof RegressionEvaluation;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + super.hashCode();
        result = result * 59 + (this.isInitialized() ? 79 : 97);
        List<String> $columnNames = this.getColumnNames();
        result = result * 59 + ($columnNames == null ? 43 : ((Object)$columnNames).hashCode());
        result = result * 59 + this.getPrecision();
        INDArray $exampleCountPerColumn = this.getExampleCountPerColumn();
        result = result * 59 + ($exampleCountPerColumn == null ? 43 : $exampleCountPerColumn.hashCode());
        INDArray $labelsSumPerColumn = this.getLabelsSumPerColumn();
        result = result * 59 + ($labelsSumPerColumn == null ? 43 : $labelsSumPerColumn.hashCode());
        INDArray $sumSquaredErrorsPerColumn = this.getSumSquaredErrorsPerColumn();
        result = result * 59 + ($sumSquaredErrorsPerColumn == null ? 43 : $sumSquaredErrorsPerColumn.hashCode());
        INDArray $sumAbsErrorsPerColumn = this.getSumAbsErrorsPerColumn();
        result = result * 59 + ($sumAbsErrorsPerColumn == null ? 43 : $sumAbsErrorsPerColumn.hashCode());
        INDArray $currentMean = this.getCurrentMean();
        result = result * 59 + ($currentMean == null ? 43 : $currentMean.hashCode());
        INDArray $currentPredictionMean = this.getCurrentPredictionMean();
        result = result * 59 + ($currentPredictionMean == null ? 43 : $currentPredictionMean.hashCode());
        INDArray $sumOfProducts = this.getSumOfProducts();
        result = result * 59 + ($sumOfProducts == null ? 43 : $sumOfProducts.hashCode());
        INDArray $sumSquaredLabels = this.getSumSquaredLabels();
        result = result * 59 + ($sumSquaredLabels == null ? 43 : $sumSquaredLabels.hashCode());
        INDArray $sumSquaredPredicted = this.getSumSquaredPredicted();
        result = result * 59 + ($sumSquaredPredicted == null ? 43 : $sumSquaredPredicted.hashCode());
        return result;
    }
}

