/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers.setup;

import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;

public class ConvolutionLayerSetup {
    private int lastHeight = -1;
    private int lastWidth = -1;
    private int lastOutChannels = -1;
    private int numLayers = -1;
    private Map<Integer, int[]> outSizesEachLayer = new HashMap<Integer, int[]>();
    private Map<Integer, Integer> nInForLayer = new HashMap<Integer, Integer>();

    public ConvolutionLayerSetup(MultiLayerConfiguration.Builder conf, int height, int width, int channels) {
        this.lastHeight = height;
        this.lastWidth = width;
        this.lastOutChannels = channels;
        if (conf instanceof NeuralNetConfiguration.ListBuilder) {
            NeuralNetConfiguration.ListBuilder listBuilder = (NeuralNetConfiguration.ListBuilder)conf;
            this.numLayers = listBuilder.getLayerwise().size();
        } else {
            this.numLayers = conf.getConfs().size();
        }
        boolean alreadySet = false;
        for (int i = 0; i < this.numLayers; ++i) {
            int[] outWidthAndHeight;
            SubsamplingLayer subsamplingLayer;
            ConvolutionLayer nextConv;
            int nIn;
            int[] outWidthAndHeight2;
            Layer next;
            ConvolutionLayer convolutionLayer;
            alreadySet = false;
            Layer curr = this.getLayer(i, conf);
            if (i == 0 || i < this.numLayers - 2 && this.getLayer(i, conf) instanceof ConvolutionLayer) {
                convolutionLayer = (ConvolutionLayer)this.getLayer(i, conf);
                if (i == 0) {
                    convolutionLayer.setNIn(channels);
                }
                if ((next = this.getLayer(i + 1, conf)) instanceof DenseLayer || next instanceof OutputLayer) {
                    if (i > 0) {
                        outWidthAndHeight2 = this.getConvolutionOutputSize(new int[]{this.lastHeight, this.lastWidth}, convolutionLayer.getKernelSize(), convolutionLayer.getPadding(), convolutionLayer.getStride());
                        conf.inputPreProcessor(i + 1, new CnnToFeedForwardPreProcessor(outWidthAndHeight2[0], outWidthAndHeight2[1], convolutionLayer.getNOut()));
                    } else {
                        conf.inputPreProcessor(i + 1, new CnnToFeedForwardPreProcessor(height, width, convolutionLayer.getNOut()));
                    }
                    FeedForwardLayer o = (FeedForwardLayer)next;
                    int[] outWidthAndHeight3 = this.getConvolutionOutputSize(new int[]{this.lastHeight, this.lastWidth}, convolutionLayer.getKernelSize(), convolutionLayer.getPadding(), convolutionLayer.getStride());
                    this.outSizesEachLayer.put(i, outWidthAndHeight3);
                    int outRows = outWidthAndHeight3[0];
                    int outCols = outWidthAndHeight3[1];
                    this.lastHeight = outRows;
                    this.lastWidth = outCols;
                    this.lastOutChannels = convolutionLayer.getNOut();
                    nIn = outCols * outRows * convolutionLayer.getNOut();
                    this.nInForLayer.put(i, nIn);
                    o.setNIn(nIn);
                    alreadySet = true;
                } else if (next instanceof SubsamplingLayer) {
                    SubsamplingLayer subsamplingLayer2 = (SubsamplingLayer)next;
                    if (subsamplingLayer2.getPadding() == null) {
                        subsamplingLayer2.setPadding(convolutionLayer.getPadding());
                    }
                } else if (next instanceof ConvolutionLayer) {
                    nextConv = (ConvolutionLayer)next;
                    nextConv.setNIn(convolutionLayer.getNOut());
                }
            } else if (i < this.numLayers - 1 && this.getLayer(i, conf) instanceof SubsamplingLayer) {
                subsamplingLayer = (SubsamplingLayer)this.getLayer(i, conf);
                next = this.getLayer(i + 1, conf);
                if (next instanceof DenseLayer || next instanceof OutputLayer) {
                    outWidthAndHeight2 = this.getSubSamplingOutputSize(new int[]{this.lastHeight, this.lastWidth}, subsamplingLayer.getKernelSize(), subsamplingLayer.getStride());
                    this.outSizesEachLayer.put(i, outWidthAndHeight2);
                    int outRows = outWidthAndHeight2[0];
                    int outCols = outWidthAndHeight2[1];
                    this.lastHeight = outWidthAndHeight2[0];
                    this.lastWidth = outWidthAndHeight2[1];
                    conf.inputPreProcessor(i + 1, new CnnToFeedForwardPreProcessor(outRows, outCols, this.lastOutChannels));
                    FeedForwardLayer o = (FeedForwardLayer)next;
                    nIn = outCols * outRows * this.lastOutChannels;
                    o.setNIn(nIn);
                    this.nInForLayer.put(i + 1, nIn);
                    this.setFourDtoTwoD(i, conf, o);
                    alreadySet = true;
                } else if (next instanceof ConvolutionLayer) {
                    nextConv = (ConvolutionLayer)next;
                    nextConv.setNIn(this.lastOutChannels);
                }
            } else if (i < this.numLayers - 1 && (this.getLayer(i, conf) instanceof DenseLayer || this.getLayer(i, conf) instanceof OutputLayer)) {
                FeedForwardLayer forwardLayer = (FeedForwardLayer)this.getLayer(i, conf);
                if (this.getLayer(i + 1, conf) instanceof ConvolutionLayer) {
                    ConvolutionLayer convolutionLayer2 = (ConvolutionLayer)this.getLayer(i + 1, conf);
                    throw new UnsupportedOperationException("2d to 4d needs to be implemented");
                }
                if (this.getLayer(i + 1, conf) instanceof SubsamplingLayer) {
                    SubsamplingLayer subsamplingLayer3 = (SubsamplingLayer)this.getLayer(i + 1, conf);
                    throw new UnsupportedOperationException("2d to 4d needs to be implemented");
                }
                if (this.getLayer(i + 1, conf) instanceof OutputLayer || this.getLayer(i + 1, conf) instanceof DenseLayer) {
                    FeedForwardLayer d = (FeedForwardLayer)this.getLayer(i + 1, conf);
                    d.setNIn(forwardLayer.getNOut());
                    this.nInForLayer.put(i + 1, forwardLayer.getNOut());
                }
                this.setFourDtoTwoD(i, conf, forwardLayer);
            }
            if (curr instanceof ConvolutionLayer && i < this.numLayers - 1 && !alreadySet) {
                convolutionLayer = (ConvolutionLayer)curr;
                outWidthAndHeight = this.getConvolutionOutputSize(new int[]{this.lastHeight, this.lastWidth}, convolutionLayer.getKernelSize(), convolutionLayer.getPadding(), convolutionLayer.getStride());
                this.lastHeight = outWidthAndHeight[0];
                this.lastWidth = outWidthAndHeight[1];
                this.lastOutChannels = convolutionLayer.getNOut();
                this.outSizesEachLayer.put(i, outWidthAndHeight);
                continue;
            }
            if (!(curr instanceof SubsamplingLayer) || i >= this.numLayers - 1 || alreadySet) continue;
            subsamplingLayer = (SubsamplingLayer)curr;
            outWidthAndHeight = this.getSubSamplingOutputSize(new int[]{this.lastHeight, this.lastWidth}, subsamplingLayer.getKernelSize(), subsamplingLayer.getStride());
            this.lastHeight = outWidthAndHeight[0];
            this.lastWidth = outWidthAndHeight[1];
            this.outSizesEachLayer.put(i, outWidthAndHeight);
        }
        if (this.getLayer(this.numLayers - 1, conf) instanceof OutputLayer || this.getLayer(this.numLayers - 1, conf) instanceof DenseLayer) {
            FeedForwardLayer lastLayer = (FeedForwardLayer)this.getLayer(this.numLayers - 1, conf);
            if (this.getLayer(this.numLayers - 2, conf) instanceof DenseLayer || this.getLayer(this.numLayers - 2, conf) instanceof OutputLayer) {
                FeedForwardLayer feedForwardLayer = (FeedForwardLayer)this.getLayer(this.numLayers - 2, conf);
                lastLayer.setNIn(feedForwardLayer.getNOut());
                this.nInForLayer.put(this.numLayers - 1, feedForwardLayer.getNOut());
            } else if (this.getLayer(this.numLayers - 2, conf) instanceof SubsamplingLayer) {
                lastLayer.setNIn(this.lastHeight * this.lastWidth * this.lastOutChannels);
                this.nInForLayer.put(this.numLayers - 1, this.lastHeight * this.lastWidth * this.lastOutChannels);
            } else if (this.getLayer(this.numLayers - 2, conf) instanceof ConvolutionLayer) {
                lastLayer.setNIn(this.lastHeight * this.lastWidth * this.lastOutChannels);
                this.nInForLayer.put(this.numLayers - 1, this.lastHeight * this.lastWidth * this.lastOutChannels);
            }
        } else {
            if (this.getLayer(this.numLayers - 1, conf) instanceof ConvolutionLayer) {
                throw new UnsupportedOperationException("Unsupported path: final convolution layer");
            }
            if (this.getLayer(this.numLayers - 1, conf) instanceof SubsamplingLayer) {
                throw new UnsupportedOperationException("Unsupported path: final subsampling layer");
            }
        }
        if (conf instanceof NeuralNetConfiguration.ListBuilder) {
            NeuralNetConfiguration.ListBuilder l = (NeuralNetConfiguration.ListBuilder)conf;
            if (l.getLayerwise().get(0).getLayer() instanceof ConvolutionLayer || l.getLayerwise().get(0).getLayer() instanceof SubsamplingLayer) {
                conf.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(height, width, channels));
            }
        } else if (conf.getConfs().get(0).getLayer() instanceof ConvolutionLayer || conf.getConfs().get(0).getLayer() instanceof SubsamplingLayer) {
            conf.inputPreProcessor(0, new FeedForwardToCnnPreProcessor(height, width, channels));
        }
    }

    private int[] getSubSamplingOutputSize(int[] inputWidthAndHeight, int[] kernelWidthAndHeight, int[] stride) {
        int[] ret = new int[inputWidthAndHeight.length];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = kernelWidthAndHeight[i] == 1 ? inputWidthAndHeight[i] / stride[i] : (inputWidthAndHeight[i] - kernelWidthAndHeight[i]) / stride[i] + 1;
        }
        return ret;
    }

    private int[] getConvolutionOutputSize(int[] inputWidthAndHeight, int[] kernelWidthAndHeight, int[] padding, int[] stride) {
        int[] ret = new int[inputWidthAndHeight.length];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = (inputWidthAndHeight[i] - kernelWidthAndHeight[i] + 2 * padding[i]) / stride[i] + 1;
        }
        return ret;
    }

    public Layer getLayer(int i, MultiLayerConfiguration.Builder builder) {
        if (builder instanceof NeuralNetConfiguration.ListBuilder) {
            NeuralNetConfiguration.ListBuilder listBuilder = (NeuralNetConfiguration.ListBuilder)builder;
            if (listBuilder.getLayerwise().get(i) == null) {
                throw new IllegalStateException("Undefined layer " + i);
            }
            return listBuilder.getLayerwise().get(i).getLayer();
        }
        return builder.getConfs().get(i).getLayer();
    }

    private void setFourDtoTwoD(int i, MultiLayerConfiguration.Builder conf, FeedForwardLayer d) {
        Layer currFourdLayer;
        if (d instanceof ConvolutionLayer) {
            return;
        }
        Layer layer = currFourdLayer = conf instanceof NeuralNetConfiguration.ListBuilder ? ((NeuralNetConfiguration.ListBuilder)conf).getLayerwise().get(i).getLayer() : conf.getConfs().get(i).getLayer();
        if (currFourdLayer instanceof ConvolutionLayer || currFourdLayer instanceof SubsamplingLayer) {
            if (currFourdLayer instanceof ConvolutionLayer) {
                ConvolutionLayer convolutionLayer = (ConvolutionLayer)currFourdLayer;
                int inputHeight = this.lastHeight;
                int inputWidth = this.lastWidth;
                if (convolutionLayer.getKernelSize() == null) {
                    throw new IllegalStateException("Unable to infer width and height without convolution layer kernel size");
                }
                d.setNOut(inputHeight * inputWidth * convolutionLayer.getNOut());
                conf.inputPreProcessor(i + 1, new CnnToFeedForwardPreProcessor(inputHeight, inputWidth, this.lastOutChannels));
            } else if (currFourdLayer instanceof SubsamplingLayer) {
                int inputHeight = this.lastHeight;
                int inputWidth = this.lastWidth;
                conf.inputPreProcessor(i + 1, new CnnToFeedForwardPreProcessor(inputHeight, inputWidth, this.lastOutChannels));
            }
        }
    }

    public int getLastHeight() {
        return this.lastHeight;
    }

    public void setLastHeight(int lastHeight) {
        this.lastHeight = lastHeight;
    }

    public int getLastWidth() {
        return this.lastWidth;
    }

    public void setLastWidth(int lastWidth) {
        this.lastWidth = lastWidth;
    }

    public int getLastOutChannels() {
        return this.lastOutChannels;
    }

    public void setLastOutChannels(int lastOutChannels) {
        this.lastOutChannels = lastOutChannels;
    }

    public Map<Integer, int[]> getOutSizesEachLayer() {
        return this.outSizesEachLayer;
    }

    public void setOutSizesEachLayer(Map<Integer, int[]> outSizesEachLayer) {
        this.outSizesEachLayer = outSizesEachLayer;
    }

    public Map<Integer, Integer> getnInForLayer() {
        return this.nInForLayer;
    }

    public void setnInForLayer(Map<Integer, Integer> nInForLayer) {
        this.nInForLayer = nInForLayer;
    }
}

