/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.convolution;

import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.RNNFormat;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.Convolution1DUtils;
import org.deeplearning4j.util.ConvolutionUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv1DDerivative;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Broadcast;
import org.nd4j.linalg.factory.Nd4j;

public class Convolution1DLayer
extends ConvolutionLayer {
    public Convolution1DLayer(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        INDArray[] outputArrs;
        INDArray[] inputArrs;
        this.assertInputSet(true);
        if (epsilon.rank() != 3) {
            throw new DL4JInvalidInputException("Got rank " + epsilon.rank() + " array as epsilon for Convolution1DLayer backprop with shape " + Arrays.toString(epsilon.shape()) + ". Expected rank 3 array with shape [minibatchSize, features, length]. " + this.layerId());
        }
        Pair<INDArray, INDArray> fwd = this.preOutput(false, true, workspaceMgr);
        IActivation afn = this.layerConf().getActivationFn();
        INDArray delta = (INDArray)afn.backprop((INDArray)fwd.getFirst(), epsilon).getFirst();
        org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = this.layerConf();
        Conv1DConfig conf = Conv1DConfig.builder().k((long)c.getKernelSize()[0]).s((long)c.getStride()[0]).d((long)c.getDilation()[0]).p((long)c.getPadding()[0]).dataFormat("NCW").paddingMode(ConvolutionUtils.paddingModeForConvolutionMode(this.convolutionMode)).build();
        INDArray w = Convolution1DUtils.reshapeWeightArrayOrGradientForFormat(this.getParam("W"), RNNFormat.NCW);
        INDArray wg = Convolution1DUtils.reshapeWeightArrayOrGradientForFormat((INDArray)this.gradientViews.get("W"), this.getRnnDataFormat());
        INDArray epsOut = workspaceMgr.createUninitialized(ArrayType.ACTIVATION_GRAD, this.input.dataType(), this.input.shape());
        INDArray input = this.input.castTo(this.dataType);
        if (this.layerConf().getRnnDataFormat() == RNNFormat.NWC) {
            input = input.permute(new int[]{0, 2, 1});
        }
        if (this.layerConf().hasBias()) {
            INDArray b = this.getParam("b");
            b = b.reshape(new long[]{b.length()});
            inputArrs = new INDArray[]{input, w, b, delta};
            INDArray bg = (INDArray)this.gradientViews.get("b");
            bg = bg.reshape(new long[]{bg.length()});
            outputArrs = new INDArray[]{epsOut, wg, bg};
        } else {
            inputArrs = new INDArray[]{input, w, delta};
            outputArrs = new INDArray[]{epsOut, wg};
        }
        Conv1DDerivative op = new Conv1DDerivative(inputArrs, outputArrs, conf);
        Nd4j.exec((CustomOp)op);
        DefaultGradient retGradient = new DefaultGradient();
        if (this.layerConf().hasBias()) {
            retGradient.setGradientFor("b", (INDArray)this.gradientViews.get("b"));
        }
        retGradient.setGradientFor("W", (INDArray)this.gradientViews.get("W"), Character.valueOf('c'));
        if (this.getRnnDataFormat() == RNNFormat.NWC) {
            epsOut = epsOut.permute(new int[]{0, 2, 1});
        }
        return new Pair((Object)retGradient, (Object)epsOut);
    }

    @Override
    protected Pair<INDArray, INDArray> preOutput4d(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) {
        Pair<INDArray, INDArray> preOutput = super.preOutput(true, forBackprop, workspaceMgr);
        INDArray p3d = (INDArray)preOutput.getFirst();
        INDArray p = ((INDArray)preOutput.getFirst()).reshape(new long[]{p3d.size(0), p3d.size(1), p3d.size(2), 1L});
        preOutput.setFirst((Object)p);
        return preOutput;
    }

    @Override
    protected Pair<INDArray, INDArray> preOutput(boolean training, boolean forBackprop, LayerWorkspaceMgr workspaceMgr) {
        INDArray[] inputs;
        this.assertInputSet(false);
        INDArray input = this.input.castTo(this.dataType);
        if (this.layerConf().getRnnDataFormat() == RNNFormat.NWC) {
            if (input.rank() == 3) {
                input = input.permute(new int[]{0, 2, 1});
            } else if (input.rank() == 4) {
                input = input.permute(new int[]{0, 2, 3, 1});
            }
        }
        org.deeplearning4j.nn.conf.layers.Convolution1DLayer c = this.layerConf();
        Conv1DConfig conf = Conv1DConfig.builder().k((long)c.getKernelSize()[0]).s((long)c.getStride()[0]).d((long)c.getDilation()[0]).p((long)c.getPadding()[0]).dataFormat("NCW").paddingMode(ConvolutionUtils.paddingModeForConvolutionMode(this.convolutionMode)).build();
        INDArray w = Convolution1DUtils.reshapeWeightArrayOrGradientForFormat(this.getParam("W"), RNNFormat.NCW);
        if (this.layerConf().hasBias()) {
            INDArray b = this.getParam("b");
            b = b.reshape(new long[]{b.length()});
            inputs = new INDArray[]{input, w, b};
        } else {
            inputs = new INDArray[]{input, w};
        }
        Conv1D op = new Conv1D(inputs, null, conf);
        List outShape = op.calculateOutputShape();
        op.setOutputArgument(0, Nd4j.create((LongShapeDescriptor)((LongShapeDescriptor)outShape.get(0)), (boolean)false));
        Nd4j.exec((CustomOp)op);
        INDArray output = op.getOutputArgument(0);
        if (this.getRnnDataFormat() == RNNFormat.NWC) {
            output = output.permute(new int[]{0, 2, 1});
        }
        return new Pair((Object)output, null);
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        INDArray act3d;
        INDArray act4d = super.activate(training, workspaceMgr);
        INDArray iNDArray = act3d = act4d.rank() > 3 ? act4d.reshape(new long[]{act4d.size(0), act4d.size(1), act4d.size(2)}) : act4d;
        if (this.maskArray != null) {
            INDArray maskOut = (INDArray)this.feedForwardMaskArray(this.maskArray, MaskState.Active, (int)act3d.size(0)).getFirst();
            Preconditions.checkState((act3d.size(0) == maskOut.size(0) && act3d.size(2) == maskOut.size(1) ? 1 : 0) != 0, (String)"Activations dimensions (0,2) and mask dimensions (0,1) don't match: Activations %s, Mask %s", (Object)act3d.shape(), (Object)maskOut.shape());
            Broadcast.mul((INDArray)act3d, (INDArray)maskOut, (INDArray)act3d, (int[])new int[]{0, 2});
        }
        return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, act3d);
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
        INDArray reduced = ConvolutionUtils.cnn1dMaskReduction(maskArray, this.layerConf().getKernelSize()[0], this.layerConf().getStride()[0], this.layerConf().getPadding()[0], this.layerConf().getDilation()[0], this.layerConf().getConvolutionMode());
        return new Pair((Object)reduced, (Object)currentMaskState);
    }

    @Override
    public org.deeplearning4j.nn.conf.layers.Convolution1DLayer layerConf() {
        return (org.deeplearning4j.nn.conf.layers.Convolution1DLayer)this.conf().getLayer();
    }

    private RNNFormat getRnnDataFormat() {
        return this.layerConf().getRnnDataFormat();
    }
}

