/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.dropout;

import org.deeplearning4j.nn.conf.dropout.IDropout;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MulOp;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonIgnoreProperties(value={"noise"})
public class GaussianDropout
implements IDropout {
    private final double rate;
    private final ISchedule rateSchedule;
    private transient INDArray noise;

    public GaussianDropout(double rate) {
        this(rate, null);
    }

    public GaussianDropout(ISchedule rateSchedule) {
        this(Double.NaN, rateSchedule);
    }

    protected GaussianDropout(@JsonProperty(value="rate") double rate, @JsonProperty(value="rateSchedule") ISchedule rateSchedule) {
        this.rate = rate;
        this.rateSchedule = rateSchedule;
    }

    @Override
    public INDArray applyDropout(INDArray inputActivations, INDArray output, int iteration, int epoch, LayerWorkspaceMgr workspaceMgr) {
        double r = this.rateSchedule != null ? this.rateSchedule.valueAt(iteration, epoch) : this.rate;
        double stdev = Math.sqrt(r / (1.0 - r));
        this.noise = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), inputActivations.shape(), inputActivations.ordering());
        Nd4j.getExecutioner().exec((RandomOp)new GaussianDistribution(this.noise, 1.0, stdev));
        return Nd4j.getExecutioner().exec((CustomOp)new MulOp(inputActivations, this.noise, output))[0];
    }

    @Override
    public INDArray backprop(INDArray gradAtOutput, INDArray gradAtInput, int iteration, int epoch) {
        Preconditions.checkState((this.noise != null ? 1 : 0) != 0, (String)"Cannot perform backprop: GaussianDropout noise array is absent (already cleared?)");
        Nd4j.getExecutioner().exec((CustomOp)new MulOp(gradAtOutput, this.noise, gradAtInput));
        this.noise = null;
        return gradAtInput;
    }

    @Override
    public void clear() {
        this.noise = null;
    }

    @Override
    public GaussianDropout clone() {
        return new GaussianDropout(this.rate, this.rateSchedule == null ? null : this.rateSchedule.clone());
    }

    public double getRate() {
        return this.rate;
    }

    public ISchedule getRateSchedule() {
        return this.rateSchedule;
    }

    public INDArray getNoise() {
        return this.noise;
    }

    public void setNoise(INDArray noise) {
        this.noise = noise;
    }

    public String toString() {
        return "GaussianDropout(rate=" + this.getRate() + ", rateSchedule=" + this.getRateSchedule() + ", noise=" + this.getNoise() + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof GaussianDropout)) {
            return false;
        }
        GaussianDropout other = (GaussianDropout)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (Double.compare(this.getRate(), other.getRate()) != 0) {
            return false;
        }
        ISchedule this$rateSchedule = this.getRateSchedule();
        ISchedule other$rateSchedule = other.getRateSchedule();
        return !(this$rateSchedule == null ? other$rateSchedule != null : !this$rateSchedule.equals(other$rateSchedule));
    }

    protected boolean canEqual(Object other) {
        return other instanceof GaussianDropout;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        long $rate = Double.doubleToLongBits(this.getRate());
        result = result * 59 + (int)($rate >>> 32 ^ $rate);
        ISchedule $rateSchedule = this.getRateSchedule();
        result = result * 59 + ($rateSchedule == null ? 43 : $rateSchedule.hashCode());
        return result;
    }
}

