/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers;

import java.util.Arrays;
import java.util.Collection;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BaseUpsamplingLayer;
import org.deeplearning4j.nn.conf.layers.InputTypeUtil;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;

public class Upsampling3D
extends BaseUpsamplingLayer {
    protected int[] size;

    protected Upsampling3D(BaseUpsamplingLayer.UpsamplingBuilder builder) {
        super(builder);
        this.size = builder.size;
    }

    @Override
    public Upsampling3D clone() {
        return (Upsampling3D)super.clone();
    }

    @Override
    public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> iterationListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams) {
        org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling3D ret = new org.deeplearning4j.nn.layers.convolution.upsampling.Upsampling3D(conf);
        ret.setListeners(iterationListeners);
        ret.setIndex(layerIndex);
        ret.setParamsViewArray(layerParamsView);
        Map<String, INDArray> paramTable = this.initializer().init(conf, layerParamsView, initializeParams);
        ret.setParamTable(paramTable);
        ret.setConf(conf);
        return ret;
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        if (inputType == null || inputType.getType() != InputType.Type.CNN3D) {
            throw new IllegalStateException("Invalid input for Upsampling 3D layer (layer name=\"" + this.getLayerName() + "\"): Expected CNN3D input, got " + inputType);
        }
        InputType.InputTypeConvolutional3D i = (InputType.InputTypeConvolutional3D)inputType;
        int inHeight = i.getHeight();
        int inWidth = i.getWidth();
        int inDepth = i.getDepth();
        int inChannels = i.getChannels();
        return InputType.convolutional3D(this.size[0] * inDepth, this.size[1] * inHeight, this.size[2] * inWidth, inChannels);
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        if (inputType == null) {
            throw new IllegalStateException("Invalid input for Upsampling 3D layer (layer name=\"" + this.getLayerName() + "\"): input is null");
        }
        return InputTypeUtil.getPreProcessorForInputTypeCnn3DLayers(inputType, this.getLayerName());
    }

    @Override
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        int im2colSizePerEx;
        InputType.InputTypeConvolutional3D c = (InputType.InputTypeConvolutional3D)inputType;
        InputType.InputTypeConvolutional3D outputType = (InputType.InputTypeConvolutional3D)this.getOutputType(-1, inputType);
        int trainingWorkingSizePerEx = im2colSizePerEx = c.getChannels() & outputType.getDepth() * outputType.getHeight() * outputType.getWidth() * this.size[0] * this.size[1] * this.size[2];
        if (this.getIDropout() != null) {
            trainingWorkingSizePerEx += inputType.arrayElementsPerExample();
        }
        return new LayerMemoryReport.Builder(this.layerName, Upsampling3D.class, inputType, outputType).standardMemory(0L, 0L).workingMemory(0L, (long)im2colSizePerEx, 0L, trainingWorkingSizePerEx).cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS).build();
    }

    @Override
    public int[] getSize() {
        return this.size;
    }

    @Override
    public void setSize(int[] size) {
        this.size = size;
    }

    public Upsampling3D() {
    }

    @Override
    public String toString() {
        return "Upsampling3D(super=" + super.toString() + ", size=" + Arrays.toString(this.getSize()) + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof Upsampling3D)) {
            return false;
        }
        Upsampling3D other = (Upsampling3D)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        return Arrays.equals(this.getSize(), other.getSize());
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof Upsampling3D;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        result = result * 59 + Arrays.hashCode(this.getSize());
        return result;
    }

    public static class Builder
    extends BaseUpsamplingLayer.UpsamplingBuilder<Builder> {
        public Builder(int size) {
            super(new int[]{size, size, size});
        }

        public Builder size(int size) {
            this.size = new int[]{size, size, size};
            return this;
        }

        public Builder size(int[] size) {
            Preconditions.checkArgument((size.length == 3 ? 1 : 0) != 0);
            this.size = size;
            return this;
        }

        @Override
        public Upsampling3D build() {
            return new Upsampling3D(this);
        }

        public Builder() {
        }
    }
}

