/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.feedforward.embedding;

import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.custom.ScatterUpdate;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class EmbeddingSequenceLayer
extends BaseLayer<org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer> {
    private static final Logger log = LoggerFactory.getLogger(EmbeddingSequenceLayer.class);
    private static final int[] WEIGHT_DIM = new int[]{1};

    public EmbeddingSequenceLayer(NeuralNetConfiguration conf) {
        super(conf);
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(true);
        INDArray z = this.preOutput(true, workspaceMgr);
        INDArray delta = (INDArray)((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getActivationFn().backprop(z, epsilon).getFirst();
        int inputLength = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getInputLength();
        int numSamples = this.input.rows();
        int nOut = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getNOut();
        delta = delta.permute(new int[]{2, 0, 1});
        delta = delta.reshape(inputLength * numSamples, nOut);
        if (this.maskArray != null) {
            INDArray maskDelta = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, delta.shape(), 'f');
            delta = Broadcast.mul((INDArray)delta, (INDArray)this.maskArray, (INDArray)maskDelta, (int[])new int[]{0, 2});
        }
        INDArray weightGradients = (INDArray)this.gradientViews.get("W");
        weightGradients.assign((Number)0);
        if (!Shape.hasDefaultStridesForShape((INDArray)this.input)) {
            this.input = workspaceMgr.dup(ArrayType.ACTIVATIONS, this.input, 'f');
        }
        int[] indexes = this.input.data().asInt();
        ScatterUpdate op = new ScatterUpdate(weightGradients, delta, indexes, WEIGHT_DIM, ScatterUpdate.UpdateOp.ADD);
        Nd4j.getExecutioner().exec((CustomOp)op);
        DefaultGradient ret = new DefaultGradient();
        ret.gradientForVariable().put("W", weightGradients);
        if (this.hasBias()) {
            INDArray biasGradientsView = (INDArray)this.gradientViews.get("b");
            delta.sum(biasGradientsView, new int[]{0});
            ret.gradientForVariable().put("b", biasGradientsView);
        }
        return new Pair((Object)ret, null);
    }

    @Override
    protected INDArray preOutput(boolean training, LayerWorkspaceMgr workspaceMgr) {
        this.assertInputSet(false);
        boolean inferInputLength = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).isInferInputLength();
        if (inferInputLength) {
            ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).setInputLength(this.input.columns());
        }
        if (this.input.columns() != ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getInputLength()) {
            throw new DL4JInvalidInputException("Sequence length of embedding input has to be equal to the specified input length: " + ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getInputLength() + " i.e. we expect input shape [numExamples, inputDim] with each entry being an integer index,  got [" + this.input.rows() + ", " + this.input.columns() + "] instead, for layer with id: " + this.layerId());
        }
        int nIn = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getNIn();
        int numRows = this.input.rows();
        int inputLength = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getInputLength();
        if (!Shape.hasDefaultStridesForShape((INDArray)this.input)) {
            this.input = workspaceMgr.dup(ArrayType.ACTIVATIONS, this.input, 'f');
        }
        int[] indexes = this.input.data().asInt();
        for (int i = 0; i < indexes.length; ++i) {
            indexes[i] = this.input.getInt(new int[]{i % numRows, i / numRows});
            if (indexes[i] >= 0 && indexes[i] < nIn) continue;
            throw new DL4JInvalidInputException("Invalid index for embedding layer: got index " + indexes[i] + " for entry " + i + " in minibatch; indexes must be between 0 and nIn-1 inclusive (0 to " + (nIn - 1) + ")");
        }
        INDArray weights = this.getParam("W");
        int nOut = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getNOut();
        INDArray destination = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, new int[]{numRows * inputLength, nOut});
        INDArray rows = Nd4j.pullRows((INDArray)weights, (INDArray)destination, (int)1, (int[])indexes);
        if (this.hasBias()) {
            INDArray bias = this.getParam("b");
            rows.addiRowVector(bias);
        }
        int[] shape = new int[]{inputLength, numRows, nOut};
        INDArray ret = workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, rows.reshape('c', shape));
        ret = ret.permute(new int[]{1, 2, 0});
        return ret;
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        INDArray rows = this.preOutput(training, workspaceMgr);
        INDArray ret = ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).getActivationFn().getActivation(rows, training);
        if (this.maskArray != null) {
            ret.muliColumnVector(this.maskArray);
        }
        return ret;
    }

    @Override
    public boolean hasBias() {
        return ((org.deeplearning4j.nn.conf.layers.EmbeddingSequenceLayer)this.layerConf()).hasBias();
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    protected void applyDropOutIfNecessary(boolean training, LayerWorkspaceMgr workspaceMgr) {
        throw new UnsupportedOperationException("Dropout not supported with EmbeddingLayer " + this.layerId());
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }
}

