/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.lossfunctions.impl;

import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossUtil;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

@JsonInclude(value=JsonInclude.Include.NON_NULL)
public class LossMultiLabel
extends DifferentialFunction
implements ILossFunction {
    private void calculate(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, INDArray scoreOutput, INDArray gradientOutput) {
        if (scoreOutput == null && gradientOutput == null) {
            throw new IllegalArgumentException("You have to provide at least one of scoreOutput or gradientOutput!");
        }
        if (labels.size(1) != preOutput.size(1)) {
            throw new IllegalArgumentException("Labels array numColumns (size(1) = " + labels.size(1) + ") does not match output layer number of outputs (nOut = " + preOutput.size(1) + ") ");
        }
        INDArray postOutput = activationFn.getActivation(preOutput.dup(), true);
        INDArray positive = labels;
        INDArray negative = labels.eq(0.0);
        INDArray normFactor = negative.sum(1).muli(positive.sum(1));
        int examples = positive.size(0);
        for (int i = 0; i < examples; ++i) {
            INDArray locCfn = postOutput.getRow(i);
            int[] shape = locCfn.shape();
            INDArray locPositive = positive.getRow(i);
            INDArray locNegative = negative.getRow(i);
            Double locNormFactor = normFactor.getDouble(i);
            INDArray operandA = Nd4j.ones(shape[1], shape[0]).mmul(locCfn);
            INDArray operandB = operandA.transpose();
            INDArray pairwiseSub = Transforms.exp(operandA.sub(operandB));
            INDArray selection = locPositive.transpose().mmul(locNegative);
            INDArray classificationDifferences = pairwiseSub.muli(selection).divi(locNormFactor);
            if (scoreOutput != null) {
                if (mask != null) {
                    INDArray perLabel = classificationDifferences.sum(0);
                    LossUtil.applyMask(perLabel, mask.getRow(i));
                    perLabel.sum(scoreOutput.getRow(i), 0);
                } else {
                    classificationDifferences.sum(scoreOutput.getRow(i), 0, 1);
                }
            }
            if (gradientOutput == null) continue;
            gradientOutput.getRow(i).assign(classificationDifferences.sum(0).addi(classificationDifferences.sum(1).transposei().negi()));
        }
        if (gradientOutput != null) {
            gradientOutput.assign((INDArray)activationFn.backprop(preOutput.dup(), gradientOutput).getFirst());
            if (mask != null) {
                LossUtil.applyMask(gradientOutput, mask);
            }
        }
    }

    public INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        INDArray scoreArr = Nd4j.create(labels.size(0), 1);
        this.calculate(labels, preOutput, activationFn, mask, scoreArr, null);
        return scoreArr;
    }

    @Override
    public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
        INDArray scoreArr = this.scoreArray(labels, preOutput, activationFn, mask);
        double score = scoreArr.sumNumber().doubleValue();
        if (average) {
            score /= (double)scoreArr.size(0);
        }
        return score;
    }

    @Override
    public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        INDArray scoreArr = this.scoreArray(labels, preOutput, activationFn, mask);
        return scoreArr.sum(1);
    }

    @Override
    public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        if (labels.size(1) != preOutput.size(1)) {
            throw new IllegalArgumentException("Labels array numColumns (size(1) = " + labels.size(1) + ") does not match output layer number of outputs (nOut = " + preOutput.size(1) + ") ");
        }
        INDArray grad = Nd4j.ones(labels.shape());
        this.calculate(labels, preOutput, activationFn, mask, null, grad);
        return grad;
    }

    @Override
    public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask, boolean average) {
        INDArray scoreArr = Nd4j.create(labels.size(0), 1);
        INDArray grad = Nd4j.ones(labels.shape());
        this.calculate(labels, preOutput, activationFn, mask, scoreArr, grad);
        double score = scoreArr.sumNumber().doubleValue();
        if (average) {
            score /= (double)scoreArr.size(0);
        }
        return new Pair((Object)score, (Object)grad);
    }

    @Override
    public String name() {
        return this.toString();
    }

    @Override
    public String toString() {
        return "LossMultiLabel";
    }

    @Override
    public SDVariable[] outputVariables() {
        throw new UnsupportedOperationException();
    }

    @Override
    public SDVariable[] outputVariables(String baseName) {
        throw new UnsupportedOperationException();
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> f1) {
        throw new UnsupportedOperationException();
    }

    @Override
    public String opName() {
        return this.name();
    }

    @Override
    public Op.Type opType() {
        return Op.Type.CUSTOM;
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    }

    @Override
    public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
    }

    @Override
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op name found for " + this.opName());
    }

    @Override
    public String tensorflowName() {
        throw new NoOpNameFoundException("No tensorflow op name found for " + this.opName());
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof LossMultiLabel)) {
            return false;
        }
        LossMultiLabel other = (LossMultiLabel)o;
        return other.canEqual(this);
    }

    protected boolean canEqual(Object other) {
        return other instanceof LossMultiLabel;
    }

    @Override
    public int hashCode() {
        int result = 1;
        return result;
    }
}

