/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.activations.impl;

import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.RectifedLinear;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties;

@JsonIgnoreProperties(value={"alpha"})
public class ActivationRReLU
extends BaseActivationFunction {
    public static final double DEFAULT_L = 0.125;
    public static final double DEFAULT_U = 0.3333333333333333;
    private double l;
    private double u;
    private transient INDArray alpha;

    public ActivationRReLU() {
        this(0.125, 0.3333333333333333);
    }

    public ActivationRReLU(double l, double u) {
        if (l > u) {
            throw new IllegalArgumentException("Cannot have lower value (" + l + ") greater than upper (" + u + ")");
        }
        this.l = l;
        this.u = u;
    }

    @Override
    public INDArray getActivation(INDArray in, boolean training) {
        if (training) {
            try (MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces();){
                this.alpha = Nd4j.rand(in.shape(), this.l, this.u, Nd4j.getRandom());
            }
        } else {
            this.alpha = null;
            double a = 0.5 * (this.l + this.u);
            return Nd4j.getExecutioner().execAndReturn(new RectifedLinear(in, a));
        }
        INDArray inTimesAlpha = in.mul(this.alpha);
        BooleanIndexing.replaceWhere(in, inTimesAlpha, Conditions.lessThan(0));
        return in;
    }

    @Override
    public Pair<INDArray, INDArray> backprop(INDArray in, INDArray epsilon) {
        INDArray dLdz = Nd4j.ones(in.shape());
        BooleanIndexing.replaceWhere(dLdz, this.alpha, Conditions.lessThanOrEqual(0.0));
        dLdz.muli(epsilon);
        return new Pair((Object)dLdz, null);
    }

    public String toString() {
        return "rrelu(l=" + this.l + ", u=" + this.u + ")";
    }

    public boolean equals(Object o) {
        if (o == this) {
            return true;
        }
        if (!(o instanceof ActivationRReLU)) {
            return false;
        }
        ActivationRReLU other = (ActivationRReLU)o;
        if (!other.canEqual(this)) {
            return false;
        }
        if (Double.compare(this.getL(), other.getL()) != 0) {
            return false;
        }
        return Double.compare(this.getU(), other.getU()) == 0;
    }

    protected boolean canEqual(Object other) {
        return other instanceof ActivationRReLU;
    }

    public int hashCode() {
        int PRIME = 59;
        int result = 1;
        long $l = Double.doubleToLongBits(this.getL());
        result = result * 59 + (int)($l >>> 32 ^ $l);
        long $u = Double.doubleToLongBits(this.getU());
        result = result * 59 + (int)($u >>> 32 ^ $u);
        return result;
    }

    public double getL() {
        return this.l;
    }

    public double getU() {
        return this.u;
    }

    public INDArray getAlpha() {
        return this.alpha;
    }
}

