/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.autodiff.samediff;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.AddOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.DivOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.RDivOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.RSubOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.SquaredDifferenceOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.SubOp;
import org.nd4j.linalg.api.ops.impl.transforms.arithmetic.TruncateDivOp;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class SDVariable
extends DifferentialFunction
implements Serializable {
    private String varName;
    protected WeightInitScheme weightInitScheme;

    private SDVariable(String varName, SameDiff sameDiff, int[] shape, WeightInitScheme weightInitScheme) {
        super(sameDiff, new Object[0]);
        this.varName = varName;
        this.weightInitScheme = weightInitScheme;
        if (weightInitScheme == null) {
            this.weightInitScheme = new ZeroInitScheme('f');
        }
        if (shape == null) {
            sameDiff.addAsPlaceHolder(varName);
        } else {
            boolean foundPlaceHolder = false;
            for (int i = 0; i < shape.length; ++i) {
                if (shape[i] >= 0) continue;
                sameDiff.addAsPlaceHolder(varName);
                sameDiff.setOriginalPlaceHolderShape(varName, shape);
                foundPlaceHolder = true;
                break;
            }
            if (!foundPlaceHolder && shape != null) {
                sameDiff.putShapeForVarName(varName, shape);
            }
        }
        this.sameDiff = sameDiff;
    }

    public boolean isPlaceHolder() {
        return this.sameDiff.isPlaceHolder(this.varName);
    }

    @Override
    public String opName() {
        return "variable";
    }

    @Override
    public SDVariable[] outputVariables() {
        return new SDVariable[]{this};
    }

    @Override
    public SDVariable arg() {
        return this;
    }

    @Override
    public SDVariable[] args() {
        return new SDVariable[]{this};
    }

    @Override
    public SDVariable[] outputVariables(String baseName) {
        return new SDVariable[]{this};
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
    }

    @Override
    public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
    }

    public INDArray storeAndAllocateNewArray() {
        int[] shape = this.sameDiff.getShapeForVarName(this.getVarName());
        if (this.getArr() != null && Arrays.equals(this.getArr().shape(), shape)) {
            return this.getArr();
        }
        if (this.varName == null) {
            throw new ND4JIllegalStateException("Unable to store array for null variable name!");
        }
        if (shape == null) {
            throw new ND4JIllegalStateException("Unable to allocate new array. No shape found for variable " + this.varName);
        }
        INDArray arr = this.getWeightInitScheme().create(shape);
        this.sameDiff.putArrayForVarName(this.getVarName(), arr);
        return arr;
    }

    public INDArray getArr() {
        if (this.sameDiff.arrayAlreadyExistsForVarName(this.getVarName())) {
            return this.sameDiff.getArrForVarName(this.getVarName());
        }
        if (this.getScalarValue() != null && ArrayUtil.prod((int[])this.getShape()) == 1) {
            INDArray arr = Nd4j.valueArrayOf(this.getShape(), this.getScalarValue().doubleValue());
            this.sameDiff.associateArrayWithVariable(arr, this);
        } else {
            if (this.sameDiff.getShapeForVarName(this.getVarName()) == null) {
                return null;
            }
            INDArray newAlloc = this.getWeightInitScheme().create(this.sameDiff.getShapeForVarName(this.getVarName()));
            this.sameDiff.associateArrayWithVariable(newAlloc, this);
        }
        return this.sameDiff.getArrForVarName(this.getVarName());
    }

    public SDVariable gradient() {
        return this.getGradient();
    }

    public SDVariable getGradient() {
        return this.sameDiff.getGradForVariable(this.getVarName());
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> f1) {
        throw new ND4JIllegalStateException("Unable to differentiate a variable! Must be a function.");
    }

    public int[] getShape() {
        INDArray arr;
        int[] initialShape = this.sameDiff.getShapeForVarName(this.getVarName());
        if (initialShape == null && (arr = this.getArr()) != null) {
            return arr.shape();
        }
        return initialShape;
    }

    @Override
    public SDVariable dup() {
        return this.sameDiff.var(this);
    }

    public SDVariable rsub(double sameDiffVariable) {
        return this.rsub(this.sameDiff.generateNewVarName(new RSubOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable rdiv(double sameDiffVariable) {
        return this.rdiv(this.sameDiff.generateNewVarName(new RDivOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable add(double sameDiffVariable) {
        return this.add(this.sameDiff.generateNewVarName(new AddOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable sub(double sameDiffVariable) {
        return this.sub(this.sameDiff.generateNewVarName(new SubOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable squaredDifference(SDVariable sameDiffVariable) {
        return this.squaredDifference(this.sameDiff.generateNewVarName(new SquaredDifferenceOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable div(double sameDiffVariable) {
        return this.div(this.sameDiff.generateNewVarName(new DivOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable mul(double sameDiffVariable) {
        return this.mul(this.sameDiff.generateNewVarName(new MulOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable rsubi(double sameDiffVariable) {
        return this.rsubi(this.sameDiff.generateNewVarName(new RSubOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable rdivi(double sameDiffVariable) {
        return this.rdivi(this.sameDiff.generateNewVarName(new RDivOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable addi(double sameDiffVariable) {
        return this.addi(this.sameDiff.generateNewVarName(new AddOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable subi(double sameDiffVariable) {
        return this.subi(this.sameDiff.generateNewVarName(new SubOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable divi(double sameDiffVariable) {
        return this.divi(this.sameDiff.generateNewVarName(new DivOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable muli(double sameDiffVariable) {
        return this.muli(this.sameDiff.generateNewVarName(new MulOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable rsub(SDVariable sameDiffVariable) {
        return this.rsub(this.sameDiff.generateNewVarName(new RSubOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable rdiv(SDVariable sameDiffVariable) {
        return this.rdiv(this.sameDiff.generateNewVarName(new RDivOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable truncatedDiv(SDVariable sameDiffVariable) {
        return this.truncatedDiv(this.sameDiff.generateNewVarName(new TruncateDivOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable add(SDVariable sameDiffVariable) {
        return this.add(this.sameDiff.generateNewVarName(new AddOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable sub(SDVariable sameDiffVariable) {
        return this.sub(this.sameDiff.generateNewVarName(new SubOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable div(SDVariable sameDiffVariable) {
        return this.div(this.sameDiff.generateNewVarName(new DivOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable mul(SDVariable sameDiffVariable) {
        return this.mul(this.sameDiff.generateNewVarName(new MulOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable rsubi(SDVariable sameDiffVariable) {
        return this.rsubi(this.sameDiff.generateNewVarName(new RSubOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable rdivi(SDVariable sameDiffVariable) {
        return this.rdivi(this.sameDiff.generateNewVarName(new RDivOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable addi(SDVariable sameDiffVariable) {
        return this.addi(this.sameDiff.generateNewVarName(new AddOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable subi(SDVariable sameDiffVariable) {
        return this.subi(this.sameDiff.generateNewVarName(new SubOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable divi(SDVariable sameDiffVariable) {
        return this.divi(this.sameDiff.generateNewVarName(new DivOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable muli(SDVariable sameDiffVariable) {
        return this.muli(this.sameDiff.generateNewVarName(new MulOp().opName(), 0), sameDiffVariable);
    }

    public SDVariable rsub(String varName, double sameDiffVariable) {
        SDVariable function = this.sameDiff.f().rsub(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable rdiv(String varName, double sameDiffVariable) {
        SDVariable function = this.sameDiff.f().rdiv(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable truncatedDiv(String varName, SDVariable sameDiffVariable) {
        SDVariable function = this.sameDiff.f().truncatedDiv(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable add(String varName, double sameDiffVariable) {
        SDVariable function = this.sameDiff.f().add(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable sub(String varName, double sameDiffVariable) {
        SDVariable right = this;
        SDVariable result = this.sameDiff.f().sub(right, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(result, varName);
    }

    public SDVariable div(String varName, double sameDiffVariable) {
        SDVariable function = this.sameDiff.f().div(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable mul(String varName, double sameDiffVariable) {
        SDVariable function = this.sameDiff.f().mul(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable rsubi(String varName, double sameDiffVariable) {
        SDVariable function = this.sameDiff.f().rsubi(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable rdivi(String varName, double sameDiffVariable) {
        SDVariable function = this.sameDiff.f().rdivi(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable addi(String varName, double sameDiffVariable) {
        SDVariable function = this.sameDiff.f().addi(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable subi(String varName, double sameDiffVariable) {
        SDVariable function = this.sameDiff.f().subi(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable divi(String varName, double sameDiffVariable) {
        SDVariable function = this.sameDiff.f().divi(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable muli(String varName, double sameDiffVariable) {
        SDVariable function = this.sameDiff.f().muli(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(function, varName);
    }

    public SDVariable rsub(String varName, SDVariable sameDiffVariable) {
        this.assertShapeEquals(sameDiffVariable);
        SDVariable result = this.sameDiff.f().rsub(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(result, varName);
    }

    public SDVariable rdiv(String varName, SDVariable sameDiffVariable) {
        this.assertShapeEquals(sameDiffVariable);
        SDVariable result = this.sameDiff.f().rdiv(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(result, varName);
    }

    public SDVariable add(String varName, SDVariable sameDiffVariable) {
        this.assertShapeEquals(sameDiffVariable);
        SDVariable result = this.sameDiff.f().add(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(result, varName);
    }

    public SDVariable sub(String varName, SDVariable sameDiffVariable) {
        this.assertShapeEquals(sameDiffVariable);
        SDVariable left = this;
        SDVariable right = sameDiffVariable;
        SDVariable result = this.sameDiff.f().sub(left, right);
        return this.sameDiff.updateVariableNameAndReference(result, varName);
    }

    public SDVariable squaredDifference(String varName, SDVariable sameDiffVariable) {
        this.assertShapeEquals(sameDiffVariable);
        SDVariable left = this;
        SDVariable right = sameDiffVariable;
        SDVariable result = this.sameDiff.f().squaredDifference(left, right);
        return this.sameDiff.updateVariableNameAndReference(result, varName);
    }

    public SDVariable div(String varName, SDVariable sameDiffVariable) {
        this.assertShapeEquals(sameDiffVariable);
        SDVariable result = this.sameDiff.f().div(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(result, varName);
    }

    public SDVariable mul(String varName, SDVariable sameDiffVariable) {
        this.assertShapeEquals(sameDiffVariable);
        SDVariable left = this;
        SDVariable right = sameDiffVariable;
        Preconditions.checkNotNull((Object)left, (String)"Left input is null!");
        Preconditions.checkNotNull((Object)right, (String)"Right input is null!");
        SDVariable result = this.sameDiff.f().mul(left, right);
        return this.sameDiff.updateVariableNameAndReference(result, varName);
    }

    public SDVariable rsubi(String varName, SDVariable sameDiffVariable) {
        this.assertShapeEquals(sameDiffVariable);
        SDVariable result = this.sameDiff.f().rsubi(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(result, varName);
    }

    public SDVariable rdivi(String varName, SDVariable sameDiffVariable) {
        this.assertShapeEquals(sameDiffVariable);
        SDVariable result = this.sameDiff.f().rdivi(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(result, varName);
    }

    public SDVariable addi(String varName, SDVariable sameDiffVariable) {
        this.assertShapeEquals(sameDiffVariable);
        SDVariable result = this.sameDiff.f().addi(this, sameDiffVariable);
        return this.sameDiff.updateVariableNameAndReference(result, varName);
    }

    @Override
    public Op.Type opType() {
        return Op.Type.RETURN;
    }

    public SDVariable subi(String varName, SDVariable sameDiffVariable) {
        this.assertShapeEquals(sameDiffVariable);
        SDVariable left = this;
        SDVariable right = sameDiffVariable;
        SDVariable result = this.sameDiff.f().subi(left, right);
        return this.sameDiff.updateVariableNameAndReference(result, varName);
    }

    public SDVariable divi(String varName, SDVariable sameDiffVariable) {
        this.assertShapeEquals(sameDiffVariable);
        SDVariable result = this.sameDiff.f().divi(this, sameDiffVariable);
        result.setVarName(varName);
        return result;
    }

    public SDVariable muli(String varName, SDVariable sameDiffVariable) {
        this.assertShapeEquals(sameDiffVariable);
        SDVariable left = this;
        SDVariable right = sameDiffVariable;
        SDVariable result = this.sameDiff.f().muli(left, right);
        result.setVarName(varName);
        return result;
    }

    public INDArray eval() {
        SameDiff exec = this.sameDiff.dup();
        exec.defineFunction("output", new SameDiff.SameDiffFunctionDefinition(){

            @Override
            public SDVariable[] define(SameDiff sameDiff, Map<String, INDArray> inputs, SDVariable[] variableInputs) {
                return new SDVariable[]{SDVariable.this};
            }
        });
        SDVariable output = exec.invokeFunctionOn("output", exec);
        return output.getSameDiff().execAndEndResult();
    }

    private void assertShapeEquals(SDVariable variable) {
    }

    @Override
    public String toString() {
        return this.varName;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        SDVariable that = (SDVariable)o;
        if (this.varName != null ? !this.varName.equals(that.varName) : that.varName != null) {
            return false;
        }
        return this.weightInitScheme != null ? this.weightInitScheme.equals(that.weightInitScheme) : that.weightInitScheme == null;
    }

    @Override
    public int hashCode() {
        int result = super.hashCode();
        result = 31 * result + (this.varName != null ? this.varName.hashCode() : 0);
        result = 31 * result + (this.weightInitScheme != null ? this.weightInitScheme.hashCode() : 0);
        return result;
    }

    @Override
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op opName found for " + this.opName());
    }

    @Override
    public String tensorflowName() {
        throw new NoOpNameFoundException("No tensorflow op opName found for " + this.opName());
    }

    public static SDVariableBuilder builder() {
        return new SDVariableBuilder();
    }

    public SDVariable() {
    }

    public String getVarName() {
        return this.varName;
    }

    public void setVarName(String varName) {
        this.varName = varName;
    }

    public WeightInitScheme getWeightInitScheme() {
        return this.weightInitScheme;
    }

    public void setWeightInitScheme(WeightInitScheme weightInitScheme) {
        this.weightInitScheme = weightInitScheme;
    }

    public static class SDVariableBuilder {
        private String varName;
        private SameDiff sameDiff;
        private int[] shape;
        private WeightInitScheme weightInitScheme;

        SDVariableBuilder() {
        }

        public SDVariableBuilder varName(String varName) {
            this.varName = varName;
            return this;
        }

        public SDVariableBuilder sameDiff(SameDiff sameDiff) {
            this.sameDiff = sameDiff;
            return this;
        }

        public SDVariableBuilder shape(int[] shape) {
            this.shape = shape;
            return this;
        }

        public SDVariableBuilder weightInitScheme(WeightInitScheme weightInitScheme) {
            this.weightInitScheme = weightInitScheme;
            return this;
        }

        public SDVariable build() {
            return new SDVariable(this.varName, this.sameDiff, this.shape, this.weightInitScheme);
        }

        public String toString() {
            return "SDVariable.SDVariableBuilder(varName=" + this.varName + ", sameDiff=" + this.sameDiff + ", shape=" + Arrays.toString(this.shape) + ", weightInitScheme=" + this.weightInitScheme + ")";
        }
    }
}

