/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers.samediff;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.conf.layers.samediff.SDVertexParams;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffVertex;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;

public abstract class SameDiffLambdaVertex
extends SameDiffVertex {
    protected transient VertexInputs inputs;

    public abstract SDVariable defineVertex(SameDiff var1, VertexInputs var2);

    @Override
    public SDVariable defineVertex(SameDiff sameDiff, Map<String, SDVariable> layerInput, Map<String, SDVariable> paramTable, Map<String, SDVariable> maskVars) {
        VertexInputs vi = this.getInputs(sameDiff);
        int i = 0;
        if (vi.map.size() == 0 && layerInput.size() > 0) {
            for (SDVariable v : layerInput.values()) {
                vi.map.put(i++, v);
            }
        }
        return this.defineVertex(sameDiff, this.getInputs(sameDiff));
    }

    @Override
    public void defineParametersAndInputs(SDVertexParams params) {
        SameDiff temp = SameDiff.create();
        VertexInputs tempInputs = new VertexInputs(temp);
        this.defineVertex(temp, tempInputs);
        ArrayList<String> list = new ArrayList<String>();
        for (Integer i : tempInputs.map.keySet()) {
            list.add(((SDVariable)tempInputs.map.get(i)).getVarName());
        }
        params.defineInputs(list.toArray(new String[list.size()]));
    }

    @Override
    public void initializeParameters(Map<String, INDArray> params) {
    }

    protected VertexInputs getInputs(SameDiff sd) {
        if (this.inputs == null) {
            this.inputs = new VertexInputs(sd);
        }
        return this.inputs;
    }

    public class VertexInputs {
        private SameDiff sameDiff;
        private Map<Integer, SDVariable> map = new LinkedHashMap<Integer, SDVariable>();

        protected VertexInputs(SameDiff sd) {
            this.sameDiff = sd;
        }

        public SDVariable getInput(int inputNum) {
            Preconditions.checkArgument((inputNum >= 0 ? 1 : 0) != 0, (String)"Input number must be >= 0.Got: %s", (int)inputNum);
            if (!this.map.containsKey(inputNum)) {
                SDVariable var = this.sameDiff.var("var_" + inputNum, SameDiffLambdaVertex.this.dataType, new int[]{-1});
                this.map.put(inputNum, var);
            }
            return this.map.get(inputNum);
        }
    }
}

