/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.berkeley;

import java.util.List;
import java.util.Map;
import org.deeplearning4j.berkeley.Counter;

public final class SloppyMath {
    public static final double LOGTOLERANCE = 30.0;
    static final float LOGTOLERANCE_F = 10.0f;

    public static double abs(double x) {
        if (x > 0.0) {
            return x;
        }
        return -1.0 * x;
    }

    public static double lambert(double v, double u) {
        double x = -(Math.log(-v) + u);
        double w = -x;
        double diff = 1.0;
        while (Math.abs(diff) < 1.0E-5) {
            double z = -x - Math.log(Math.abs(w));
            diff = z - w;
            w = z;
        }
        return w;
    }

    public static int max(int a, int b, int c) {
        int ma = a;
        if (b > ma) {
            ma = b;
        }
        if (c > ma) {
            ma = c;
        }
        return ma;
    }

    public static int min(int a, int b, int c) {
        int mi = a;
        if (b < mi) {
            mi = b;
        }
        if (c < mi) {
            mi = c;
        }
        return mi;
    }

    public static float max(float a, float b) {
        return a >= b ? a : b;
    }

    public static double max(double a, double b) {
        return a >= b ? a : b;
    }

    public static float min(float a, float b) {
        return a <= b ? a : b;
    }

    public static double min(double a, double b) {
        return a <= b ? a : b;
    }

    public static boolean isDangerous(double d) {
        return Double.isInfinite(d) || Double.isNaN(d) || d == 0.0;
    }

    public static boolean isDangerous(float d) {
        return Float.isInfinite(d) || Float.isNaN(d) || (double)d == 0.0;
    }

    public static boolean isGreater(double x, double y) {
        if (x > 1.0) {
            return (x - y) / x > -0.01;
        }
        return x - y > -1.0E-4;
    }

    public static boolean isVeryDangerous(double d) {
        return Double.isInfinite(d) || Double.isNaN(d);
    }

    public static double relativeDifferance(double a, double b) {
        a = Math.abs(a);
        b = Math.abs(b);
        double absMin = Math.min(a, b);
        return Math.abs(a - b) / absMin;
    }

    public static boolean isDiscreteProb(double d, double tol) {
        return d >= 0.0 && d <= 1.0 + tol;
    }

    public static float logAdd(float lx, float ly) {
        float negDiff;
        float max;
        if (lx > ly) {
            max = lx;
            negDiff = ly - lx;
        } else {
            max = ly;
            negDiff = lx - ly;
        }
        if ((double)max == Double.NEGATIVE_INFINITY) {
            return max;
        }
        if (negDiff < -10.0f) {
            return max;
        }
        return max + (float)Math.log(1.0 + Math.exp(negDiff));
    }

    public static double logAdd(double lx, double ly) {
        double negDiff;
        double max;
        if (lx > ly) {
            max = lx;
            negDiff = ly - lx;
        } else {
            max = ly;
            negDiff = lx - ly;
        }
        if (max == Double.NEGATIVE_INFINITY) {
            return max;
        }
        if (negDiff < -30.0) {
            return max;
        }
        return max + Math.log(1.0 + Math.exp(negDiff));
    }

