/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.conf.dropout;

import lombok.NonNull;
import org.deeplearning4j.nn.conf.dropout.IDropout;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.RandomOp;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.schedule.ISchedule;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonIgnoreProperties(value={"lastPValue", "alphaPrime", "a", "b", "mask"})
public class AlphaDropout
implements IDropout {
    public static final double DEFAULT_ALPHA = 1.6732632423543772;
    public static final double DEFAULT_LAMBDA = 1.0507009873554805;
    private final double p;
    private final ISchedule pSchedule;
    private final double alpha;
    private final double lambda;
    private transient double lastPValue;
    private double alphaPrime;
    private double a;
    private double b;
    private transient INDArray mask;

    public AlphaDropout(double activationRetainProbability) {
        this(activationRetainProbability, null, 1.6732632423543772, 1.0507009873554805);
        if (activationRetainProbability < 0.0) {
            throw new IllegalArgumentException("Activation retain probability must be > 0. Got: " + activationRetainProbability);
        }
        if (activationRetainProbability == 0.0) {
            throw new IllegalArgumentException("Invalid probability value: Dropout with 0.0 probability of retaining activations is not supported");
        }
    }

    public AlphaDropout(@NonNull ISchedule activationRetainProbabilitySchedule) {
        this(Double.NaN, activationRetainProbabilitySchedule, 1.6732632423543772, 1.0507009873554805);
        if (activationRetainProbabilitySchedule == null) {
            throw new NullPointerException("activationRetainProbabilitySchedule is marked @NonNull but is null");
        }
    }

    protected AlphaDropout(@JsonProperty(value="p") double activationRetainProbability, @JsonProperty(value="pSchedule") ISchedule activationRetainProbabilitySchedule, @JsonProperty(value="alpha") double alpha, @JsonProperty(value="lambda") double lambda) {
        this.p = activationRetainProbability;
        this.pSchedule = activationRetainProbabilitySchedule;
        this.alpha = alpha;
        this.lambda = lambda;
        this.alphaPrime = -lambda * alpha;
        if (activationRetainProbabilitySchedule == null) {
            this.lastPValue = this.p;
            this.a = this.a(this.p);
            this.b = this.b(this.p);
        }
    }

    @Override
    public INDArray applyDropout(INDArray inputActivations, INDArray output, int iteration, int epoch, LayerWorkspaceMgr workspaceMgr) {
        double pValue = this.pSchedule != null ? this.pSchedule.valueAt(iteration, epoch) : this.p;
        if (pValue != this.lastPValue) {
            this.a = this.a(pValue);
            this.b = this.b(pValue);
        }
        this.lastPValue = pValue;
        this.mask = workspaceMgr.createUninitialized(ArrayType.INPUT, output.dataType(), output.shape(), output.ordering());
        Nd4j.getExecutioner().exec((RandomOp)new BernoulliDistribution(this.mask, pValue));
        INDArray inverseMask = this.mask.rsub((Number)1.0);
        INDArray aPOneMinusD = inverseMask.muli((Number)this.alphaPrime);
        Nd4j.getExecutioner().exec((Op)new OldMulOp(inputActivations, this.mask, output));
        output.addi(aPOneMinusD).muli((Number)this.a).addi((Number)this.b);
        return output;
    }

    @Override
    public INDArray backprop(INDArray gradAtOutput, INDArray gradAtInput, int iteration, int epoch) {
        Preconditions.checkState((this.mask != null ? 1 : 0) != 0, (String)"Cannot perform backprop: Dropout mask array is absent (already cleared?)");
        this.mask.muli((Number)this.a);
        Nd4j.getExecutioner().exec((Op)new OldMulOp(gradAtOutput, this.mask, gradAtInput));
        this.mask = null;
        return gradAtInput;
    }

    @Override
    public void clear() {
        this.mask = null;
    }

    @Override
    public AlphaDropout clone() {
        return new AlphaDropout(this.p, this.pSchedule == null ? null : this.pSchedule.clone(), this.alpha, this.lambda);
    }

    public double a(double p) {
        return 1.0 / Math.sqrt(p + this.alphaPrime * this.alphaPrime * p * (1.0 - p));
    }

    public double b(double p) {
        return -this.a(p) * (1.0 - p) * this.alphaPrime;
    }

    public double getP() {
        return this.p;
    }

    public ISchedule getPSchedule() {
        return this.pSchedule;
    }

    public double getAlpha() {
        return this.alpha;
    }

    public double getLambda() {
        return this.lambda;
    }

    public double getLastPValue() {
        return this.lastPValue;
    }

    public double getAlphaPrime() {
        return this.alphaPrime;
    }

    public double getA() {
        return this.a;
    }

    public double getB() {
        return this.b;
    }

    public INDArray getMask() {
        return this.mask;
    }

    public void setLastPValue(double lastPValue) {
        this.lastPValue = lastPValue;
    }

    public void setAlphaPrime(double alphaPrime) {
        this.alphaPrime = alphaPrime;
    }

    public void setA(double a) {
        this.a = a;
    }

    public void setB(double b) {
        this.b = b;
    }

    public void setMask(INDArray mask) {
        this.mask = mask;
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof AlphaDropout)) {
            return false;
        }
        AlphaDropout other = (AlphaDropout)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (Double.compare(this.getP(), other.getP()) != 0) {
            return false;
        }
        ISchedule this$pSchedule = this.getPSchedule();
        ISchedule other$pSchedule = other.getPSchedule();
        if (this$pSchedule == null ? other$pSchedule != null : !this$pSchedule.equals(other$pSchedule)) {
            return false;
        }
        if (Double.compare(this.getAlpha(), other.getAlpha()) != 0) {
            return false;
        }
        return Double.compare(this.getLambda(), other.getLambda()) == 0;
    }

    protected boolean canEqual(Object other) {
        return other instanceof AlphaDropout;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        long $p = Double.doubleToLongBits(this.getP());
        result = result * 59 + (int)($p >>> 32 ^ $p);
        ISchedule $pSchedule = this.getPSchedule();
        result = result * 59 + ($pSchedule == null ? 43 : $pSchedule.hashCode());
        long $alpha = Double.doubleToLongBits(this.getAlpha());
        result = result * 59 + (int)($alpha >>> 32 ^ $alpha);
        long $lambda = Double.doubleToLongBits(this.getLambda());
        result = result * 59 + (int)($lambda >>> 32 ^ $lambda);
        return result;
    }

    public String toString() {
        return "AlphaDropout(p=" + this.getP() + ", pSchedule=" + this.getPSchedule() + ", alpha=" + this.getAlpha() + ", lambda=" + this.getLambda() + ", mask=" + this.getMask() + ")";
    }
}

