/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.controlflow;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicInteger;
import onnx.OnnxProto3;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.controlflow.WhileDerivative;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class While
extends DifferentialFunction
implements CustomOp {
    private static final Logger log = LoggerFactory.getLogger(While.class);
    private AtomicInteger startPosition;
    protected SameDiff loopBodyExecution;
    protected SameDiff predicateExecution;
    protected SameDiff.SameDiffConditional predicate;
    protected SameDiff.SameDiffFunctionDefinition trueBody;
    protected String blockName;
    protected String trueBodyName;
    protected SDVariable[] inputVars;
    protected SDVariable targetBoolean;
    protected SDVariable dummyResult;
    protected SDVariable[] outputVars;
    protected int numLooped = 0;

    public While(AtomicInteger startPosition) {
        this.startPosition = startPosition;
    }

    public While(While whileStatement) {
        this.sameDiff = whileStatement.sameDiff;
        this.outputVars = whileStatement.outputVars;
        this.loopBodyExecution = whileStatement.loopBodyExecution;
        this.numLooped = whileStatement.numLooped;
        this.dummyResult = whileStatement.dummyResult;
        this.predicate = whileStatement.predicate;
        this.predicateExecution = whileStatement.predicateExecution;
        this.inputVars = whileStatement.inputVars;
        this.dummyResult = this.sameDiff.var("dummyresult-" + UUID.randomUUID().toString(), new int[]{1, 1}, new ZeroInitScheme('f'));
    }

    public While(String blockName, SameDiff parent, SDVariable[] inputVars, SameDiff.SameDiffConditional predicate, SameDiff.SameDiffFunctionDefinition condition, SameDiff.SameDiffFunctionDefinition trueBody) {
        this.init(blockName, parent, inputVars, predicate, condition, trueBody);
    }

    private void init(String blockName, SameDiff parent, SDVariable[] inputVars, SameDiff.SameDiffConditional predicate, SameDiff.SameDiffFunctionDefinition condition, SameDiff.SameDiffFunctionDefinition trueBody) {
        String trueBodyName;
        this.sameDiff = parent;
        this.inputVars = inputVars;
        this.predicate = predicate;
        this.trueBody = trueBody;
        this.blockName = blockName;
        this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(), new int[]{1, 1}, new ZeroInitScheme('f'));
        parent.putFunctionForId(this.getOwnName(), this);
        parent.addArgsFor(inputVars, (DifferentialFunction)this);
        parent.addOutgoingFor(new SDVariable[]{this.dummyResult}, (DifferentialFunction)this);
        SameDiff sameDiff = SameDiff.create();
        this.targetBoolean = predicate.eval(sameDiff, condition, inputVars);
        this.predicateExecution = sameDiff;
        this.trueBodyName = trueBodyName = "true-body-" + UUID.randomUUID().toString();
        parent.defineFunction(trueBodyName, trueBody, inputVars);
        parent.defineFunction(blockName, condition, inputVars);
        parent.putSubFunction("predicate-eval-body", sameDiff);
        this.loopBodyExecution = parent.getFunction(trueBodyName);
    }

    @Override
    public SDVariable[] outputVariables(String baseName) {
        return new SDVariable[]{this.dummyResult};
    }

    @Override
    public List<SDVariable> doDiff(List<SDVariable> f1) {
        ArrayList<SDVariable> ret = new ArrayList<SDVariable>();
        ret.addAll(Arrays.asList(new WhileDerivative(this).outputVariables()));
        return ret;
    }

    public void incrementLoopCounter() {
        ++this.numLooped;
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        this.doImport(nodeDef, initWith, attributesForNode, graph, new LinkedHashSet<String>(), new AtomicInteger(0));
    }

    private void doImport(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph, Set<String> skipSet, AtomicInteger currIndex) {
        NodeDef tfNode;
        NodeDef tfNode2;
        NodeDef tfNode3;
        NodeDef tfNode4;
        NodeDef tfNode5;
        NodeDef tfNode6;
        String uniqueId = UUID.randomUUID().toString();
        skipSet.add(nodeDef.getName());
        SameDiff scopeCondition = SameDiff.create();
        SameDiff scopeLoop = SameDiff.create();
        initWith.putSubFunction("condition-" + uniqueId, scopeCondition);
        initWith.putSubFunction("loopbody-" + uniqueId, scopeLoop);
        this.loopBodyExecution = scopeLoop;
        this.predicateExecution = scopeCondition;
        this.startPosition = currIndex;
        log.info("Adding 2 new scopes for WHILE {}");
        List<NodeDef> nodes = graph.getNodeList();
        while (currIndex.get() < nodes.size() && (tfNode6 = nodes.get(currIndex.get())).getOp().equalsIgnoreCase("enter")) {
            skipSet.add(tfNode6.getName());
            SDVariable[] vars = new SDVariable[tfNode6.getInputCount()];
            for (int e = 0; e < tfNode6.getInputCount(); ++e) {
                String input = TFGraphMapper.getInstance().getNodeName(tfNode6.getInput(e));
                vars[e] = initWith.getVariable(input) == null ? initWith.var(input, null, new ZeroInitScheme()) : initWith.getVariable(input);
                scopeCondition.var(vars[e]);
                scopeLoop.var(vars[e]);
            }
            this.inputVars = vars;
            currIndex.incrementAndGet();
        }
        int mergedCnt = 0;
        while (currIndex.get() < nodes.size()) {
            tfNode5 = nodes.get(currIndex.get());
            if (!tfNode5.getOp().equalsIgnoreCase("merge")) {
                scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode5.getName()), null, new ZeroInitScheme());
                break;
            }
            skipSet.add(tfNode5.getName());
            SDVariable var = scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode5.getName()), null, new ZeroInitScheme());
            scopeCondition.var(var);
            initWith.var(var);
            ++mergedCnt;
            currIndex.incrementAndGet();
        }
        while (currIndex.get() < nodes.size()) {
            tfNode5 = nodes.get(currIndex.get());
            if (tfNode5.getOp().equalsIgnoreCase("LoopCond")) {
                skipSet.add(tfNode5.getName());
                currIndex.incrementAndGet();
                break;
            }
            boolean isConst = tfNode5.getOp().equalsIgnoreCase("const");
            boolean isVar = tfNode5.getOp().startsWith("VariableV");
            boolean isPlaceholder = tfNode5.getOp().startsWith("Placeholder");
            if (isConst || isVar || isPlaceholder) {
                SDVariable var = scopeCondition.var(tfNode5.getName(), null, new ZeroInitScheme());
                scopeLoop.var(var);
                initWith.var(var);
                log.info("Adding condition var [{}]", (Object)var.getVarName());
            } else if (!skipSet.contains(tfNode5.getName())) {
                DifferentialFunction func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode5.getOp()).opName());
                func.initFromTensorFlow(tfNode5, scopeCondition, nodeDef.getAttrMap(), graph);
                func.setSameDiff(scopeLoop);
            }
            skipSet.add(tfNode5.getName());
            currIndex.incrementAndGet();
        }
        int switchCnt = 0;
        while (currIndex.get() < nodes.size() && (tfNode4 = nodes.get(currIndex.get())).getOp().equalsIgnoreCase("Switch")) {
            ++switchCnt;
            skipSet.add(tfNode4.getName());
            currIndex.incrementAndGet();
        }
        boolean identityCnt = false;
        while (currIndex.get() < nodes.size() && (tfNode3 = nodes.get(currIndex.get())).getOp().equalsIgnoreCase("Identity")) {
            DifferentialFunction func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode3.getOp()).opName());
            func.initFromTensorFlow(tfNode3, initWith, nodeDef.getAttrMap(), graph);
            func.setSameDiff(scopeLoop);
            SDVariable[] variables = new SDVariable[tfNode3.getInputCount()];
            for (int i = 0; i < tfNode3.getInputCount(); ++i) {
                SDVariable testVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode3.getInput(i)));
                if (testVar == null) {
                    variables[i] = initWith.var(tfNode3.getInput(i), null, new ZeroInitScheme());
                    scopeCondition.var(variables[i]);
                    scopeLoop.var(variables[i]);
                    continue;
                }
                variables[i] = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode3.getInput(i)));
                scopeCondition.var(variables[i]);
                scopeLoop.var(variables[i]);
            }
            scopeLoop.addArgsFor(variables, func);
            skipSet.add(tfNode3.getName());
            currIndex.incrementAndGet();
        }
        while (currIndex.get() < nodes.size()) {
            tfNode = nodes.get(currIndex.get());
            if (skipSet.contains(tfNode.getName())) {
                log.info("Skipping: {}", (Object)tfNode.getName());
            } else {
                if (tfNode.getOp().equalsIgnoreCase("NextIteration")) break;
                if (skipSet.contains(tfNode.getName())) {
                    log.info("Skipping: {}", (Object)tfNode.getName());
                } else {
                    boolean isConst = tfNode.getOp().equalsIgnoreCase("const");
                    boolean isVar = tfNode.getOp().startsWith("VariableV");
                    boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder");
                    if (isConst || isVar || isPlaceholder) {
                        SDVariable var = scopeLoop.var(tfNode.getName(), null, new ZeroInitScheme());
                        log.info("Adding body var [{}]", (Object)var.getVarName());
                    } else {
                        DifferentialFunction func;
                        log.info("starting on [{}]: {}", (Object)tfNode.getName(), (Object)tfNode.getOp());
                        if (tfNode.getOp().equalsIgnoreCase("enter")) {
                            log.info("NEW LOOP ----------------------------------------");
                            func = new While(currIndex);
                            super.doImport(nodeDef, initWith, attributesForNode, graph, skipSet, currIndex);
                            func.setSameDiff(initWith);
                            log.info("END LOOP ----------------------------------------");
                        } else {
                            func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
                            func.initFromTensorFlow(tfNode, initWith, nodeDef.getAttrMap(), graph);
                            func.setSameDiff(scopeCondition);
                            SDVariable[] variables = new SDVariable[tfNode.getInputCount()];
                            for (int i = 0; i < tfNode.getInputCount(); ++i) {
                                String name = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i));
                                variables[i] = scopeCondition.getVariable(name);
                                if (variables[i] != null) continue;
                                variables[i] = scopeLoop.getVariable(name) == null ? scopeCondition.var(initWith.getVariable(name)) : (scopeLoop.getVariable(name) != null ? scopeLoop.getVariable(name) : scopeLoop.var(name, Nd4j.scalar(1.0)));
                            }
                            scopeLoop.addArgsFor(variables, func);
                        }
                    }
                    skipSet.add(tfNode.getName());
                }
            }
            currIndex.incrementAndGet();
        }
        ArrayList<SDVariable> returnInputs = new ArrayList<SDVariable>();
        ArrayList returnOutputs = new ArrayList();
        while (currIndex.get() < nodes.size() && (tfNode2 = nodes.get(currIndex.get())).getOp().equalsIgnoreCase("NextIteration")) {
            skipSet.add(tfNode2.getName());
            String inputName = TFGraphMapper.getInstance().getNodeName(tfNode2.getName());
            SDVariable input = initWith.getVariable(inputName) == null ? initWith.var(inputName, null, new ZeroInitScheme()) : initWith.getVariable(inputName);
            returnInputs.add(input);
            currIndex.incrementAndGet();
        }
        this.outputVars = returnOutputs.toArray(new SDVariable[returnOutputs.size()]);
        this.inputVars = returnInputs.toArray(new SDVariable[returnInputs.size()]);
        initWith.addArgsFor(this.inputVars, (DifferentialFunction)this);
        initWith.addOutgoingFor(this.outputVars, (DifferentialFunction)this);
        boolean exitCnt = false;
        while (currIndex.get() < nodes.size() && (tfNode = nodes.get(currIndex.get())).getOp().equalsIgnoreCase("Exit")) {
            skipSet.add(tfNode.getName());
            String inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName());
            SDVariable sDVariable = initWith.getVariable(inputName) == null ? initWith.var(inputName, null, new ZeroInitScheme()) : initWith.getVariable(inputName);
            currIndex.incrementAndGet();
        }
        DifferentialFunction[] conditionVars = scopeCondition.functions();
        if (conditionVars.length < 1) {
            throw new ND4JIllegalArgumentException("No functions found!");
        }
        this.targetBoolean = conditionVars[conditionVars.length - 1].outputVariables()[0];
        log.info("-------------------------------------------");
    }

    @Override
    public void initFromOnnx(OnnxProto3.NodeProto node, SameDiff initWith, Map<String, OnnxProto3.AttributeProto> attributesForNode, OnnxProto3.GraphProto graph) {
    }

    @Override
    public String toString() {
        return this.opName();
    }

    @Override
    public String opName() {
        return "while";
    }

    @Override
    public long opHash() {
        return this.opName().hashCode();
    }

    @Override
    public boolean isInplaceCall() {
        return false;
    }

    @Override
    public INDArray[] outputArguments() {
        return new INDArray[0];
    }

    @Override
    public INDArray[] inputArguments() {
        return new INDArray[0];
    }

    @Override
    public int[] iArgs() {
        return new int[0];
    }

    @Override
    public double[] tArgs() {
        return new double[0];
    }

    @Override
    public void addIArgument(int ... arg) {
    }

    @Override
    public void removeIArgument(Integer arg) {
    }

    @Override
    public Integer getIArgument(int index) {
        return null;
    }

    @Override
    public int numIArguments() {
        return 0;
    }

    @Override
    public void addTArgument(double ... arg) {
    }

    @Override
    public void removeTArgument(Double arg) {
    }

    @Override
    public Double getTArgument(int index) {
        return null;
    }

    @Override
    public int numTArguments() {
        return 0;
    }

    @Override
    public void addInputArgument(INDArray ... arg) {
    }

    @Override
    public void removeInputArgument(INDArray arg) {
    }

    @Override
    public INDArray getInputArgument(int index) {
        return null;
    }

    @Override
    public int numInputArguments() {
        return 0;
    }

    @Override
    public void addOutputArgument(INDArray ... arg) {
    }

    @Override
    public void removeOutputArgument(INDArray arg) {
    }

    @Override
    public INDArray getOutputArgument(int index) {
        return null;
    }

    @Override
    public int numOutputArguments() {
        return 0;
    }

    @Override
    public List<int[]> calculateOutputShape() {
        ArrayList<int[]> ret = new ArrayList<int[]>();
        for (SDVariable var : this.args()) {
            ret.add(this.sameDiff.getShapeForVarName(var.getVarName()));
        }
        return ret;
    }

    @Override
    public CustomOpDescriptor getDescriptor() {
        return CustomOpDescriptor.builder().build();
    }

    @Override
    public void assertValidForExecution() {
    }

    @Override
    public void populateInputsAndOutputsFromSameDiff() {
    }

    @Override
    public String onnxName() {
        throw new NoOpNameFoundException("No onnx op opName found for " + this.opName());
    }

    @Override
    public String tensorflowName() {
        throw new NoOpNameFoundException("No *singular (eg: use tensorflowNames() found for this op " + this.opName());
    }

    @Override
    public String[] tensorflowNames() {
        throw new NoOpNameFoundException("This operation has no TF counterpart");
    }

    @Override
    public Op.Type opType() {
        return Op.Type.LOOP;
    }

    public static WhileBuilder builder() {
        return new WhileBuilder();
    }

    public While() {
    }

    public SameDiff getLoopBodyExecution() {
        return this.loopBodyExecution;
    }

    public SameDiff getPredicateExecution() {
        return this.predicateExecution;
    }

    public SameDiff.SameDiffConditional getPredicate() {
        return this.predicate;
    }

    public SameDiff.SameDiffFunctionDefinition getTrueBody() {
        return this.trueBody;
    }

    public String getBlockName() {
        return this.blockName;
    }

    public String getTrueBodyName() {
        return this.trueBodyName;
    }

    public SDVariable[] getInputVars() {
        return this.inputVars;
    }

    public SDVariable getTargetBoolean() {
        return this.targetBoolean;
    }

    public SDVariable[] getOutputVars() {
        return this.outputVars;
    }

    public void setOutputVars(SDVariable[] outputVars) {
        this.outputVars = outputVars;
    }

    public int getNumLooped() {
        return this.numLooped;
    }

    public static class WhileBuilder {
        private String blockName;
        private SameDiff parent;
        private SDVariable[] inputVars;
        private SameDiff.SameDiffConditional predicate;
        private SameDiff.SameDiffFunctionDefinition condition;
        private SameDiff.SameDiffFunctionDefinition trueBody;

        WhileBuilder() {
        }

        public WhileBuilder blockName(String blockName) {
            this.blockName = blockName;
            return this;
        }

        public WhileBuilder parent(SameDiff parent) {
            this.parent = parent;
            return this;
        }

        public WhileBuilder inputVars(SDVariable[] inputVars) {
            this.inputVars = inputVars;
            return this;
        }

        public WhileBuilder predicate(SameDiff.SameDiffConditional predicate) {
            this.predicate = predicate;
            return this;
        }

        public WhileBuilder condition(SameDiff.SameDiffFunctionDefinition condition) {
            this.condition = condition;
            return this;
        }

        public WhileBuilder trueBody(SameDiff.SameDiffFunctionDefinition trueBody) {
            this.trueBody = trueBody;
            return this;
        }

        public While build() {
            return new While(this.blockName, this.parent, this.inputVars, this.predicate, this.condition, this.trueBody);
        }

        public String toString() {
            return "While.WhileBuilder(blockName=" + this.blockName + ", parent=" + this.parent + ", inputVars=" + Arrays.deepToString(this.inputVars) + ", predicate=" + this.predicate + ", condition=" + this.condition + ", trueBody=" + this.trueBody + ")";
        }
    }
}