    public static double logAdd(float[] logV) {
        double maxIndex = 0.0;
        double max = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < logV.length; ++i) {
            if (!((double)logV[i] > max)) continue;
            max = logV[i];
            maxIndex = i;
        }
        if (max == Double.NEGATIVE_INFINITY) {
            return Double.NEGATIVE_INFINITY;
        }
        double threshold = max - 30.0;
        double sumNegativeDifferences = 0.0;
        for (int i = 0; i < logV.length; ++i) {
            if ((double)i == maxIndex || !((double)logV[i] > threshold)) continue;
            sumNegativeDifferences += Math.exp((double)logV[i] - max);
        }
        if (sumNegativeDifferences > 0.0) {
            return max + Math.log(1.0 + sumNegativeDifferences);
        }
        return max;
    }

    public static void logNormalize(double[] logV) {
        double logSum = SloppyMath.logAdd(logV);
        if (Double.isNaN(logSum)) {
            throw new RuntimeException("Bad log-sum");
        }
        if (logSum == 0.0) {
            return;
        }
        int i = 0;
        while (i < logV.length) {
            int n = i++;
            logV[n] = logV[n] - logSum;
        }
    }

    public static double logAdd(double[] logV) {
        double maxIndex = 0.0;
        double max = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < logV.length; ++i) {
            if (!(logV[i] > max)) continue;
            max = logV[i];
            maxIndex = i;
        }
        if (max == Double.NEGATIVE_INFINITY) {
            return Double.NEGATIVE_INFINITY;
        }
        double threshold = max - 30.0;
        double sumNegativeDifferences = 0.0;
        for (int i = 0; i < logV.length; ++i) {
            if ((double)i == maxIndex || !(logV[i] > threshold)) continue;
            sumNegativeDifferences += Math.exp(logV[i] - max);
        }
        if (sumNegativeDifferences > 0.0) {
            return max + Math.log(1.0 + sumNegativeDifferences);
        }
        return max;
    }

    public static double logAdd(List<Double> logV) {
        double max = Double.NEGATIVE_INFINITY;
        double maxIndex = 0.0;
        for (int i = 0; i < logV.size(); ++i) {
            if (!(logV.get(i) > max)) continue;
            max = logV.get(i);
            maxIndex = i;
        }
        if (max == Double.NEGATIVE_INFINITY) {
            return Double.NEGATIVE_INFINITY;
        }
        double threshold = max - 30.0;
        double sumNegativeDifferences = 0.0;
        for (int i = 0; i < logV.size(); ++i) {
            if ((double)i == maxIndex || !(logV.get(i) > threshold)) continue;
            sumNegativeDifferences += Math.exp(logV.get(i) - max);
        }
        if (sumNegativeDifferences > 0.0) {
            return max + Math.log(1.0 + sumNegativeDifferences);
        }
        return max;
    }

    public static float logAdd_Old(float[] logV) {
        float max = Float.NEGATIVE_INFINITY;
        float maxIndex = 0.0f;
        for (int i = 0; i < logV.length; ++i) {
            if (!(logV[i] > max)) continue;
            max = logV[i];
            maxIndex = i;
        }
        if (max == Float.NEGATIVE_INFINITY) {
            return Float.NEGATIVE_INFINITY;
        }
        float threshold = max - 10.0f;
        float sumNegativeDifferences = 0.0f;
        for (int i = 0; i < logV.length; ++i) {
            if ((float)i == maxIndex || !(logV[i] > threshold)) continue;
            sumNegativeDifferences = (float)((double)sumNegativeDifferences + Math.exp(logV[i] - max));
        }
        if ((double)sumNegativeDifferences > 0.0) {
            return max + (float)Math.log(1.0f + sumNegativeDifferences);
        }
        return max;
    }

    public static float logAdd(float[] logV, int lastIndex) {
        if (lastIndex == 0) {
            return Float.NEGATIVE_INFINITY;
        }
        float max = Float.NEGATIVE_INFINITY;
        float maxIndex = 0.0f;
        for (int i = 0; i < lastIndex; ++i) {
            if (!(logV[i] > max)) continue;
            max = logV[i];
            maxIndex = i;
        }
        if (max == Float.NEGATIVE_INFINITY) {
            return Float.NEGATIVE_INFINITY;
        }
        float threshold = max - 10.0f;
        double sumNegativeDifferences = 0.0;
        for (int i = 0; i < lastIndex; ++i) {
            if ((float)i == maxIndex || !(logV[i] > threshold)) continue;
            sumNegativeDifferences += Math.exp(logV[i] - max);
        }
        if (sumNegativeDifferences > 0.0) {
            return max + (float)Math.log(1.0 + sumNegativeDifferences);
        }
        return max;
    }

    public static double logAdd(double[] logV, int lastIndex) {
        if (lastIndex == 0) {
            return Double.NEGATIVE_INFINITY;
        }
        double max = Double.NEGATIVE_INFINITY;
        double maxIndex = 0.0;
        for (int i = 0; i < lastIndex; ++i) {
            if (!(logV[i] > max)) continue;
            max = logV[i];
            maxIndex = i;
        }
        if (max == Double.NEGATIVE_INFINITY) {
            return Double.NEGATIVE_INFINITY;
        }
        double threshold = max - 30.0;
        double sumNegativeDifferences = 0.0;
        for (int i = 0; i < lastIndex; ++i) {
            if ((double)i == maxIndex || !(logV[i] > threshold)) continue;
            sumNegativeDifferences += Math.exp(logV[i] - max);
        }
        if (sumNegativeDifferences > 0.0) {
            return max + Math.log(1.0 + sumNegativeDifferences);
        }
        return max;
    }

    public static float addExp_Old(float[] logV) {
        float max = Float.NEGATIVE_INFINITY;
        float maxIndex = 0.0f;
        for (int i = 0; i < logV.length; ++i) {
            if (!(logV[i] > max)) continue;
            max = logV[i];
            maxIndex = i;
        }
        if (max == Float.NEGATIVE_INFINITY) {
            return Float.NEGATIVE_INFINITY;
        }
        float threshold = max - 10.0f;
        float sumNegativeDifferences = 0.0f;
        for (int i = 0; i < logV.length; ++i) {
            if ((float)i == maxIndex || !(logV[i] > threshold)) continue;
            sumNegativeDifferences = (float)((double)sumNegativeDifferences + Math.exp(logV[i] - max));
        }
        return (float)Math.exp(max) * (1.0f + sumNegativeDifferences);
    }

    public static float addExp(float[] logV, int lastIndex) {
        if (lastIndex == 0) {
            return Float.NEGATIVE_INFINITY;
        }
        float max = Float.NEGATIVE_INFINITY;
        float maxIndex = 0.0f;
        for (int i = 0; i < lastIndex; ++i) {
            if (!(logV[i] > max)) continue;
            max = logV[i];
            maxIndex = i;
        }
        if (max == Float.NEGATIVE_INFINITY) {
            return Float.NEGATIVE_INFINITY;
        }
        float threshold = max - 10.0f;
        float sumNegativeDifferences = 0.0f;
        for (int i = 0; i < lastIndex; ++i) {
            if ((float)i == maxIndex || !(logV[i] > threshold)) continue;
            sumNegativeDifferences = (float)((double)sumNegativeDifferences + Math.exp(logV[i] - max));
        }
        return (float)Math.exp(max) * (1.0f + sumNegativeDifferences);
    }

    public static int nChooseK(int n, int k) {
        if ((k = Math.min(k, n - k)) == 0) {
            return 1;
        }
        int accum = n;
        for (int i = 1; i < k; ++i) {
            accum *= n - i;
            accum /= i;
        }
        return accum / k;
    }

    public static int intPow(int b, int e) {
        if (e == 0) {
            return 1;
        }
        int result = 1;
        int currPow = b;
        do {
            if ((e & 1) == 1) {
                result *= currPow;
            }
            currPow *= currPow;
        } while ((e >>= 1) > 0);
        return result;
    }

    public static float intPow(float b, int e) {
        if (e == 0) {
            return 1.0f;
        }
        float result = 1.0f;
        float currPow = b;
        do {
            if ((e & 1) == 1) {
                result *= currPow;
            }
            currPow *= currPow;
        } while ((e >>= 1) > 0);
        return result;
    }

    public static double intPow(double b, int e) {
        if (e == 0) {
            return 1.0;
        }
        float result = 1.0f;
        double currPow = b;
        do {
            if ((e & 1) == 1) {
                result = (float)((double)result * currPow);
            }
            currPow *= currPow;
        } while ((e >>= 1) > 0);
        return result;
    }

    public static double hypergeometric(int k, int n, int r, int m) {
        if (k < 0 || r > n || m > n || n <= 0 || m < 0 | r < 0) {
            throw new IllegalArgumentException("Invalid hypergeometric");
        }
        if (m > n / 2) {
            m = n - m;
            k = r - k;
        }
        if (r > n / 2) {
            r = n - r;
            k = m - k;
        }
        if (m > r) {
            int temp = m;
            m = r;
            r = temp;
        }
        if (k < m + r - n || k > m) {
            return 0.0;
        }
        if (r == n) {
            if (k == m) {
                return 1.0;
            }
            return 0.0;
        }
        if (r == n - 1) {
            if (k == m) {
                return (double)(n - m) / (double)n;
            }
            if (k == m - 1) {
                return (double)m / (double)n;
            }
            return 0.0;
        }
        if (m == 1) {
            if (k == 0) {
                return (double)(n - r) / (double)n;
            }
            if (k == 1) {
                return (double)r / (double)n;
            }
            return 0.0;
        }
        if (m == 0) {
            if (k == 0) {
                return 1.0;
            }
            return 0.0;
        }
        if (k == 0) {
            double ans = 1.0;
            for (int m0 = 0; m0 < m; ++m0) {
                ans *= (double)(n - r - m0);
                ans /= (double)(n - m0);
            }
            return ans;
        }
        double ans = 1.0;
        int nr = n - r;
        int n0 = n;
        while (nr > n - r - (m - k)) {
            ans *= (double)nr;
            ans /= (double)n0;
            --nr;
            --n0;
        }
        for (int k0 = 0; k0 < k; ++k0) {
            ans *= (double)(m - k0);
            ans /= (double)(n - (m - k0) + 1);
            ans *= (double)(r - k0);
            ans /= (double)(k0 + 1);
        }
        return ans;
    }

    public static double exactBinomial(int k, int n, double p) {
        double total = 0.0;
        for (int m = k; m <= n; ++m) {
            double nChooseM = 1.0;
            for (int r = 1; r <= m; ++r) {
                nChooseM *= (double)(n - r + 1);
                nChooseM /= (double)r;
            }
            total += nChooseM * Math.pow(p, m) * Math.pow(1.0 - p, n - m);
        }
        return total;
    }

    public static double oneTailedFishersExact(int k, int n, int r, int m) {
        if (k < 0 || k < m + r - n || k > r || k > m || r > n || m > n) {
            throw new IllegalArgumentException("Invalid Fisher's exact: k=" + k + " n=" + n + " r=" + r + " m=" + m + " k<0=" + (k < 0) + " k<(m+r)-n=" + (k < m + r - n) + " k>r=" + (k > r) + " k>m=" + (k > m) + " r>n=" + (r > n) + "m>n=" + (m > n));
        }
        if (m > n / 2) {
            m = n - m;
            k = r - k;
        }
        if (r > n / 2) {
            r = n - r;
            k = m - k;
        }
        if (m > r) {
            int temp = m;
            m = r;
            r = temp;
        }
        double total = 0.0;
        if (k > m / 2) {
            for (int k0 = k; k0 <= m; ++k0) {
                total += SloppyMath.hypergeometric(k0, n, r, m);
            }
        } else {
            int min;
            for (int k0 = min = Math.max(0, m + r - n); k0 < k; ++k0) {
                total += SloppyMath.hypergeometric(k0, n, r, m);
            }
            total = 1.0 - total;
        }
        return total;
    }

    public static double chiSquare2by2(int k, int n, int r, int m) {
        int[][] cg = new int[][]{{k, r - k}, {m - k, n - (k + (r - k) + (m - k))}};
        int[] cgr = new int[]{r, n - r};
        int[] cgc = new int[]{m, n - m};
        double total = 0.0;
        for (int i = 0; i < 2; ++i) {
            for (int j = 0; j < 2; ++j) {
                double exp = (double)cgr[i] * (double)cgc[j] / (double)n;
                total += ((double)cg[i][j] - exp) * ((double)cg[i][j] - exp) / exp;
            }
        }
        return total;
    }

    public static double exp(double logX) {
        if (Math.abs(logX) < 0.001) {
            return 1.0 + logX;
        }
        return Math.exp(logX);
    }

    public static void main(String[] args) {
        System.out.println(SloppyMath.approxLog(0.0));
    }

    public static double noNaNDivide(double num, double denom) {
        return denom == 0.0 ? 0.0 : num / denom;
    }

    public static double approxLog(double val) {
        if (val < 0.0) {
            return Double.NaN;
        }
        if (val == 0.0) {
            return Double.NEGATIVE_INFINITY;
        }
        double r = val - 1.0;
        if (Math.abs(r) < 0.3) {
            double rSquared = r * r;
            return r - rSquared / 2.0 + rSquared * r / 3.0;
        }
        double x = Double.doubleToLongBits(val) >> 32;
        return (x - 1.072632447E9) / 1512775.0;
    }

    public static double approxExp(double val) {
        if (Math.abs(val) < 0.1) {
            return 1.0 + val;
        }
        long tmp = (long)(1512775.0 * val + 1.072632447E9);
        return Double.longBitsToDouble(tmp << 32);
    }

    public static double approxPow(double a, double b) {
        int tmp = (int)(Double.doubleToLongBits(a) >> 32);
        int tmp2 = (int)(b * (double)(tmp - 1072632447) + 1.072632447E9);
        return Double.longBitsToDouble((long)tmp2 << 32);
    }

    public static double logSubtract(double a, double b) {
        if (a > b) {
            return a + Math.log(1.0 - Math.exp(b - a));
        }
        return b + Math.log(-1.0 + Math.exp(a - b));
    }

    public static double unsafeSubtract(double a, double b) {
        if (a == b) {
            return 0.0;
        }
        if (a == Double.NEGATIVE_INFINITY) {
            return a;
        }
        return a - b;
    }

    public static double unsafeAdd(double a, double b) {
        if (a == b) {
            return 0.0;
        }
        if (a == Double.POSITIVE_INFINITY) {
            return a;
        }
        return a + b;
    }

    public static <T> double logAdd(Counter<T> counts) {
        double[] arr = new double[counts.size()];
        int index = 0;
        for (Map.Entry<T, Double> entry : counts.entrySet()) {
            arr[index++] = entry.getValue();
        }
        return SloppyMath.logAdd(arr);
    }
}

