/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.transforms.clip;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.Shape;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class ClipByNorm
extends DynamicCustomOp {
    private double clipValue;

    public ClipByNorm() {
    }

    public ClipByNorm(SameDiff sameDiff, SDVariable x, double clipValue, int ... dimensions) {
        super(null, sameDiff, new SDVariable[]{x});
        this.clipValue = clipValue;
        this.dimensions = dimensions;
        this.addIArgument(dimensions);
        this.addTArgument(clipValue);
    }

    @Override
    public String opName() {
        return "clipbynorm";
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> grad) {
        int origRank = Shape.rankFromShape(this.arg().getShape());
        SDVariable l2norm = this.f().norm2(this.arg(), this.dimensions);
        SDVariable broadcastableNorm = this.f().reductionBroadcastableWithOrigShape(origRank, this.dimensions, l2norm);
        SDVariable isClippedBC = this.f().gte(broadcastableNorm, this.clipValue);
        SDVariable notClippedBC = isClippedBC.rsub(1.0);
        SDVariable dOutdInClipped = this.f().neg(this.f().square(this.arg()).div(this.f().cube(broadcastableNorm))).add(broadcastableNorm.rdiv(1.0)).mul(this.clipValue).mul(isClippedBC);
        SDVariable ret = notClippedBC.add(dOutdInClipped).mul(grad.get(0));
        return Arrays.asList(ret);
    }
}

