/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.layers;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasConvolution
extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasConvolution.class);
    public static final int NUM_TRAINABLE_PARAMS = 2;
    public static final String KERAS_PARAM_NAME_W = "W";
    public static final String KERAS_PARAM_NAME_B = "b";

    public KerasConvolution(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(layerConfig, true);
    }

    public KerasConvolution(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        super(layerConfig, enforceTrainingConfig);
        ConvolutionLayer.Builder builder = new ConvolutionLayer.Builder().name(this.layerName).nOut(KerasConvolution.getNOutFromConfig(layerConfig)).dropOut(this.dropout).activation(this.getActivationFromConfig(layerConfig)).weightInit(this.getWeightInitFromConfig(layerConfig, enforceTrainingConfig)).biasInit(0.0).l1(this.weightL1Regularization).l2(this.weightL2Regularization).convolutionMode(KerasConvolution.getConvolutionModeFromConfig(layerConfig)).kernelSize(KerasConvolution.getKernelSizeFromConfig(layerConfig)).stride(KerasConvolution.getStrideFromConfig(layerConfig));
        int[] padding = this.getPaddingFromBorderModeConfig(layerConfig);
        if (padding != null) {
            builder.padding(padding);
        }
        this.layer = builder.build();
    }

    public ConvolutionLayer getConvolutionLayer() {
        return (ConvolutionLayer)this.layer;
    }

    @Override
    public InputType getOutputType(InputType ... inputType) throws InvalidKerasConfigurationException {
        if (inputType.length > 1) {
            throw new InvalidKerasConfigurationException("Keras Convolution layer accepts only one input (received " + inputType.length + ")");
        }
        return this.getConvolutionLayer().getOutputType(-1, inputType[0]);
    }

    @Override
    public int getNumParams() {
        return 2;
    }

    @Override
    public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
        INDArray paramValue;
        this.weights = new HashMap();
        if (weights.containsKey(KERAS_PARAM_NAME_W)) {
            INDArray kerasParamValue = weights.get(KERAS_PARAM_NAME_W);
            switch (this.getDimOrder()) {
                case TENSORFLOW: {
                    paramValue = kerasParamValue.permute(new int[]{3, 2, 0, 1});
                    break;
                }
                case THEANO: {
                    paramValue = kerasParamValue.dup();
                    for (int i = 0; i < paramValue.tensorssAlongDimension(new int[]{2, 3}); ++i) {
                        INDArray copyFilter = paramValue.tensorAlongDimension(i, new int[]{2, 3}).dup();
                        double[] flattenedFilter = copyFilter.ravel().data().asDouble();
                        ArrayUtils.reverse((double[])flattenedFilter);
                        INDArray newFilter = Nd4j.create((double[])flattenedFilter, (int[])copyFilter.shape());
                        INDArray inPlaceFilter = paramValue.tensorAlongDimension(i, new int[]{2, 3});
                        inPlaceFilter.muli((Number)0).addi(newFilter);
                    }
                    break;
                }
                default: {
                    throw new InvalidKerasConfigurationException("Unknown keras backend " + (Object)((Object)this.getDimOrder()));
                }
            }
        } else {
            throw new InvalidKerasConfigurationException("Parameter W does not exist in weights");
        }
        this.weights.put(KERAS_PARAM_NAME_W, paramValue);
        if (!weights.containsKey(KERAS_PARAM_NAME_B)) {
            throw new InvalidKerasConfigurationException("Parameter b does not exist in weights");
        }
        this.weights.put(KERAS_PARAM_NAME_B, weights.get(KERAS_PARAM_NAME_B));
        if (weights.size() > 2) {
            Set<String> paramNames = weights.keySet();
            paramNames.remove(KERAS_PARAM_NAME_W);
            paramNames.remove(KERAS_PARAM_NAME_B);
            String unknownParamNames = paramNames.toString();
            log.warn("Attemping to set weights for unknown parameters: " + unknownParamNames.substring(1, unknownParamNames.length() - 1));
        }
    }
}

