/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.convolution;

import java.util.Arrays;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.ConvolutionMode;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.nn.layers.convolution.ConvolutionHelper;
import org.deeplearning4j.util.ConvolutionUtils;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.convolution.Convolution;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ConvolutionLayer
extends BaseLayer<org.deeplearning4j.nn.conf.layers.ConvolutionLayer> {
    protected static final Logger log = LoggerFactory.getLogger(ConvolutionLayer.class);
    protected ConvolutionHelper helper = null;
    protected ConvolutionMode convolutionMode;

    public ConvolutionLayer(NeuralNetConfiguration conf) {
        super(conf);
        this.initializeHelper();
        this.convolutionMode = ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.conf().getLayer()).getConvolutionMode();
    }

    public ConvolutionLayer(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
        this.initializeHelper();
    }

    void initializeHelper() {
        block2: {
            try {
                this.helper = Class.forName("org.deeplearning4j.nn.layers.convolution.CudnnConvolutionHelper").asSubclass(ConvolutionHelper.class).newInstance();
                log.debug("CudnnConvolutionHelper successfully loaded");
            }
            catch (Throwable t) {
                if (t instanceof ClassNotFoundException) break block2;
                log.warn("Could not load CudnnConvolutionHelper", t);
            }
        }
    }

    @Override
    public double calcL2() {
        if (!this.conf.isUseRegularization() || this.conf.getLayer().getL2() <= 0.0) {
            return 0.0;
        }
        double l2Norm = this.getParam("W").norm2Number().doubleValue();
        return 0.5 * this.conf.getLayer().getL2() * l2Norm * l2Norm;
    }

    @Override
    public double calcL1() {
        if (!this.conf.isUseRegularization() || this.conf.getLayer().getL1() <= 0.0) {
            return 0.0;
        }
        return this.conf.getLayer().getL1() * this.getParam("W").norm1Number().doubleValue();
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.CONVOLUTIONAL;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        Pair<Gradient, INDArray> ret;
        INDArray delta;
        int[] pad;
        int[] outSize;
        INDArray weights = this.getParam("W");
        int miniBatch = this.input.size(0);
        int inH = this.input.size(2);
        int inW = this.input.size(3);
        int outDepth = weights.size(0);
        int inDepth = weights.size(1);
        int kH = weights.size(2);
        int kW = weights.size(3);
        int[] kernel = ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getKernelSize();
        int[] strides = ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getStride();
        if (this.convolutionMode == ConvolutionMode.Same) {
            outSize = ConvolutionUtils.getOutputSize(this.input, kernel, strides, null, this.convolutionMode);
            pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[]{inH, inW}, kernel, strides);
        } else {
            pad = ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getPadding();
            outSize = ConvolutionUtils.getOutputSize(this.input, kernel, strides, pad, this.convolutionMode);
        }
        int outH = outSize[0];
        int outW = outSize[1];
        INDArray biasGradView = (INDArray)this.gradientViews.get("b");
        INDArray weightGradView = (INDArray)this.gradientViews.get("W");
        INDArray weightGradView2df = Shape.newShapeNoCopy((INDArray)weightGradView, (int[])new int[]{outDepth, inDepth * kH * kW}, (boolean)false).transpose();
        String afn = this.conf.getLayer().getActivationFunction();
        if ("identity".equals(afn)) {
            delta = epsilon;
        } else {
            INDArray sigmaPrimeZ = this.preOutput(true);
            Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(afn, sigmaPrimeZ, this.conf.getExtraArgs()).derivative());
            delta = sigmaPrimeZ.muli(epsilon);
        }
        if (this.helper != null && Nd4j.dataType() != DataBuffer.Type.HALF && (ret = this.helper.backpropGradient(this.input, weights, delta, kernel, strides, pad, biasGradView, weightGradView, afn, ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getCudnnAlgoMode(), this.convolutionMode)) != null) {
            return ret;
        }
        delta = delta.permute(new int[]{1, 0, 2, 3});
        INDArray delta2d = delta.reshape('c', new int[]{outDepth, miniBatch * outH * outW});
        INDArray col = Nd4j.createUninitialized((int[])new int[]{miniBatch, outH, outW, inDepth, kH, kW}, (char)'c');
        INDArray col2 = col.permute(new int[]{0, 3, 4, 5, 1, 2});
        Convolution.im2col((INDArray)this.input, (int)kH, (int)kW, (int)strides[0], (int)strides[1], (int)pad[0], (int)pad[1], (this.convolutionMode == ConvolutionMode.Same ? 1 : 0) != 0, (INDArray)col2);
        INDArray im2col2d = col.reshape('c', miniBatch * outH * outW, inDepth * kH * kW);
        Nd4j.gemm((INDArray)im2col2d, (INDArray)delta2d, (INDArray)weightGradView2df, (boolean)true, (boolean)true, (double)1.0, (double)0.0);
        INDArray wPermuted = weights.permute(new int[]{3, 2, 1, 0});
        INDArray w2d = wPermuted.reshape('f', inDepth * kH * kW, outDepth);
        INDArray epsNext2d = w2d.mmul(delta2d);
        INDArray eps6d = Shape.newShapeNoCopy((INDArray)epsNext2d, (int[])new int[]{kW, kH, inDepth, outW, outH, miniBatch}, (boolean)true);
        eps6d = eps6d.permute(new int[]{5, 2, 1, 0, 4, 3});
        INDArray epsNextOrig = Nd4j.create((int[])new int[]{inDepth, miniBatch, inH, inW}, (char)'c');
        INDArray epsNext = epsNextOrig.permute(new int[]{1, 0, 2, 3});
        Convolution.col2im((INDArray)eps6d, (INDArray)epsNext, (int)strides[0], (int)strides[1], (int)pad[0], (int)pad[1], (int)inH, (int)inW);
        DefaultGradient retGradient = new DefaultGradient();
        INDArray biasGradTemp = delta2d.sum(new int[]{1});
        biasGradView.assign(biasGradTemp);
        retGradient.setGradientFor("b", biasGradView);
        retGradient.setGradientFor("W", weightGradView, Character.valueOf('c'));
        return new Pair<Gradient, INDArray>(retGradient, epsNext);
    }

    @Override
    public INDArray preOutput(boolean training) {
        INDArray ret;
        int[] pad;
        int[] outSize;
        INDArray weights = this.getParam("W");
        INDArray bias = this.getParam("b");
        if (this.conf.isUseDropConnect() && training && this.conf.getLayer().getDropOut() > 0.0) {
            weights = Dropout.applyDropConnect(this, "W");
        }
        if (this.input.rank() != 4) {
            String layerName = this.conf.getLayer().getLayerName();
            if (layerName == null) {
                layerName = "(not named)";
            }
            throw new DL4JInvalidInputException("Got rank " + this.input.rank() + " array as input to ConvolutionLayer (layer name = " + layerName + ", layer index = " + this.index + ") with shape " + Arrays.toString(this.input.shape()) + ". Expected rank 4 array with shape [minibatchSize, layerInputDepth, inputHeight, inputWidth]." + (this.input.rank() == 2 ? " (Wrong input type (see InputType.convolutionalFlat()) or wrong data type?)" : ""));
        }
        int miniBatch = this.input.size(0);
        int outDepth = weights.size(0);
        int inDepth = weights.size(1);
        if (this.input.size(1) != inDepth) {
            String layerName = this.conf.getLayer().getLayerName();
            if (layerName == null) {
                layerName = "(not named)";
            }
            throw new DL4JInvalidInputException("Cannot do forward pass in Convolution layer (layer name = " + layerName + ", layer index = " + this.index + "): input array depth does not match CNN layer configuration (data input depth = " + this.input.size(1) + ", [minibatch,inputDepth,height,width]=" + Arrays.toString(this.input.shape()) + "; expected input depth = " + inDepth + ")");
        }
        int kH = weights.size(2);
        int kW = weights.size(3);
        int[] kernel = ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getKernelSize();
        int[] strides = ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getStride();
        if (this.convolutionMode == ConvolutionMode.Same) {
            outSize = ConvolutionUtils.getOutputSize(this.input, kernel, strides, null, this.convolutionMode);
            pad = ConvolutionUtils.getSameModeTopLeftPadding(outSize, new int[]{this.input.size(2), this.input.size(3)}, kernel, strides);
        } else {
            pad = ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getPadding();
            outSize = ConvolutionUtils.getOutputSize(this.input, kernel, strides, pad, this.convolutionMode);
        }
        int outH = outSize[0];
        int outW = outSize[1];
        if (this.helper != null && Nd4j.dataType() != DataBuffer.Type.HALF && (ret = this.helper.preOutput(this.input, weights, bias, kernel, strides, pad, ((org.deeplearning4j.nn.conf.layers.ConvolutionLayer)this.layerConf()).getCudnnAlgoMode(), this.convolutionMode)) != null) {
            return ret;
        }
        INDArray col = Nd4j.createUninitialized((int[])new int[]{miniBatch, outH, outW, inDepth, kH, kW}, (char)'c');
        INDArray col2 = col.permute(new int[]{0, 3, 4, 5, 1, 2});
        Convolution.im2col((INDArray)this.input, (int)kH, (int)kW, (int)strides[0], (int)strides[1], (int)pad[0], (int)pad[1], (this.convolutionMode == ConvolutionMode.Same ? 1 : 0) != 0, (INDArray)col2);
        INDArray reshapedCol = Shape.newShapeNoCopy((INDArray)col, (int[])new int[]{miniBatch * outH * outW, inDepth * kH * kW}, (boolean)false);
        INDArray permutedW = weights.permute(new int[]{3, 2, 1, 0});
        INDArray reshapedW = permutedW.reshape('f', kW * kH * inDepth, outDepth);
        INDArray z = reshapedCol.mmul(reshapedW);
        z.addiRowVector(bias);
        z = Shape.newShapeNoCopy((INDArray)z, (int[])new int[]{outW, outH, miniBatch, outDepth}, (boolean)true);
        return z.permute(new int[]{2, 3, 1, 0});
    }

    @Override
    public INDArray activate(boolean training) {
        INDArray ret;
        if (this.input == null) {
            throw new IllegalArgumentException("No null input allowed");
        }
        this.applyDropOutIfNecessary(training);
        INDArray z = this.preOutput(training);
        String afn = this.conf.getLayer().getActivationFunction();
        if ("identity".equals(afn)) {
            return z;
        }
        if (this.helper != null && Nd4j.dataType() != DataBuffer.Type.HALF && (ret = this.helper.activate(z, this.conf.getLayer().getActivationFunction())) != null) {
            return ret;
        }
        INDArray activation = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform(afn, z));
        return activation;
    }

    @Override
    public Layer transpose() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public Gradient calcGradient(Gradient layerError, INDArray indArray) {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public void fit(INDArray input) {
    }

    @Override
    public void merge(Layer layer, int batchSize) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray params() {
        return Nd4j.toFlattened((char)'c', this.params.values());
    }

    @Override
    public void setParams(INDArray params) {
        this.setParams(params, 'c');
    }
}

