/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.mkldnn;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.mkldnn.BaseMKLDNNHelper;
import org.deeplearning4j.nn.layers.recurrent.FwdPassReturn;
import org.deeplearning4j.nn.layers.recurrent.LSTMHelper;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationELU;
import org.nd4j.linalg.activations.impl.ActivationHardSigmoid;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.activations.impl.ActivationReLU;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationSoftPlus;
import org.nd4j.linalg.activations.impl.ActivationSoftSign;
import org.nd4j.linalg.activations.impl.ActivationTanH;
import org.nd4j.linalg.activations.impl.ActivationThresholdedReLU;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;

public class MKLDNNLSTMHelper
implements LSTMHelper {
    public MKLDNNLSTMHelper(DataType dataType) {
    }

    @Override
    public boolean checkSupported(IActivation gateActivationFn, IActivation activationFn, boolean hasPeepholeConnections) {
        return gateActivationFn instanceof ActivationSigmoid && activationFn instanceof ActivationTanH && BaseMKLDNNHelper.mklDnnEnabled();
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input, INDArray recurrentWeights, INDArray inputWeights, INDArray epsilon, boolean truncatedBPTT, int tbpttBackwardLength, FwdPassReturn fwdPass, boolean forwards, String inputWeightKey, String recurrentWeightKey, String biasWeightKey, Map<String, INDArray> gradientViews, INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) {
        return null;
    }

    @Override
    public FwdPassReturn activate(Layer layer, NeuralNetConfiguration conf, IActivation gateActivationFn, INDArray input, INDArray recurrentWeights, INDArray inputWeights, INDArray biases, boolean training, INDArray prevOutputActivations, INDArray prevMemCellState, boolean forBackprop, boolean forwards, String inputWeightKey, INDArray maskArray, boolean hasPeepholeConnections, LayerWorkspaceMgr workspaceMgr) {
        INDArray b1d = biases.reshape(new long[]{biases.length()});
        INDArray seqLen = null;
        if (maskArray != null) {
            seqLen = BooleanIndexing.firstIndex((INDArray)maskArray, (Condition)Conditions.equals((Number)0), (int[])new int[]{1});
        }
        ArrayList<INDArray> args = new ArrayList<INDArray>();
        args.add(input);
        args.add(inputWeights);
        args.add(recurrentWeights);
        if (hasPeepholeConnections) {
            throw new IllegalStateException("Not yet implemented");
        }
        args.add(b1d);
        if (seqLen != null) {
            args.add(seqLen);
        }
        if (prevOutputActivations != null) {
            args.add(prevOutputActivations);
        }
        if (prevMemCellState != null) {
            args.add(prevMemCellState);
        }
        IActivation a = ((LSTM)conf.getLayer()).getActivationFn();
        DynamicCustomOp op = DynamicCustomOp.builder((String)"lstmLayer").addInputs(args.toArray(new INDArray[0])).addBooleanArguments(new boolean[]{true, seqLen != null, prevOutputActivations != null, prevMemCellState != null, hasPeepholeConnections, true, true, true}).addIntegerArguments(new int[]{2, 0, this.activationToArg(gateActivationFn), this.activationToArg(a), this.activationToArg(a)}).build();
        List outShapes = op.calculateOutputShape();
        for (LongShapeDescriptor lsd : outShapes) {
            INDArray arr = workspaceMgr.createUninitialized(ArrayType.ACTIVATIONS, lsd.dataType(), lsd.getShape(), lsd.getOrder());
            op.addOutputArgument(new INDArray[]{arr});
        }
        FwdPassReturn f = new FwdPassReturn();
        f.fwdPassOutput = op.getOutputArgument(0);
        f.lastAct = op.getOutputArgument(1);
        f.lastMemCell = op.getOutputArgument(2);
        return f;
    }

    @Override
    public Map<String, Long> helperMemoryUse() {
        return Collections.emptyMap();
    }

    @Override
    public boolean checkSupported() {
        return BaseMKLDNNHelper.mklDnnEnabled();
    }

    private int activationToArg(IActivation a) {
        if (a instanceof ActivationTanH) {
            return 0;
        }
        if (a instanceof ActivationReLU) {
            return 1;
        }
        if (a instanceof ActivationSigmoid) {
            return 2;
        }
        if (a instanceof ActivationIdentity) {
            return 3;
        }
        if (a instanceof ActivationLReLU) {
            return 4;
        }
        if (a instanceof ActivationThresholdedReLU) {
            return 5;
        }
        if (a instanceof ActivationHardSigmoid) {
            return 7;
        }
        if (a instanceof ActivationELU) {
            return 8;
        }
        if (a instanceof ActivationSoftSign) {
            return 9;
        }
        if (a instanceof ActivationSoftPlus) {
            return 10;
        }
        throw new IllegalStateException("Unknown or not supported activation function: " + a);
    }
}

