/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.graph.vertex.impl.rnn;

import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.graph.vertex.BaseGraphVertex;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.primitives.Pair;

public class ReverseTimeSeriesVertex
extends BaseGraphVertex {
    private final String inputName;
    private final int inputIdx;

    public ReverseTimeSeriesVertex(ComputationGraph graph, String name, int vertexIndex, String inputName) {
        super(graph, name, vertexIndex, null, null);
        this.inputName = inputName;
        if (inputName == null) {
            this.inputIdx = -1;
        } else {
            this.inputIdx = graph.getConfiguration().getNetworkInputs().indexOf(inputName);
            if (this.inputIdx == -1) {
                throw new IllegalArgumentException("Invalid input name: \"" + inputName + "\" not found in list of network inputs (" + graph.getConfiguration().getNetworkInputs() + ")");
            }
        }
    }

    @Override
    public boolean hasLayer() {
        return false;
    }

    @Override
    public boolean isOutputVertex() {
        return false;
    }

    @Override
    public Layer getLayer() {
        return null;
    }

    @Override
    public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
        INDArray mask = this.getMask();
        INDArray input = this.inputs[0];
        return ReverseTimeSeriesVertex.revertTimeSeries(input, mask, workspaceMgr, ArrayType.ACTIVATIONS);
    }

    @Override
    public Pair<Gradient, INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) {
        INDArray mask = this.getMask();
        INDArray epsilonsOut = ReverseTimeSeriesVertex.revertTimeSeries(this.epsilon, mask, workspaceMgr, ArrayType.ACTIVATION_GRAD);
        return new Pair(null, (Object)new INDArray[]{epsilonsOut});
    }

    private INDArray getMask() {
        if (this.inputIdx < 0) {
            return null;
        }
        INDArray[] inputMaskArrays = this.graph.getInputMaskArrays();
        return inputMaskArrays != null ? inputMaskArrays[this.inputIdx] : null;
    }

    private static INDArray revertTimeSeries(INDArray input, INDArray mask, LayerWorkspaceMgr workspaceMgr, ArrayType type) {
        int n = input.size(0);
        int m = input.size(2);
        INDArray out = workspaceMgr.create(type, input.shape(), 'f');
        for (int s = 0; s < n; ++s) {
            int t1 = 0;
            for (int t2 = m - 1; t1 < m && t2 >= 0; ++t1, --t2) {
                if (mask != null) {
                    while (t1 < m && mask.getDouble(s, t1) == 0.0) {
                        ++t1;
                    }
                    while (t2 >= 0 && mask.getDouble(s, t2) == 0.0) {
                        --t2;
                    }
                }
                INDArray vec = input.get(new INDArrayIndex[]{NDArrayIndex.point((int)s), NDArrayIndex.all(), NDArrayIndex.point((int)t1)});
                out.put(new INDArrayIndex[]{NDArrayIndex.point((int)s), NDArrayIndex.all(), NDArrayIndex.point((int)t2)}, vec);
            }
        }
        return out;
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray backpropGradientsViewArray) {
        if (backpropGradientsViewArray != null) {
            throw new RuntimeException("Vertex does not have gradients; gradients view array cannot be set here");
        }
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize) {
        if (maskArrays.length > 1) {
            throw new IllegalArgumentException("This vertex can only handle one input and hence only one mask");
        }
        return new Pair((Object)maskArrays[0], (Object)currentMaskState);
    }

    @Override
    public String toString() {
        String paramStr = this.inputName == null ? "" : "inputName=" + this.inputName;
        return "ReverseTimeSeriesVertex(" + paramStr + ")";
    }
}

