/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.loss;

import org.nd4j.autodiff.loss.LossInfo;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;

public class LossFunctions {
    private static final int[] SCALAR = new int[]{1, 1};

    private LossFunctions() {
    }

    private static LossInfo.Builder validate(String lossName, SDVariable predictions, SDVariable label, Reduction reduction) {
        Preconditions.checkNotNull((Object)predictions, (String)"Predictions variable cannot be null for loss function - %s", (Object)lossName);
        Preconditions.checkNotNull((Object)label, (String)"Label variable cannot be null for loss function - %s", (Object)lossName);
        Preconditions.checkNotNull((Object)((Object)reduction), (String)"Reduction enumeration cannot be null for loss function - %s", (Object)lossName);
        return LossInfo.builder().lossName(lossName).reduction(reduction).label(label).predictions(predictions);
    }

    public static LossInfo mse(String outputName, SDVariable predictions, SDVariable label, SDVariable weights, Reduction reduction, int ... dimensions) {
        LossInfo.Builder b = LossFunctions.validate("mse", predictions, label, reduction);
        SameDiff sd = predictions.getSameDiff();
        if (weights == null) {
            weights = sd.one("mse_loss_weights", SCALAR);
        }
        SDVariable diff = predictions.sub(label);
        String name = reduction == Reduction.NONE ? outputName : null;
        SDVariable preReduceLoss = sd.square(diff).mul(name, weights);
        return LossFunctions.doReduce(sd, outputName, true, b, reduction, preReduceLoss, label, weights, dimensions);
    }

    public static LossInfo l1(String outputName, SDVariable predictions, SDVariable label, SDVariable weights, Reduction reduction, int ... dimensions) {
        LossInfo.Builder b = LossFunctions.validate("l1", predictions, label, reduction);
        SameDiff sd = predictions.getSameDiff();
        if (weights == null) {
            weights = sd.one("l1_loss_weights", SCALAR);
        }
        String name = reduction == Reduction.NONE ? outputName : null;
        SDVariable preReduceLoss = sd.abs(predictions.sub(label)).mul(name, weights);
        return LossFunctions.doReduce(sd, outputName, false, b, reduction, preReduceLoss, label, weights, dimensions);
    }

    public static LossInfo l2(String outputName, SDVariable predictions, SDVariable label, SDVariable weights, Reduction reduction, int ... dimensions) {
        LossInfo.Builder b = LossFunctions.validate("l2", predictions, label, reduction);
        SameDiff sd = predictions.getSameDiff();
        if (weights == null) {
            weights = sd.one("l2_loss_weights", SCALAR);
        }
        SDVariable diff = predictions.sub(label);
        String name = reduction == Reduction.NONE ? outputName : null;
        SDVariable preReduceLoss = sd.square(diff).mul(name, weights);
        return LossFunctions.doReduce(sd, outputName, false, b, reduction, preReduceLoss, label, weights, dimensions);
    }

    public static LossInfo negativeLogLikelihood(String outputName, SDVariable predictions, SDVariable label, SDVariable weights, Reduction reduction, int ... dimensions) {
        return LossFunctions.mcxent(outputName, predictions, label, weights, reduction, dimensions);
    }

    public static LossInfo mcxent(String outputName, SDVariable predictions, SDVariable label, SDVariable weights, Reduction reduction, int ... dimensions) {
        LossInfo.Builder b = LossFunctions.validate("mcxent", predictions, label, reduction);
        SameDiff sd = predictions.getSameDiff();
        if (weights == null) {
            weights = sd.one("mcxent_loss_weights", SCALAR);
        }
        String name = reduction == Reduction.NONE ? outputName : null;
        SDVariable weightedLogProd = sd.log(predictions).mul(label).mul(name, weights);
        return LossFunctions.doReduce(sd, outputName, false, b, reduction, weightedLogProd, label, weights, dimensions);
    }

    private static SDVariable nonZeroCount(SDVariable weights, SDVariable labels) {
        SameDiff sd = weights.getSameDiff();
        SDVariable present = sd.neq(weights, 0.0);
        SDVariable presentBroadcast = sd.zerosLike(labels).add(present);
        return sd.sum(presentBroadcast, new int[0]);
    }

    private static LossInfo doReduce(SameDiff sd, String outputName, boolean isMean, LossInfo.Builder b, Reduction reduction, SDVariable preReduceLoss, SDVariable label, SDVariable weights, int[] dimensions) {
        switch (reduction) {
            case NONE: {
                b.loss(preReduceLoss);
                break;
            }
            case SPECIFIED_DIMS: {
                if (isMean) {
                    b.loss(sd.mean(outputName, preReduceLoss, dimensions));
                } else {
                    b.loss(sd.sum(outputName, preReduceLoss, dimensions));
                }
            }
            case SUM: {
                if (isMean) {
                    SDVariable m = sd.mean(preReduceLoss, dimensions);
                    b.loss(sd.sum(outputName, m, new int[0]));
                    break;
                }
                b.loss(sd.sum(outputName, preReduceLoss, new int[0]));
                break;
            }
            case MEAN_BY_WEIGHT: {
                SDVariable weightSum = sd.sum(weights, new int[0]);
                if (isMean) {
                    SDVariable m2 = sd.mean(preReduceLoss);
                    b.loss(m2.div(outputName, weightSum));
                    break;
                }
                SDVariable sum = sd.sum(preReduceLoss, dimensions);
                b.loss(sum.div(outputName, weightSum));
                break;
            }
            case MEAN_BY_COUNT: {
                SDVariable r;
                SDVariable nonZeroWeights = LossFunctions.nonZeroCount(weights, label);
                if (isMean) {
                    r = sd.sum(preReduceLoss, new int[0]);
                } else {
                    SDVariable sum = sd.sum(preReduceLoss, dimensions);
                    r = sd.mean(sum);
                }
                b.loss(r.div(outputName, nonZeroWeights));
                break;
            }
            default: {
                throw new RuntimeException("Unknown reduction: " + (Object)((Object)reduction));
            }
        }
        return b.build();
    }

    public static enum Reduction {
        NONE,
        SPECIFIED_DIMS,
        SUM,
        MEAN_BY_WEIGHT,
        MEAN_BY_COUNT;

    }
}

