package org.deeplearning4j.nn.graph.vertex.impl;

import java.util.Arrays;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.api.TrainingConfig;
import org.deeplearning4j.nn.api.layers.IOutputLayer;
import org.deeplearning4j.nn.api.layers.RecurrentLayer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.graph.vertex.VertexIndices;
import org.deeplearning4j.nn.layers.BaseOutputLayer;
import org.deeplearning4j.nn.layers.FrozenLayer;
import org.deeplearning4j.nn.layers.FrozenLayerWithBackprop;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/graph/vertex/impl/LayerVertex.class */
public class LayerVertex extends BaseGraphVertex {
    private Layer layer;
    private final InputPreProcessor layerPreProcessor;
    private boolean setLayerInput;

    public LayerVertex(ComputationGraph computationGraph, String str, int i, Layer layer, InputPreProcessor inputPreProcessor, boolean z, DataType dataType) {
        this(computationGraph, str, i, null, null, layer, inputPreProcessor, z, dataType);
    }

    public LayerVertex(ComputationGraph computationGraph, String str, int i, VertexIndices[] vertexIndicesArr, VertexIndices[] vertexIndicesArr2, Layer layer, InputPreProcessor inputPreProcessor, boolean z, DataType dataType) {
        super(computationGraph, str, i, vertexIndicesArr, vertexIndicesArr2, dataType);
        this.graph = computationGraph;
        this.vertexName = str;
        this.vertexIndex = i;
        this.inputVertices = vertexIndicesArr;
        this.outputVertices = vertexIndicesArr2;
        this.layer = layer;
        this.layerPreProcessor = inputPreProcessor;
        this.outputVertex = z;
        this.inputs = new INDArray[vertexIndicesArr != null ? vertexIndicesArr.length : 0];
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean hasLayer() {
        return true;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex, org.deeplearning4j.nn.graph.vertex.GraphVertex
    public void setLayerAsFrozen() {
        if (this.layer instanceof FrozenLayer) {
            return;
        }
        this.layer = new FrozenLayer(this.layer);
        this.layer.conf().getLayer().setLayerName(this.vertexName);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex, org.deeplearning4j.nn.graph.vertex.GraphVertex, org.deeplearning4j.nn.api.Trainable
    public Map<String, INDArray> paramTable(boolean z) {
        return this.layer.paramTable(z);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex, org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean isOutputVertex() {
        return this.outputVertex || (this.layer instanceof BaseOutputLayer);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Layer getLayer() {
        return this.layer;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public INDArray doForward(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (canDoForward()) {
            return this.layer.activate(z, layerWorkspaceMgr);
        }
        throw new IllegalStateException("Cannot do forward pass: all inputs not set");
    }

    public void applyPreprocessorAndSetInput(LayerWorkspaceMgr layerWorkspaceMgr) {
        INDArray iNDArray = this.inputs[0];
        if (this.layerPreProcessor != null) {
            iNDArray = this.layerPreProcessor.preProcess(iNDArray, this.graph.batchSize(), layerWorkspaceMgr);
        }
        this.layer.setInput(iNDArray, layerWorkspaceMgr);
        this.setLayerInput = true;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<Gradient, INDArray[]> doBackward(boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (!canDoBackward()) {
            if (this.inputs == null || this.inputs[0] == null) {
                throw new IllegalStateException("Cannot do backward pass: inputs not set. Layer: \"" + this.vertexName + "\" (idx " + this.vertexIndex + "), numInputs: " + getNumInputArrays());
            }
            throw new IllegalStateException("Cannot do backward pass: all epsilons not set. Layer \"" + this.vertexName + "\" (idx " + this.vertexIndex + "), numInputs :" + getNumInputArrays() + "; numOutputs: " + getNumOutputConnections());
        }
        if (!this.setLayerInput) {
            applyPreprocessorAndSetInput(layerWorkspaceMgr);
        }
        Pair<Gradient, INDArray> tbpttBackpropGradient = (z && (this.layer instanceof RecurrentLayer)) ? ((RecurrentLayer) this.layer).tbpttBackpropGradient(this.epsilon, this.graph.getConfiguration().getTbpttBackLength(), layerWorkspaceMgr) : this.layer.backpropGradient(this.epsilon, layerWorkspaceMgr);
        if (this.layerPreProcessor != null) {
            tbpttBackpropGradient.setSecond(this.layerPreProcessor.backprop((INDArray) tbpttBackpropGradient.getSecond(), this.graph.batchSize(), layerWorkspaceMgr));
        }
        return new Pair<>(tbpttBackpropGradient.getFirst(), new INDArray[]{(INDArray) tbpttBackpropGradient.getSecond()});
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex, org.deeplearning4j.nn.graph.vertex.GraphVertex
    public void setInput(int i, INDArray iNDArray, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (i > 0) {
            throw new IllegalArgumentException("Invalid input number: LayerVertex instances have only 1 input (got inputNumber = " + i + ")");
        }
        this.inputs[i] = iNDArray;
        this.setLayerInput = false;
        applyPreprocessorAndSetInput(layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public void setBackpropGradientsViewArray(INDArray iNDArray) {
        this.layer.setBackpropGradientsViewArray(iNDArray);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.GraphVertex
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] iNDArrayArr, MaskState maskState, int i) {
        if (iNDArrayArr == null || iNDArrayArr.length == 0) {
            return new Pair<>((Object) null, maskState);
        }
        if (this.layerPreProcessor != null) {
            Pair<INDArray, MaskState> feedForwardMaskArray = this.layerPreProcessor.feedForwardMaskArray(iNDArrayArr[0], maskState, i);
            if (feedForwardMaskArray == null) {
                iNDArrayArr[0] = null;
                maskState = null;
            } else {
                iNDArrayArr[0] = (INDArray) feedForwardMaskArray.getFirst();
                maskState = (MaskState) feedForwardMaskArray.getSecond();
            }
        }
        return this.layer.feedForwardMaskArray(iNDArrayArr[0], maskState, i);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("LayerVertex(id=").append(this.vertexIndex).append(",name=\"").append(this.vertexName).append("\",inputs=").append(Arrays.toString(this.inputVertices)).append(",outputs=").append(Arrays.toString(this.outputVertices)).append(")");
        return sb.toString();
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex, org.deeplearning4j.nn.graph.vertex.GraphVertex
    public boolean canDoBackward() {
        if (!isOutputVertex()) {
            if (getLayer() instanceof FrozenLayer) {
                return true;
            }
            return super.canDoBackward();
        }
        for (INDArray iNDArray : this.inputs) {
            if (iNDArray == null) {
                return false;
            }
        }
        Layer layer = this.layer;
        if (this.layer instanceof FrozenLayerWithBackprop) {
            layer = ((FrozenLayerWithBackprop) this.layer).getInsideLayer();
        }
        return (layer instanceof IOutputLayer) || this.epsilon != null;
    }

    public double computeScore(double d, boolean z, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (!(this.layer instanceof IOutputLayer)) {
            throw new UnsupportedOperationException("Cannot compute score: layer is not an output layer (layer class: " + this.layer.getClass().getSimpleName());
        }
        if (!this.setLayerInput) {
            applyPreprocessorAndSetInput(LayerWorkspaceMgr.noWorkspaces());
        }
        return ((IOutputLayer) this.layer).computeScore(d, z, layerWorkspaceMgr);
    }

    public INDArray computeScoreForExamples(double d, LayerWorkspaceMgr layerWorkspaceMgr) {
        if (!(this.layer instanceof IOutputLayer)) {
            throw new UnsupportedOperationException("Cannot compute score: layer is not an output layer (layer class: " + this.layer.getClass().getSimpleName());
        }
        if (!this.setLayerInput) {
            applyPreprocessorAndSetInput(layerWorkspaceMgr);
        }
        return ((IOutputLayer) this.layer).computeScoreForExamples(d, layerWorkspaceMgr);
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex, org.deeplearning4j.nn.api.Trainable
    public TrainingConfig getConfig() {
        return getLayer().getConfig();
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex, org.deeplearning4j.nn.api.Trainable
    public INDArray params() {
        return this.layer.params();
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex, org.deeplearning4j.nn.api.Trainable
    public INDArray getGradientsViewArray() {
        return this.layer.getGradientsViewArray();
    }

    public InputPreProcessor getLayerPreProcessor() {
        return this.layerPreProcessor;
    }

    public boolean isSetLayerInput() {
        return this.setLayerInput;
    }

    public void setLayer(Layer layer) {
        this.layer = layer;
    }

    public void setSetLayerInput(boolean z) {
        this.setLayerInput = z;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof LayerVertex)) {
            return false;
        }
        LayerVertex layerVertex = (LayerVertex) obj;
        if (!layerVertex.canEqual(this) || !super.equals(obj)) {
            return false;
        }
        Layer layer = getLayer();
        Layer layer2 = layerVertex.getLayer();
        if (layer == null) {
            if (layer2 != null) {
                return false;
            }
        } else if (!layer.equals(layer2)) {
            return false;
        }
        InputPreProcessor layerPreProcessor = getLayerPreProcessor();
        InputPreProcessor layerPreProcessor2 = layerVertex.getLayerPreProcessor();
        if (layerPreProcessor == null) {
            if (layerPreProcessor2 != null) {
                return false;
            }
        } else if (!layerPreProcessor.equals(layerPreProcessor2)) {
            return false;
        }
        return isSetLayerInput() == layerVertex.isSetLayerInput();
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
    protected boolean canEqual(Object obj) {
        return obj instanceof LayerVertex;
    }

    @Override // org.deeplearning4j.nn.graph.vertex.BaseGraphVertex
    public int hashCode() {
        int hashCode = super.hashCode();
        Layer layer = getLayer();
        int hashCode2 = (hashCode * 59) + (layer == null ? 43 : layer.hashCode());
        InputPreProcessor layerPreProcessor = getLayerPreProcessor();
        return (((hashCode2 * 59) + (layerPreProcessor == null ? 43 : layerPreProcessor.hashCode())) * 59) + (isSetLayerInput() ? 79 : 97);
    }
}
