/*
 * Decompiled with CFR 0.152.
 */
package smile.math;

import java.util.ArrayList;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.DifferentiableMultivariateFunction;
import smile.math.MathEx;
import smile.math.MultivariateFunction;
import smile.math.blas.UPLO;
import smile.math.matrix.Matrix;
import smile.sort.QuickSort;

public class BFGS {
    private static final Logger logger = LoggerFactory.getLogger(BFGS.class);
    private static final double EPSILON = Double.parseDouble(System.getProperty("smile.bfgs.epsilon", "1E-8"));
    private static final double TOLX = 4.0 * EPSILON;
    private static final double TOLF = 4.0 * EPSILON;
    private static final double STPMX = 100.0;

    public static double minimize(DifferentiableMultivariateFunction func, double[] x, double gtol, int maxIter) {
        if (gtol <= 0.0) {
            throw new IllegalArgumentException("Invalid gradient tolerance: " + gtol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        int n = x.length;
        double[] dg = new double[n];
        double[] g = new double[n];
        double[] hdg = new double[n];
        double[] xnew = new double[n];
        double[] xi = new double[n];
        double[][] hessin = new double[n][n];
        double f = func.g(x, g);
        logger.info(String.format("BFGS: initial function value: %.5f", f));
        for (int i = 0; i < n; ++i) {
            hessin[i][i] = 1.0;
            xi[i] = -g[i];
        }
        double stpmax = 100.0 * Math.max(MathEx.norm(x), (double)n);
        for (int iter = 1; iter <= maxIter; ++iter) {
            int j;
            double temp;
            int i;
            f = BFGS.linesearch(func, x, f, g, xi, xnew, stpmax);
            if (iter % 100 == 0) {
                logger.info(String.format("BFGS: the function value after %3d iterations: %.5f", iter, f));
            }
            for (i = 0; i < n; ++i) {
                xi[i] = xnew[i] - x[i];
                x[i] = xnew[i];
            }
            double test = 0.0;
            for (i = 0; i < n; ++i) {
                temp = Math.abs(xi[i]) / Math.max(Math.abs(x[i]), 1.0);
                if (!(temp > test)) continue;
                test = temp;
            }
            if (test < TOLX) {
                logger.info(String.format("BFGS converges on x after %d iterations: %.5f", iter, f));
                return f;
            }
            System.arraycopy(g, 0, dg, 0, n);
            func.g(x, g);
            double den = Math.max(f, 1.0);
            test = 0.0;
            for (i = 0; i < n; ++i) {
                temp = Math.abs(g[i]) * Math.max(Math.abs(x[i]), 1.0) / den;
                if (!(temp > test)) continue;
                test = temp;
            }
            if (test < gtol) {
                logger.info(String.format("BFGS converges on gradient after %d iterations: %.5f", iter, f));
                return f;
            }
            for (i = 0; i < n; ++i) {
                dg[i] = g[i] - dg[i];
            }
            for (i = 0; i < n; ++i) {
                hdg[i] = 0.0;
                for (j = 0; j < n; ++j) {
                    int n2 = i;
                    hdg[n2] = hdg[n2] + hessin[i][j] * dg[j];
                }
            }
            double sumxi = 0.0;
            double sumdg = 0.0;
            double fae = 0.0;
            double fac = 0.0;
            for (i = 0; i < n; ++i) {
                fac += dg[i] * xi[i];
                fae += dg[i] * hdg[i];
                sumdg += dg[i] * dg[i];
                sumxi += xi[i] * xi[i];
            }
            if (fac > Math.sqrt(EPSILON * sumdg * sumxi)) {
                fac = 1.0 / fac;
                double fad = 1.0 / fae;
                for (i = 0; i < n; ++i) {
                    dg[i] = fac * xi[i] - fad * hdg[i];
                }
                for (i = 0; i < n; ++i) {
                    for (j = i; j < n; ++j) {
                        double[] dArray = hessin[i];
                        int n3 = j;
                        dArray[n3] = dArray[n3] + (fac * xi[i] * xi[j] - fad * hdg[i] * hdg[j] + fae * dg[i] * dg[j]);
                        hessin[j][i] = hessin[i][j];
                    }
                }
            }
            Arrays.fill(xi, 0.0);
            for (i = 0; i < n; ++i) {
                for (j = 0; j < n; ++j) {
                    int n4 = i;
                    xi[n4] = xi[n4] - hessin[i][j] * g[j];
                }
            }
        }
        logger.warn(String.format("BFGS reaches maximum %d iterations: %.5f", maxIter, f));
        return f;
    }

    public static double minimize(DifferentiableMultivariateFunction func, int m, double[] x, double gtol, int maxIter) {
        if (gtol <= 0.0) {
            throw new IllegalArgumentException("Invalid gradient tolerance: " + gtol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        if (m <= 0) {
            throw new IllegalArgumentException("Invalid m: " + m);
        }
        int n = x.length;
        double[] xnew = new double[n];
        double[] gnew = new double[n];
        double[] xi = new double[n];
        double[][] s = new double[m][n];
        double[][] y = new double[m][n];
        double[] rho = new double[m];
        double[] a = new double[m];
        double diag = 1.0;
        double[] g = new double[n];
        double f = func.g(x, g);
        logger.info(String.format("L-BFGS: initial function value: %.5f", f));
        for (int i = 0; i < n; ++i) {
            xi[i] = -g[i];
        }
        double stpmax = 100.0 * Math.max(MathEx.norm(x), (double)n);
        int k = 0;
        for (int iter = 1; iter <= maxIter; ++iter) {
            int i;
            BFGS.linesearch(func, x, f, g, xi, xnew, stpmax);
            f = func.g(xnew, gnew);
            for (int i2 = 0; i2 < n; ++i2) {
                s[k][i2] = xnew[i2] - x[i2];
                y[k][i2] = gnew[i2] - g[i2];
                x[i2] = xnew[i2];
                g[i2] = gnew[i2];
            }
            double test = 0.0;
            for (int i3 = 0; i3 < n; ++i3) {
                double temp = Math.abs(s[k][i3]) / Math.max(Math.abs(x[i3]), 1.0);
                if (!(temp > test)) continue;
                test = temp;
            }
            if (test < TOLX) {
                logger.info(String.format("L-BFGS converges on x after %d iterations: %.5f", iter, f));
                return f;
            }
            test = 0.0;
            double den = Math.max(f, 1.0);
            for (int i4 = 0; i4 < n; ++i4) {
                double temp = Math.abs(g[i4]) * Math.max(Math.abs(x[i4]), 1.0) / den;
                if (!(temp > test)) continue;
                test = temp;
            }
            if (test < gtol) {
                logger.info(String.format("L-BFGS converges on gradient after %d iterations: %.5f", iter, f));
                return f;
            }
            if (iter % 100 == 0) {
                logger.info(String.format("L-BFGS: the function value after %3d iterations: %.5f", iter, f));
            }
            double ys = MathEx.dot(y[k], s[k]);
            double yy = MathEx.dot(y[k], y[k]);
            diag = ys / yy;
            rho[k] = 1.0 / ys;
            for (int i5 = 0; i5 < n; ++i5) {
                xi[i5] = -g[i5];
            }
            int cp = k;
            int bound = Math.min(iter, m);
            for (i = 0; i < bound; ++i) {
                a[cp] = rho[cp] * MathEx.dot(s[cp], xi);
                MathEx.axpy(-a[cp], y[cp], xi);
                if (--cp != -1) continue;
                cp = m - 1;
            }
            i = 0;
            while (i < n) {
                int n2 = i++;
                xi[n2] = xi[n2] * diag;
            }
            for (i = 0; i < bound; ++i) {
                if (++cp == m) {
                    cp = 0;
                }
                double b = rho[cp] * MathEx.dot(y[cp], xi);
                MathEx.axpy(a[cp] - b, s[cp], xi);
            }
            if (++k != m) continue;
            k = 0;
        }
        logger.warn(String.format("L-BFGS reaches maximum %d iterations: %.5f", maxIter, f));
        return f;
    }

    private static double linesearch(MultivariateFunction func, double[] xold, double fold, double[] g, double[] p, double[] x, double stpmax) {
        int i;
        if (stpmax <= 0.0) {
            throw new IllegalArgumentException("Invalid upper bound of linear search step: " + stpmax);
        }
        double xtol = EPSILON;
        double ftol = 1.0E-4;
        int n = xold.length;
        double pnorm = MathEx.norm(p);
        if (pnorm > stpmax) {
            double r = stpmax / pnorm;
            i = 0;
            while (i < n) {
                int n2 = i++;
                p[n2] = p[n2] * r;
            }
        }
        double slope = 0.0;
        for (i = 0; i < n; ++i) {
            slope += g[i] * p[i];
        }
        if (slope >= 0.0) {
            logger.warn("Line Search: the search direction is not a descent direction, which may be caused by roundoff problem.");
        }
        double test = 0.0;
        for (int i2 = 0; i2 < n; ++i2) {
            double temp = Math.abs(p[i2]) / Math.max(xold[i2], 1.0);
            if (!(temp > test)) continue;
            test = temp;
        }
        double alammin = xtol / test;
        double alam = 1.0;
        double alam2 = 0.0;
        double f2 = 0.0;
        while (true) {
            double tmpalam;
            for (int i3 = 0; i3 < n; ++i3) {
                x[i3] = xold[i3] + alam * p[i3];
            }
            double f = func.apply(x);
            if (alam < alammin) {
                System.arraycopy(xold, 0, x, 0, n);
                return f;
            }
            if (f <= fold + 1.0E-4 * alam * slope) {
                return f;
            }
            if (alam == 1.0) {
                tmpalam = -slope / (2.0 * (f - fold - slope));
            } else {
                double disc;
                double rhs1 = f - fold - alam * slope;
                double rhs2 = f2 - fold - alam2 * slope;
                double a = (rhs1 / (alam * alam) - rhs2 / (alam2 * alam2)) / (alam - alam2);
                double b = (-alam2 * rhs1 / (alam * alam) + alam * rhs2 / (alam2 * alam2)) / (alam - alam2);
                tmpalam = a == 0.0 ? -slope / (2.0 * b) : ((disc = b * b - 3.0 * a * slope) < 0.0 ? 0.5 * alam : (b <= 0.0 ? (-b + Math.sqrt(disc)) / (3.0 * a) : -slope / (b + Math.sqrt(disc))));
                if (tmpalam > 0.5 * alam) {
                    tmpalam = 0.5 * alam;
                }
            }
            alam2 = alam;
            f2 = f;
            alam = Math.max(tmpalam, 0.1 * alam);
        }
    }

    public static double minimize(DifferentiableMultivariateFunction func, int m, double[] x, double[] l, double[] u, double gtol, int maxIter) {
        if (gtol <= 0.0) {
            throw new IllegalArgumentException("Invalid gradient tolerance: " + gtol);
        }
        if (maxIter <= 0) {
            throw new IllegalArgumentException("Invalid maximum number of iterations: " + maxIter);
        }
        if (m <= 0) {
            throw new IllegalArgumentException("Invalid m: " + m);
        }
        if (l.length != x.length) {
            throw new IllegalArgumentException("Invalid lower bound size: " + l.length);
        }
        if (u.length != x.length) {
            throw new IllegalArgumentException("Invalid upper bound size: " + u.length);
        }
        int n = x.length;
        double theta = 1.0;
        Matrix Y = null;
        Matrix S = null;
        Matrix W = new Matrix(n, 1);
        Matrix M = new Matrix(1, 1);
        ArrayList<double[]> yHistory = new ArrayList<double[]>();
        ArrayList<double[]> sHistory = new ArrayList<double[]>();
        double[] y = new double[n];
        double[] s = new double[n];
        double[] p = new double[n];
        double[] g = new double[n];
        double[] cauchy = new double[n];
        double f = func.g(x, g);
        double[] x_old = new double[n];
        double[] g_old = new double[n];
        double stpmax = 100.0 * Math.max(MathEx.norm(x), (double)n);
        for (int iter = 1; iter <= maxIter; ++iter) {
            double f_old = f;
            System.arraycopy(x, 0, x_old, 0, n);
            System.arraycopy(g, 0, g_old, 0, n);
            System.arraycopy(x, 0, cauchy, 0, n);
            double[] c = BFGS.cauchy(x, g, cauchy, l, u, theta, W, M);
            BFGS.clampToBound(cauchy, l, u);
            double[] subspaceMin = BFGS.subspaceMinimization(x, g, cauchy, c, l, u, theta, W, M);
            BFGS.clampToBound(subspaceMin, l, u);
            for (int i = 0; i < n; ++i) {
                p[i] = subspaceMin[i] - x[i];
            }
            BFGS.linesearch(func, x_old, f_old, g, p, x, stpmax);
            BFGS.clampToBound(x, l, u);
            for (double xi : x) {
                if (!Double.isNaN(xi) && !Double.isInfinite(xi)) continue;
                logger.warn("L-BFGS-B: bad x produced by line search, return previous good x");
                System.arraycopy(x_old, 0, x, 0, n);
                return f_old;
            }
            f = func.g(x, g);
            if (Double.isNaN(f) || Double.isInfinite(f)) {
                logger.warn("L-BFGS-B: bad f(x) produced by line search, return previous good x");
                System.arraycopy(x_old, 0, x, 0, n);
                return f_old;
            }
            if (BFGS.gnorm(x, g, l, u) < gtol) {
                logger.info(String.format("L-BFGS-B converges on gradient after %d iterations: %.5f", iter, f));
                return f;
            }
            if (iter % 100 == 0) {
                logger.info(String.format("L-BFGS-B: the function value after %3d iterations: %.5f", iter, f));
            }
            for (int i = 0; i < n; ++i) {
                y[i] = g[i] - g_old[i];
                s[i] = x[i] - x_old[i];
            }
            double sy = MathEx.dot(s, y);
            double yy = MathEx.dot(y, y);
            double test = Math.abs(sy);
            if (test > EPSILON * yy) {
                int i;
                if (yHistory.size() >= m) {
                    yHistory.remove(0);
                    sHistory.remove(0);
                }
                yHistory.add(y);
                sHistory.add(s);
                int h = yHistory.size();
                if (Y == null || Y.ncol() < h) {
                    Y = new Matrix(n, h);
                    S = new Matrix(n, h);
                    W = new Matrix(n, 2 * h);
                    M = new Matrix(2 * h, 2 * h);
                }
                theta = yy / sy;
                for (int j = 0; j < h; ++j) {
                    double[] yj = (double[])yHistory.get(j);
                    double[] sj = (double[])sHistory.get(j);
                    for (i = 0; i < n; ++i) {
                        Y.set(i, j, yj[i]);
                        S.set(i, j, sj[i]);
                        W.set(i, j, yj[i]);
                        W.set(i, h + j, sj[i] * theta);
                    }
                }
                Matrix SY = S.tm(Y);
                Matrix SS = S.ata();
                for (int j = 0; j < h; ++j) {
                    M.set(j, j, -SY.get(j, j));
                    for (i = 0; i <= j; ++i) {
                        M.set(h + i, j, 0.0);
                        M.set(j, h + i, 0.0);
                    }
                    for (i = j + 1; i < h; ++i) {
                        M.set(h + i, j, SY.get(i, j));
                        M.set(j, h + i, SY.get(i, j));
                    }
                    for (i = 0; i < h; ++i) {
                        M.set(h + i, h + j, theta * SS.get(i, j));
                    }
                }
                M.uplo(UPLO.LOWER);
                M = M.inverse();
            }
            logger.debug("L-BFGS-B iteration {} moves from {} to {} where f(x) = {}", new Object[]{iter, Arrays.toString(x_old), Arrays.toString(x), f});
            if (!(Math.abs(f_old - f) < TOLF)) continue;
            logger.info(String.format("L-BFGS-B converges on f(x) after %d iterations: %.5f", iter, f));
            return f;
        }
        logger.warn(String.format("L-BFGS-B reaches maximum %d iterations: %.5f", maxIter, f));
        return f;
    }

    private static double[] cauchy(double[] x, double[] g, double[] cauchy, double[] l, double[] u, double theta, Matrix W, Matrix M) {
        int i;
        double fDoublePrime;
        int n = x.length;
        double[] t2 = new double[n];
        double[] d = new double[n];
        for (int i2 = 0; i2 < n; ++i2) {
            double d2 = g[i2] == 0.0 ? Double.MAX_VALUE : (t2[i2] = g[i2] < 0.0 ? (x[i2] - u[i2]) / g[i2] : (x[i2] - l[i2]) / g[i2]);
            if (t2[i2] == 0.0) continue;
            d[i2] = -g[i2];
        }
        int[] index = QuickSort.sort(t2);
        double[] p = W.tv(d);
        double[] c = new double[p.length];
        double fPrime = -MathEx.dot(d, d);
        double f_dp_orig = fDoublePrime = Math.max(-theta * fPrime - M.xAx(p), EPSILON);
        double dt_min = -fPrime / fDoublePrime;
        double t_old = 0.0;
        for (i = 0; i < n && !(t2[index[i]] > 0.0); ++i) {
        }
        double dt = t2[i];
        while (dt_min >= dt && i < n) {
            int b = index[i];
            double tb = t2[i];
            dt = tb - t_old;
            cauchy[b] = d[b] > 0.0 ? u[b] : (d[b] < 0.0 ? l[b] : cauchy[b]);
            double zb = cauchy[b] - x[b];
            for (int j = 0; j < c.length; ++j) {
                int n2 = j;
                c[n2] = c[n2] + p[j] * dt;
            }
            double gb = g[b];
            double[] wbt = W.row(b);
            fPrime += dt * fDoublePrime + gb * gb + theta * gb * zb - gb * MathEx.dot(wbt, M.mv(c));
            fDoublePrime -= theta * gb * gb + 2.0 * gb * MathEx.dot(wbt, M.mv(p)) + gb * gb * M.xAx(wbt);
            fDoublePrime = Math.max(fDoublePrime, EPSILON * f_dp_orig);
            for (int j = 0; j < p.length; ++j) {
                int n3 = j;
                p[n3] = p[n3] + wbt[j] * gb;
            }
            d[b] = 0.0;
            dt_min = -fPrime / fDoublePrime;
            t_old = tb;
            ++i;
        }
        dt_min = Math.max(dt_min, 0.0);
        t_old += dt_min;
        for (int ii = i; ii < n; ++ii) {
            int si = index[ii];
            cauchy[si] = x[si] + t_old * d[si];
        }
        for (int j = 0; j < c.length; ++j) {
            int n4 = j;
            c[n4] = c[n4] + p[j] * dt_min;
        }
        return c;
    }

    private static double[] subspaceMinimization(double[] x, double[] g, double[] cauchy, double[] c, double[] l, double[] u, double theta, Matrix W, Matrix M) {
        int n = x.length;
        double thetaInverse = 1.0 / theta;
        ArrayList<Integer> freeVarIdx = new ArrayList<Integer>();
        for (int i = 0; i < n; ++i) {
            if (cauchy[i] == u[i] || cauchy[i] == l[i]) continue;
            freeVarIdx.add(i);
        }
        if (freeVarIdx.isEmpty()) {
            return (double[])cauchy.clone();
        }
        int freeVarCount = freeVarIdx.size();
        int[] freeVar = new int[freeVarCount];
        for (int i = 0; i < freeVarCount; ++i) {
            freeVar[i] = (Integer)freeVarIdx.get(i);
        }
        double[] wmc = W.mv(M.mv(c));
        double[] r = new double[freeVarCount];
        for (int i = 0; i < freeVarCount; ++i) {
            int fi = freeVar[i];
            r[i] = g[fi] + (cauchy[fi] - x[fi]) * theta - wmc[fi];
        }
        Matrix WZ = W.rows(freeVar);
        double[] v = M.mv(WZ.tv(r));
        Matrix N = WZ.ata().mul(-thetaInverse);
        N = M.mm(N);
        N.addDiag(1.0);
        Matrix.LU lu = N.lu();
        v = lu.solve(v);
        double[] wzv = WZ.mv(v);
        double[] du = new double[freeVarCount];
        for (int i = 0; i < freeVarCount; ++i) {
            du[i] = -thetaInverse * (r[i] + wzv[i] * thetaInverse);
        }
        double alphaStar = BFGS.findAlpha(cauchy, du, l, u, freeVar);
        double[] dStar = new double[freeVarCount];
        for (int i = 0; i < freeVarCount; ++i) {
            dStar[i] = du[i] * alphaStar;
        }
        double[] subspaceMin = (double[])cauchy.clone();
        for (int i = 0; i < freeVarCount; ++i) {
            int n2 = freeVar[i];
            subspaceMin[n2] = subspaceMin[n2] + dStar[i];
        }
        return subspaceMin;
    }

    private static double findAlpha(double[] cauchy, double[] du, double[] l, double[] u, int[] freeVar) {
        double alphaStar = 1.0;
        int n = freeVar.length;
        for (int i = 0; i < n; ++i) {
            int fi = freeVar[i];
            alphaStar = du[i] > 0.0 ? Math.min(alphaStar, (u[fi] - cauchy[fi]) / du[i]) : Math.min(alphaStar, (l[fi] - cauchy[fi]) / du[i]);
        }
        return alphaStar;
    }

    private static double gnorm(double[] x, double[] g, double[] l, double[] u) {
        double norm = 0.0;
        int n = x.length;
        for (int i = 0; i < n; ++i) {
            double gi = g[i];
            gi = gi < 0.0 ? Math.max(x[i] - u[i], gi) : Math.min(x[i] - l[i], gi);
            norm = Math.max(norm, Math.abs(gi));
        }
        return norm;
    }

    private static void clampToBound(double[] v, double[] l, double[] u) {
        int n = v.length;
        for (int i = 0; i < n; ++i) {
            if (v[i] > u[i]) {
                v[i] = u[i];
                continue;
            }
            if (!(v[i] < l[i])) continue;
            v[i] = l[i];
        }
    }
}

