/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras.preprocessors;

import org.deeplearning4j.nn.conf.preprocessor.CnnToFeedForwardPreProcessor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.shade.jackson.annotation.JsonCreator;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class TensorFlowCnnToFeedForwardPreProcessor
extends CnnToFeedForwardPreProcessor {
    private static final Logger log = LoggerFactory.getLogger(TensorFlowCnnToFeedForwardPreProcessor.class);

    @JsonCreator
    public TensorFlowCnnToFeedForwardPreProcessor(@JsonProperty(value="inputHeight") int inputHeight, @JsonProperty(value="inputWidth") int inputWidth, @JsonProperty(value="numChannels") int numChannels) {
        super(inputHeight, inputWidth, numChannels);
    }

    public TensorFlowCnnToFeedForwardPreProcessor(int inputHeight, int inputWidth) {
        super(inputHeight, inputWidth);
    }

    public TensorFlowCnnToFeedForwardPreProcessor() {
    }

    public INDArray preProcess(INDArray input, int miniBatchSize) {
        if (input.rank() == 2) {
            return input;
        }
        INDArray flatInput = super.preProcess(input, miniBatchSize);
        INDArray permuted = input.permute(new int[]{0, 2, 3, 1});
        INDArray flatPermuted = super.preProcess(permuted, miniBatchSize);
        return flatPermuted;
    }

    public INDArray backprop(INDArray epsilons, int miniBatchSize) {
        INDArray epsilonsReshaped = super.backprop(epsilons, miniBatchSize);
        return epsilonsReshaped.permute(new int[]{0, 3, 1, 2});
    }

    public TensorFlowCnnToFeedForwardPreProcessor clone() {
        return (TensorFlowCnnToFeedForwardPreProcessor)super.clone();
    }
}

