/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.layers;

import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.CNN2DFormat;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.LayerValidation;
import org.deeplearning4j.nn.conf.memory.LayerMemoryReport;
import org.deeplearning4j.nn.conf.memory.MemoryReport;
import org.deeplearning4j.nn.conf.preprocessor.FeedForwardToCnnPreProcessor;
import org.deeplearning4j.nn.params.BatchNormalizationParamInitializer;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.config.IUpdater;
import org.nd4j.linalg.learning.config.NoOp;
import org.nd4j.linalg.learning.regularization.Regularization;

public class BatchNormalization
extends FeedForwardLayer {
    protected double decay = 0.9;
    protected double eps = 1.0E-5;
    protected boolean isMinibatch = true;
    protected double gamma = 1.0;
    protected double beta = 0.0;
    protected boolean lockGammaBeta = false;
    protected boolean cudnnAllowFallback = true;
    protected boolean useLogStd = false;
    protected CNN2DFormat cnn2DFormat = CNN2DFormat.NCHW;

    private BatchNormalization(Builder builder) {
        super(builder);
        this.decay = builder.decay;
        this.eps = builder.eps;
        this.isMinibatch = builder.isMinibatch;
        this.gamma = builder.gamma;
        this.beta = builder.beta;
        this.lockGammaBeta = builder.lockGammaBeta;
        this.cudnnAllowFallback = builder.cudnnAllowFallback;
        this.useLogStd = builder.useLogStd;
        this.cnn2DFormat = builder.cnn2DFormat;
        this.initializeConstraints(builder);
    }

    public BatchNormalization() {
        this(new Builder());
    }

    @Override
    public BatchNormalization clone() {
        BatchNormalization clone = (BatchNormalization)super.clone();
        return clone;
    }

    @Override
    public Layer instantiate(NeuralNetConfiguration conf, Collection<TrainingListener> trainingListeners, int layerIndex, INDArray layerParamsView, boolean initializeParams, DataType networkDataType) {
        LayerValidation.assertNOutSet("BatchNormalization", this.getLayerName(), layerIndex, this.getNOut());
        org.deeplearning4j.nn.layers.normalization.BatchNormalization ret = new org.deeplearning4j.nn.layers.normalization.BatchNormalization(conf, networkDataType);
        ret.setListeners(trainingListeners);
        ret.setIndex(layerIndex);
        ret.setParamsViewArray(layerParamsView);
        Map<String, INDArray> paramTable = this.initializer().init(conf, layerParamsView, initializeParams);
        ret.setParamTable(paramTable);
        ret.setConf(conf);
        return ret;
    }

    @Override
    public ParamInitializer initializer() {
        return BatchNormalizationParamInitializer.getInstance();
    }

    @Override
    public InputType getOutputType(int layerIndex, InputType inputType) {
        if (inputType == null) {
            throw new IllegalStateException("Invalid input type: Batch norm layer expected input of type CNN, got null for layer \"" + this.getLayerName() + "\"");
        }
        switch (inputType.getType()) {
            case FF: 
            case CNN: 
            case CNNFlat: 
            case CNN3D: 
            case RNN: {
                return inputType;
            }
        }
        throw new IllegalStateException("Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got " + inputType + " for layer index " + layerIndex + ", layer name = " + this.getLayerName());
    }

    @Override
    public void setNIn(InputType inputType, boolean override) {
        if (this.nIn <= 0L || override) {
            switch (inputType.getType()) {
                case FF: {
                    this.nIn = ((InputType.InputTypeFeedForward)inputType).getSize();
                    break;
                }
                case CNN: {
                    this.nIn = ((InputType.InputTypeConvolutional)inputType).getChannels();
                    this.cnn2DFormat = ((InputType.InputTypeConvolutional)inputType).getFormat();
                    break;
                }
                case CNN3D: {
                    this.nIn = ((InputType.InputTypeConvolutional3D)inputType).getChannels();
                    break;
                }
                case CNNFlat: {
                    this.nIn = ((InputType.InputTypeConvolutionalFlat)inputType).getDepth();
                    break;
                }
                case RNN: {
                    InputType.InputTypeRecurrent inputTypeRecurrent = (InputType.InputTypeRecurrent)inputType;
                    this.nIn = inputTypeRecurrent.getSize();
                    break;
                }
                default: {
                    throw new IllegalStateException("Invalid input type: Batch norm layer expected input of type CNN, CNN Flat or FF, got " + inputType + " for layer " + this.getLayerName() + "\"");
                }
            }
            this.nOut = this.nIn;
        }
    }

    @Override
    public InputPreProcessor getPreProcessorForInputType(InputType inputType) {
        if (inputType.getType() == InputType.Type.CNNFlat) {
            InputType.InputTypeConvolutionalFlat i = (InputType.InputTypeConvolutionalFlat)inputType;
            return new FeedForwardToCnnPreProcessor(i.getHeight(), i.getWidth(), i.getDepth());
        }
        return null;
    }

    @Override
    public List<Regularization> getRegularizationByParam(String paramName) {
        return null;
    }

    @Override
    public IUpdater getUpdaterByParam(String paramName) {
        switch (paramName) {
            case "beta": 
            case "gamma": {
                return this.iUpdater;
            }
            case "mean": 
            case "var": 
            case "log10stdev": {
                return new NoOp();
            }
        }
        throw new IllegalArgumentException("Unknown parameter: \"" + paramName + "\"");
    }

    @Override
    public LayerMemoryReport getMemoryReport(InputType inputType) {
        InputType outputType = this.getOutputType(-1, inputType);
        long numParams = this.initializer().numParams(this);
        int updaterStateSize = 0;
        for (String s : BatchNormalizationParamInitializer.getInstance().paramKeys(this)) {
            updaterStateSize = (int)((long)updaterStateSize + this.getUpdaterByParam(s).stateSize(this.nOut));
        }
        long inferenceWorkingSize = 2L * inputType.arrayElementsPerExample();
        long trainWorkFixed = 2L * this.nOut;
        long trainWorkingSizePerExample = inferenceWorkingSize + (outputType.arrayElementsPerExample() + 2L * this.nOut);
        return new LayerMemoryReport.Builder(this.layerName, BatchNormalization.class, inputType, outputType).standardMemory(numParams, updaterStateSize).workingMemory(0L, 0L, trainWorkFixed, trainWorkingSizePerExample).cacheMemory(MemoryReport.CACHE_MODE_ALL_ZEROS, MemoryReport.CACHE_MODE_ALL_ZEROS).build();
    }

    @Override
    public boolean isPretrainParam(String paramName) {
        return false;
    }

    public double getDecay() {
        return this.decay;
    }

    public double getEps() {
        return this.eps;
    }

    public boolean isMinibatch() {
        return this.isMinibatch;
    }

    public double getGamma() {
        return this.gamma;
    }

    public double getBeta() {
        return this.beta;
    }

    public boolean isLockGammaBeta() {
        return this.lockGammaBeta;
    }

    public boolean isCudnnAllowFallback() {
        return this.cudnnAllowFallback;
    }

    public boolean isUseLogStd() {
        return this.useLogStd;
    }

    public CNN2DFormat getCnn2DFormat() {
        return this.cnn2DFormat;
    }

    public void setDecay(double decay) {
        this.decay = decay;
    }

    public void setEps(double eps) {
        this.eps = eps;
    }

    public void setMinibatch(boolean isMinibatch) {
        this.isMinibatch = isMinibatch;
    }

    public void setGamma(double gamma) {
        this.gamma = gamma;
    }

    public void setBeta(double beta) {
        this.beta = beta;
    }

    public void setLockGammaBeta(boolean lockGammaBeta) {
        this.lockGammaBeta = lockGammaBeta;
    }

    public void setCudnnAllowFallback(boolean cudnnAllowFallback) {
        this.cudnnAllowFallback = cudnnAllowFallback;
    }

    public void setUseLogStd(boolean useLogStd) {
        this.useLogStd = useLogStd;
    }

    public void setCnn2DFormat(CNN2DFormat cnn2DFormat) {
        this.cnn2DFormat = cnn2DFormat;
    }

    @Override
    public String toString() {
        return "BatchNormalization(super=" + super.toString() + ", decay=" + this.getDecay() + ", eps=" + this.getEps() + ", isMinibatch=" + this.isMinibatch() + ", gamma=" + this.getGamma() + ", beta=" + this.getBeta() + ", lockGammaBeta=" + this.isLockGammaBeta() + ", cudnnAllowFallback=" + this.isCudnnAllowFallback() + ", useLogStd=" + this.isUseLogStd() + ", cnn2DFormat=" + this.getCnn2DFormat() + ")";
    }

    @Override
    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof BatchNormalization)) {
            return false;
        }
        BatchNormalization other = (BatchNormalization)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (!super.equals(o)) {
            return false;
        }
        if (Double.compare(this.getDecay(), other.getDecay()) != 0) {
            return false;
        }
        if (Double.compare(this.getEps(), other.getEps()) != 0) {
            return false;
        }
        if (this.isMinibatch() != other.isMinibatch()) {
            return false;
        }
        if (Double.compare(this.getGamma(), other.getGamma()) != 0) {
            return false;
        }
        if (Double.compare(this.getBeta(), other.getBeta()) != 0) {
            return false;
        }
        if (this.isLockGammaBeta() != other.isLockGammaBeta()) {
            return false;
        }
        if (this.isCudnnAllowFallback() != other.isCudnnAllowFallback()) {
            return false;
        }
        if (this.isUseLogStd() != other.isUseLogStd()) {
            return false;
        }
        CNN2DFormat this$cnn2DFormat = this.getCnn2DFormat();
        CNN2DFormat other$cnn2DFormat = other.getCnn2DFormat();
        return !(this$cnn2DFormat == null ? other$cnn2DFormat != null : !this$cnn2DFormat.equals(other$cnn2DFormat));
    }

    @Override
    protected boolean canEqual(Object other) {
        return other instanceof BatchNormalization;
    }

    @Override
    public int hashCode() {
        int PRIME = 59;
        int result = super.hashCode();
        long $decay = Double.doubleToLongBits(this.getDecay());
        result = result * 59 + (int)($decay >>> 32 ^ $decay);
        long $eps = Double.doubleToLongBits(this.getEps());
        result = result * 59 + (int)($eps >>> 32 ^ $eps);
        result = result * 59 + (this.isMinibatch() ? 79 : 97);
        long $gamma = Double.doubleToLongBits(this.getGamma());
        result = result * 59 + (int)($gamma >>> 32 ^ $gamma);
        long $beta = Double.doubleToLongBits(this.getBeta());
        result = result * 59 + (int)($beta >>> 32 ^ $beta);
        result = result * 59 + (this.isLockGammaBeta() ? 79 : 97);
        result = result * 59 + (this.isCudnnAllowFallback() ? 79 : 97);
        result = result * 59 + (this.isUseLogStd() ? 79 : 97);
        CNN2DFormat $cnn2DFormat = this.getCnn2DFormat();
        result = result * 59 + ($cnn2DFormat == null ? 43 : $cnn2DFormat.hashCode());
        return result;
    }

    public static class Builder
    extends FeedForwardLayer.Builder<Builder> {
        protected double decay = 0.9;
        protected double eps = 1.0E-5;
        protected boolean isMinibatch = true;
        protected boolean lockGammaBeta = false;
        protected double gamma = 1.0;
        protected double beta = 0.0;
        protected List<LayerConstraint> betaConstraints;
        protected List<LayerConstraint> gammaConstraints;
        protected boolean cudnnAllowFallback = true;
        protected boolean useLogStd = true;
        protected CNN2DFormat cnn2DFormat = CNN2DFormat.NCHW;

        public Builder(double decay, boolean isMinibatch) {
            this.setDecay(decay);
            this.setMinibatch(isMinibatch);
        }

        public Builder(double gamma, double beta) {
            this.setGamma(gamma);
            this.setBeta(beta);
        }

        public Builder(double gamma, double beta, boolean lockGammaBeta) {
            this.setGamma(gamma);
            this.setBeta(beta);
            this.setLockGammaBeta(lockGammaBeta);
        }

        public Builder(boolean lockGammaBeta) {
            this.setLockGammaBeta(lockGammaBeta);
        }

        public Builder() {
        }

        public Builder dataFormat(CNN2DFormat format) {
            this.cnn2DFormat = format;
            return this;
        }

        public Builder minibatch(boolean minibatch) {
            this.setMinibatch(minibatch);
            return this;
        }

        public Builder gamma(double gamma) {
            this.setGamma(gamma);
            return this;
        }

        public Builder beta(double beta) {
            this.setBeta(beta);
            return this;
        }

        public Builder eps(double eps) {
            this.setEps(eps);
            return this;
        }

        public Builder decay(double decay) {
            this.setDecay(decay);
            return this;
        }

        public Builder lockGammaBeta(boolean lockGammaBeta) {
            this.setLockGammaBeta(lockGammaBeta);
            return this;
        }

        public Builder constrainBeta(LayerConstraint ... constraints) {
            this.setBetaConstraints(Arrays.asList(constraints));
            return this;
        }

        public Builder constrainGamma(LayerConstraint ... constraints) {
            this.setGammaConstraints(Arrays.asList(constraints));
            return this;
        }

        @Deprecated
        public Builder cudnnAllowFallback(boolean allowFallback) {
            this.setCudnnAllowFallback(allowFallback);
            return this;
        }

        public Builder helperAllowFallback(boolean allowFallback) {
            this.cudnnAllowFallback = allowFallback;
            return this;
        }

        public Builder useLogStd(boolean useLogStd) {
            this.setUseLogStd(useLogStd);
            return this;
        }

        @Override
        public BatchNormalization build() {
            return new BatchNormalization(this);
        }

        public Builder(double decay, double eps, boolean isMinibatch, boolean lockGammaBeta, double gamma, double beta, List<LayerConstraint> betaConstraints, List<LayerConstraint> gammaConstraints, boolean cudnnAllowFallback, boolean useLogStd, CNN2DFormat cnn2DFormat) {
            this.decay = decay;
            this.eps = eps;
            this.isMinibatch = isMinibatch;
            this.lockGammaBeta = lockGammaBeta;
            this.gamma = gamma;
            this.beta = beta;
            this.betaConstraints = betaConstraints;
            this.gammaConstraints = gammaConstraints;
            this.cudnnAllowFallback = cudnnAllowFallback;
            this.useLogStd = useLogStd;
            this.cnn2DFormat = cnn2DFormat;
        }

        public double getDecay() {
            return this.decay;
        }

        public double getEps() {
            return this.eps;
        }

        public boolean isMinibatch() {
            return this.isMinibatch;
        }

        public boolean isLockGammaBeta() {
            return this.lockGammaBeta;
        }

        public double getGamma() {
            return this.gamma;
        }

        public double getBeta() {
            return this.beta;
        }

        public List<LayerConstraint> getBetaConstraints() {
            return this.betaConstraints;
        }

        public List<LayerConstraint> getGammaConstraints() {
            return this.gammaConstraints;
        }

        public boolean isCudnnAllowFallback() {
            return this.cudnnAllowFallback;
        }

        public boolean isUseLogStd() {
            return this.useLogStd;
        }

        public CNN2DFormat getCnn2DFormat() {
            return this.cnn2DFormat;
        }

        public void setDecay(double decay) {
            this.decay = decay;
        }

        public void setEps(double eps) {
            this.eps = eps;
        }

        public void setMinibatch(boolean isMinibatch) {
            this.isMinibatch = isMinibatch;
        }

        public void setLockGammaBeta(boolean lockGammaBeta) {
            this.lockGammaBeta = lockGammaBeta;
        }

        public void setGamma(double gamma) {
            this.gamma = gamma;
        }

        public void setBeta(double beta) {
            this.beta = beta;
        }

        public void setBetaConstraints(List<LayerConstraint> betaConstraints) {
            this.betaConstraints = betaConstraints;
        }

        public void setGammaConstraints(List<LayerConstraint> gammaConstraints) {
            this.gammaConstraints = gammaConstraints;
        }

        public void setCudnnAllowFallback(boolean cudnnAllowFallback) {
            this.cudnnAllowFallback = cudnnAllowFallback;
        }

        public void setUseLogStd(boolean useLogStd) {
            this.useLogStd = useLogStd;
        }

        public void setCnn2DFormat(CNN2DFormat cnn2DFormat) {
            this.cnn2DFormat = cnn2DFormat;
        }
    }
}

