/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.pooling;

import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.MaskedReductionUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.TransformOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.transforms.IsMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;

public class GlobalPoolingLayer
extends AbstractLayer<org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer> {
    private static final int[] DEFAULT_TIMESERIES_POOL_DIMS = new int[]{2};
    private static final int[] DEFAULT_CNN_POOL_DIMS = new int[]{2, 3};
    private final int[] poolingDimensions;
    private final boolean collapseDimensions;
    private final PoolingType poolingType;
    private final int pNorm;

    public GlobalPoolingLayer(NeuralNetConfiguration conf) {
        super(conf);
        org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer layerConf = (org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer)conf.getLayer();
        this.poolingDimensions = layerConf.getPoolingDimensions();
        this.collapseDimensions = layerConf.isCollapseDimensions();
        this.poolingType = layerConf.getPoolingType();
        this.pNorm = layerConf.getPnorm();
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public void clearNoiseWeightParams() {
    }

    @Override
    public double calcL2(boolean backpropParamsOnly) {
        return 0.0;
    }

    @Override
    public double calcL1(boolean backpropParamsOnly) {
        return 0.0;
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.SUBSAMPLING;
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        INDArray reduced2d;
        int[] poolDim;
        this.assertInputSet(false);
        if (this.input.rank() == 3) {
            poolDim = this.poolingDimensions == null ? DEFAULT_TIMESERIES_POOL_DIMS : this.poolingDimensions;
        } else if (this.input.rank() == 4) {
            poolDim = this.poolingDimensions == null ? DEFAULT_CNN_POOL_DIMS : this.poolingDimensions;
        } else {
            throw new UnsupportedOperationException("Received rank " + this.input.rank() + " input (shape = " + Arrays.toString(this.input.shape()) + "). Only rank 3 (time series) and rank 4 (images/CNN data) are currently supported for global pooling " + this.layerId());
        }
        if (this.maskArray == null) {
            reduced2d = this.activateHelperFullArray(this.input, poolDim);
        } else if (this.input.rank() == 3) {
            reduced2d = MaskedReductionUtil.maskedPoolingTimeSeries(this.poolingType, this.input, this.maskArray, this.pNorm);
        } else if (this.input.rank() == 4) {
            if (this.maskArray.rank() != 2) {
                throw new UnsupportedOperationException("Only 2d mask arrays are currently supported for masked global reductions on CNN data. Got 4d activations array (shape " + Arrays.toString(this.input.shape()) + ") and " + this.maskArray.rank() + "d mask array (shape " + Arrays.toString(this.maskArray.shape()) + ") " + this.layerId());
            }
            int h = this.input.size(2);
            int w = this.input.size(3);
            int maskLength = this.maskArray.size(1);
            if (h != 1 && w != 1 || h != maskLength && w != maskLength) {
                throw new UnsupportedOperationException("Masked global pooling with on CNN data currently only supports data with h=1 or w=1:\n input activations must have shape [minibatchSize,channels,height=1,width] or [minibatchSize,channels,height,width=1] with  mask array of shape [minibatchSize,width] or [minibatchSize,height] respectively.\n Got 4d activations array (shape " + Arrays.toString(this.input.shape()) + ") and " + this.maskArray.rank() + "d mask array (shape " + Arrays.toString(this.maskArray.shape()) + ") " + this.layerId());
            }
            if (DEFAULT_CNN_POOL_DIMS != poolDim && !Arrays.equals(DEFAULT_CNN_POOL_DIMS, poolDim)) {
                throw new UnsupportedOperationException("Masked global pooling with on CNN data currently only supports poolling over dimensions [2,3] (i.e., width and height - both required). Got pooling dimensions " + Arrays.toString(poolDim) + ") " + this.layerId());
            }
            boolean maskAlongHeight = h == maskLength;
            reduced2d = MaskedReductionUtil.maskedPoolingConvolution(this.poolingType, this.input, this.maskArray, maskAlongHeight, this.pNorm);
        } else {
            throw new UnsupportedOperationException("Invalid input: is rank " + this.input.rank() + " " + this.layerId());
        }
        if (this.collapseDimensions) {
            return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reduced2d);
        }
        int[] inputShape = this.input.shape();
        if (this.input.rank() == 3) {
            return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reduced2d.reshape(reduced2d.ordering(), new int[]{inputShape[0], inputShape[1], 1}));
        }
        return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reduced2d.reshape(reduced2d.ordering(), new int[]{inputShape[0], inputShape[1], 1, 1}));
    }

    @Override
    public Layer clone() {
        return new GlobalPoolingLayer(this.conf);
    }

    private INDArray activateHelperFullArray(INDArray inputArray, int[] poolDim) {
        switch (this.poolingType) {
            case MAX: {
                return inputArray.max(poolDim);
            }
            case AVG: {
                return inputArray.mean(poolDim);
            }
            case SUM: {
                return inputArray.sum(poolDim);
            }
            case PNORM: {
                int pnorm = ((org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer)this.layerConf()).getPnorm();
                INDArray abs = Transforms.abs((INDArray)inputArray, (boolean)true);
                Transforms.pow((INDArray)abs, (Number)pnorm, (boolean)false);
                INDArray pNorm = abs.sum(poolDim);
                return Transforms.pow((INDArray)pNorm, (Number)(1.0 / (double)pnorm), (boolean)false);
            }
        }
        throw new RuntimeException("Unknown or not supported pooling type: " + (Object)((Object)this.poolingType) + " " + this.layerId());
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        INDArray epsilonNd;
        this.assertInputSet(true);
        if (!this.collapseDimensions && epsilon.rank() != 2) {
            int[] origShape = epsilon.shape();
            epsilon = epsilon.reshape(epsilon.ordering(), origShape[0], origShape[1]);
        }
        DefaultGradient retGradient = new DefaultGradient();
        int[] poolDim = null;
        if (this.input.rank() == 3) {
            poolDim = this.poolingDimensions == null ? DEFAULT_TIMESERIES_POOL_DIMS : this.poolingDimensions;
        } else if (this.input.rank() == 4) {
            poolDim = this.poolingDimensions == null ? DEFAULT_CNN_POOL_DIMS : this.poolingDimensions;
        }
        if (this.maskArray == null) {
            epsilonNd = this.epsilonHelperFullArray(this.input, epsilon, poolDim);
        } else if (this.input.rank() == 3) {
            epsilonNd = MaskedReductionUtil.maskedPoolingEpsilonTimeSeries(this.poolingType, this.input, this.maskArray, epsilon, this.pNorm);
        } else if (this.input.rank() == 4) {
            int h = this.input.size(2);
            boolean maskAlongHeight = h == this.maskArray.size(1);
            epsilonNd = MaskedReductionUtil.maskedPoolingEpsilonCnn(this.poolingType, this.input, this.maskArray, epsilon, maskAlongHeight, this.pNorm);
        } else {
            throw new UnsupportedOperationException(this.layerId());
        }
        epsilonNd = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilonNd);
        return new Pair((Object)retGradient, (Object)epsilonNd);
    }

    private INDArray epsilonHelperFullArray(INDArray inputArray, INDArray epsilon, int[] poolDim) {
        int[] broadcastDims = new int[inputArray.rank() - poolDim.length];
        int count = 0;
        for (int i = 0; i < inputArray.rank(); ++i) {
            if (ArrayUtils.contains((int[])poolDim, (int)i)) continue;
            broadcastDims[count++] = i;
        }
        switch (this.poolingType) {
            case MAX: {
                INDArray isMax = Nd4j.getExecutioner().execAndReturn((TransformOp)new IsMax(inputArray.dup(), poolDim));
                return Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(isMax, epsilon, isMax, broadcastDims));
            }
            case AVG: {
                int n = 1;
                for (int d : poolDim) {
                    n *= inputArray.size(d);
                }
                INDArray ret = Nd4j.create((int[])inputArray.shape());
                Nd4j.getExecutioner().exec((Op)new BroadcastCopyOp(ret, epsilon, ret, broadcastDims));
                ret.divi((Number)n);
                return ret;
            }
            case SUM: {
                INDArray retSum = Nd4j.create((int[])inputArray.shape());
                Nd4j.getExecutioner().exec((Op)new BroadcastCopyOp(retSum, epsilon, retSum, broadcastDims));
                return retSum;
            }
            case PNORM: {
                INDArray numerator;
                int pnorm = ((org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer)this.layerConf()).getPnorm();
                INDArray abs = Transforms.abs((INDArray)inputArray, (boolean)true);
                Transforms.pow((INDArray)abs, (Number)pnorm, (boolean)false);
                INDArray pNorm = Transforms.pow((INDArray)abs.sum(poolDim), (Number)(1.0 / (double)pnorm));
                if (pnorm == 2) {
                    numerator = inputArray.dup();
                } else {
                    INDArray absp2 = Transforms.pow((INDArray)Transforms.abs((INDArray)inputArray, (boolean)true), (Number)(pnorm - 2), (boolean)false);
                    numerator = inputArray.mul(absp2);
                }
                INDArray denom = Transforms.pow((INDArray)pNorm, (Number)(pnorm - 1), (boolean)false);
                denom.rdivi(epsilon);
                Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(numerator, denom, numerator, broadcastDims));
                return numerator;
            }
        }
        throw new RuntimeException("Unknown or not supported pooling type: " + (Object)((Object)this.poolingType) + " " + this.layerId());
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
        this.maskArray = maskArray;
        this.maskState = null;
        return null;
    }
}

