/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.preprocessor;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Arrays;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.nd4j.linalg.api.ndarray.INDArray;

public class CnnToRnnPreProcessor
implements InputPreProcessor {
    private int inputHeight;
    private int inputWidth;
    private int numChannels;
    private int product;

    @JsonCreator
    public CnnToRnnPreProcessor(@JsonProperty(value="inputHeight") int inputHeight, @JsonProperty(value="inputWidth") int inputWidth, @JsonProperty(value="numChannels") int numChannels) {
        this.inputHeight = inputHeight;
        this.inputWidth = inputWidth;
        this.numChannels = numChannels;
        this.product = inputHeight * inputWidth * numChannels;
    }

    @Override
    public INDArray preProcess(INDArray input, Layer layer) {
        if (input.rank() != 4) {
            throw new IllegalArgumentException("Invalid input: expect CNN activations with rank 4 (received input with shape " + Arrays.toString(input.shape()) + ")");
        }
        int[] shape = input.shape();
        int miniBatchSize = layer.getInputMiniBatchSize();
        INDArray reshaped = input.reshape(new int[]{miniBatchSize, shape[0] / miniBatchSize, this.product});
        return reshaped.permute(new int[]{0, 2, 1});
    }

    @Override
    public INDArray backprop(INDArray output, Layer layer) {
        INDArray output2d;
        int[] shape = output.shape();
        if (shape[0] == 1) {
            output2d = output.tensorAlongDimension(0, new int[]{1, 2});
        } else if (shape[2] == 1) {
            output2d = output.tensorAlongDimension(0, new int[]{1, 0});
        } else {
            INDArray permuted3d = output.permute(new int[]{0, 2, 1});
            output2d = permuted3d.reshape(shape[0] * shape[2], shape[1]);
        }
        if (shape[1] != this.product) {
            throw new IllegalArgumentException("Invalid input: expected output size(1)=" + shape[1] + " must be equal to " + this.inputHeight + " x columns " + this.inputWidth + " x depth " + this.numChannels + " = " + this.product + ", received: " + shape[1]);
        }
        return output2d.reshape(new int[]{output2d.size(0), this.numChannels, this.inputHeight, this.inputWidth});
    }

    @Override
    public CnnToRnnPreProcessor clone() {
        return new CnnToRnnPreProcessor(this.inputHeight, this.inputWidth, this.numChannels);
    }

    public int getInputHeight() {
        return this.inputHeight;
    }

    public int getInputWidth() {
        return this.inputWidth;
    }

    public int getNumChannels() {
        return this.numChannels;
    }

    public int getProduct() {
        return this.product;
    }

    public void setInputHeight(int inputHeight) {
        this.inputHeight = inputHeight;
    }

    public void setInputWidth(int inputWidth) {
        this.inputWidth = inputWidth;
    }

    public void setNumChannels(int numChannels) {
        this.numChannels = numChannels;
    }

    public void setProduct(int product) {
        this.product = product;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof CnnToRnnPreProcessor)) {
            return false;
        }
        CnnToRnnPreProcessor other = (CnnToRnnPreProcessor)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (this.getInputHeight() != other.getInputHeight()) {
            return false;
        }
        if (this.getInputWidth() != other.getInputWidth()) {
            return false;
        }
        if (this.getNumChannels() != other.getNumChannels()) {
            return false;
        }
        return this.getProduct() == other.getProduct();
    }

    protected boolean canEqual(Object other) {
        return other instanceof CnnToRnnPreProcessor;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        result = result * 59 + this.getInputHeight();
        result = result * 59 + this.getInputWidth();
        result = result * 59 + this.getNumChannels();
        result = result * 59 + this.getProduct();
        return result;
    }

    public String toString() {
        return "CnnToRnnPreProcessor(inputHeight=" + this.getInputHeight() + ", inputWidth=" + this.getInputWidth() + ", numChannels=" + this.getNumChannels() + ", product=" + this.getProduct() + ")";
    }
}

