/*
 * Decompiled with CFR 0.152.
 */
package smile.math.matrix.fp32;

import java.io.Serializable;
import java.nio.FloatBuffer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import smile.math.MathEx;
import smile.math.blas.BLAS;
import smile.math.blas.LAPACK;
import smile.math.blas.Layout;
import smile.math.blas.Transpose;
import smile.math.blas.UPLO;
import smile.math.matrix.fp32.IMatrix;
import smile.math.matrix.fp32.Matrix;

public class BandMatrix
extends IMatrix {
    private static final long serialVersionUID = 2L;
    private static final Logger logger = LoggerFactory.getLogger(BandMatrix.class);
    final float[] AB;
    final int m;
    final int n;
    final int kl;
    final int ku;
    final int ld;
    UPLO uplo = null;

    public BandMatrix(int m, int n, int kl, int ku) {
        if (m <= 0 || n <= 0) {
            throw new IllegalArgumentException(String.format("Invalid matrix size: %d x %d", m, n));
        }
        if (kl < 0 || ku < 0) {
            throw new IllegalArgumentException(String.format("Invalid subdiagonals or superdiagonals: kl = %d, ku = %d", kl, ku));
        }
        if (kl >= m) {
            throw new IllegalArgumentException(String.format("Invalid subdiagonals %d >= %d", kl, m));
        }
        if (ku >= n) {
            throw new IllegalArgumentException(String.format("Invalid superdiagonals %d >= %d", ku, n));
        }
        this.m = m;
        this.n = n;
        this.kl = kl;
        this.ku = ku;
        this.ld = kl + ku + 1;
        this.AB = new float[this.ld * n];
    }

    public BandMatrix(int m, int n, int kl, int ku, float[][] AB) {
        this(m, n, kl, ku);
        for (int j = 0; j < n; ++j) {
            for (int i = 0; i < this.ld; ++i) {
                this.AB[j * this.ld + i] = AB[i][j];
            }
        }
    }

    public BandMatrix clone() {
        BandMatrix matrix = new BandMatrix(this.m, this.n, this.kl, this.ku);
        System.arraycopy(this.AB, 0, matrix.AB, 0, this.AB.length);
        if (this.m == this.n && this.kl == this.ku) {
            matrix.uplo(this.uplo);
        }
        return matrix;
    }

    @Override
    public int nrow() {
        return this.m;
    }

    @Override
    public int ncol() {
        return this.n;
    }

    @Override
    public long size() {
        return this.AB.length;
    }

    public int kl() {
        return this.kl;
    }

    public int ku() {
        return this.ku;
    }

    public Layout layout() {
        return Layout.COL_MAJOR;
    }

    public int ld() {
        return this.ld;
    }

    public boolean isSymmetric() {
        return this.uplo != null;
    }

    public BandMatrix uplo(UPLO uplo) {
        if (this.m != this.n) {
            throw new IllegalArgumentException(String.format("The matrix is not square: %d x %d", this.m, this.n));
        }
        if (this.kl != this.ku) {
            throw new IllegalArgumentException(String.format("kl != ku: %d != %d", this.kl, this.ku));
        }
        this.uplo = uplo;
        return this;
    }

    public UPLO uplo() {
        return this.uplo;
    }

    public boolean equals(Object o) {
        if (!(o instanceof BandMatrix)) {
            return false;
        }
        return this.equals((BandMatrix)o, 1.0E-7f);
    }

    public boolean equals(BandMatrix o, float epsilon) {
        if (this.m != o.m || this.n != o.n) {
            return false;
        }
        for (int j = 0; j < this.n; ++j) {
            for (int i = 0; i < this.m; ++i) {
                if (MathEx.isZero(this.get(i, j) - o.get(i, j), epsilon)) continue;
                return false;
            }
        }
        return true;
    }

    @Override
    public float get(int i, int j) {
        if (Math.max(0, j - this.ku) <= i && i <= Math.min(this.m - 1, j + this.kl)) {
            return this.AB[j * this.ld + this.ku + i - j];
        }
        return 0.0f;
    }

    @Override
    public void set(int i, int j, float x) {
        if (Math.max(0, j - this.ku) > i || i > Math.min(this.m - 1, j + this.kl)) {
            throw new UnsupportedOperationException(String.format("Set element at (%d, %d)", i, j));
        }
        this.AB[j * this.ld + this.ku + i - j] = x;
    }

    @Override
    public void mv(Transpose trans, float alpha, float[] x, float beta, float[] y) {
        if (this.uplo != null) {
            BLAS.engine.sbmv(this.layout(), this.uplo, this.n, this.kl, alpha, this.AB, this.ld, x, 1, beta, y, 1);
        } else {
            BLAS.engine.gbmv(this.layout(), trans, this.m, this.n, this.kl, this.ku, alpha, this.AB, this.ld, x, 1, beta, y, 1);
        }
    }

    @Override
    public void mv(float[] work, int inputOffset, int outputOffset) {
        FloatBuffer xb = FloatBuffer.wrap(work, inputOffset, this.n);
        FloatBuffer yb = FloatBuffer.wrap(work, outputOffset, this.m);
        if (this.uplo != null) {
            BLAS.engine.sbmv(this.layout(), this.uplo, this.n, this.kl, 1.0f, FloatBuffer.wrap(this.AB), this.ld, xb, 1, 0.0f, yb, 1);
        } else {
            BLAS.engine.gbmv(this.layout(), Transpose.NO_TRANSPOSE, this.m, this.n, this.kl, this.ku, 1.0f, FloatBuffer.wrap(this.AB), this.ld, xb, 1, 0.0f, yb, 1);
        }
    }

    @Override
    public void tv(float[] work, int inputOffset, int outputOffset) {
        FloatBuffer xb = FloatBuffer.wrap(work, inputOffset, this.m);
        FloatBuffer yb = FloatBuffer.wrap(work, outputOffset, this.n);
        if (this.uplo != null) {
            BLAS.engine.sbmv(this.layout(), this.uplo, this.n, this.kl, 1.0f, FloatBuffer.wrap(this.AB), this.ld, xb, 1, 0.0f, yb, 1);
        } else {
            BLAS.engine.gbmv(this.layout(), Transpose.TRANSPOSE, this.m, this.n, this.kl, this.ku, 1.0f, FloatBuffer.wrap(this.AB), this.ld, xb, 1, 0.0f, yb, 1);
        }
    }

    public LU lu() {
        BandMatrix lu = new BandMatrix(this.m, this.n, 2 * this.kl, this.ku);
        for (int j = 0; j < this.n; ++j) {
            for (int i = 0; i < this.ld; ++i) {
                lu.AB[j * lu.ld + this.kl + i] = this.AB[j * this.ld + i];
            }
        }
        int[] ipiv = new int[this.n];
        int info = LAPACK.engine.gbtrf(lu.layout(), lu.m, lu.n, lu.kl / 2, lu.ku, lu.AB, lu.ld, ipiv);
        if (info < 0) {
            logger.error("LAPACK GBTRF error code: {}", (Object)info);
            throw new ArithmeticException("LAPACK GBTRF error code: " + info);
        }
        return new LU(lu, ipiv, info);
    }

    public Cholesky cholesky() {
        int info;
        int j;
        if (this.uplo == null) {
            throw new IllegalArgumentException("The matrix is not symmetric");
        }
        BandMatrix lu = new BandMatrix(this.m, this.n, this.uplo == UPLO.LOWER ? this.kl : 0, this.uplo == UPLO.LOWER ? 0 : this.ku);
        lu.uplo = this.uplo;
        if (this.uplo == UPLO.LOWER) {
            for (j = 0; j < this.n; ++j) {
                for (int i = 0; i <= this.kl; ++i) {
                    lu.AB[j * lu.ld + i] = this.get(j + i, j);
                }
            }
        } else {
            for (j = 0; j < this.n; ++j) {
                for (int i = 0; i <= this.ku; ++i) {
                    lu.AB[j * lu.ld + this.ku - i] = this.get(j - i, j);
                }
            }
        }
        if ((info = LAPACK.engine.pbtrf(lu.layout(), lu.uplo, lu.n, lu.uplo == UPLO.LOWER ? lu.kl : lu.ku, lu.AB, lu.ld)) != 0) {
            logger.error("LAPACK PBTRF error code: {}", (Object)info);
            throw new ArithmeticException("LAPACK PBTRF error code: " + info);
        }
        return new Cholesky(lu);
    }

    public static class LU
    implements Serializable {
        private static final long serialVersionUID = 2L;
        public final BandMatrix lu;
        public final int[] ipiv;
        public final int info;

        public LU(BandMatrix lu, int[] ipiv, int info) {
            this.lu = lu;
            this.ipiv = ipiv;
            this.info = info;
        }

        public boolean isSingular() {
            return this.info > 0;
        }

        public float det() {
            int j;
            int m = this.lu.m;
            int n = this.lu.n;
            if (m != n) {
                throw new IllegalArgumentException(String.format("The matrix is not square: %d x %d", m, n));
            }
            double d = 1.0;
            for (j = 0; j < n; ++j) {
                d *= (double)this.lu.AB[j * this.lu.ld + this.lu.kl / 2 + this.lu.ku];
            }
            for (j = 0; j < n; ++j) {
                if (j + 1 == this.ipiv[j]) continue;
                d = -d;
            }
            return (float)d;
        }

        public Matrix inverse() {
            Matrix inv = Matrix.eye(this.lu.n);
            this.solve(inv);
            return inv;
        }

        public float[] solve(float[] b) {
            Matrix x = Matrix.column(b);
            this.solve(x);
            return x.A;
        }

        public void solve(Matrix B) {
            if (this.lu.m != this.lu.n) {
                throw new IllegalArgumentException(String.format("The matrix is not square: %d x %d", this.lu.m, this.lu.n));
            }
            if (B.m != this.lu.m) {
                throw new IllegalArgumentException(String.format("Row dimensions do not agree: A is %d x %d, but B is %d x %d", this.lu.m, this.lu.n, B.m, B.n));
            }
            if (this.lu.layout() != B.layout()) {
                throw new IllegalArgumentException("The matrix layout is inconsistent.");
            }
            if (this.info > 0) {
                throw new RuntimeException("The matrix is singular.");
            }
            int ret = LAPACK.engine.gbtrs(this.lu.layout(), Transpose.NO_TRANSPOSE, this.lu.n, this.lu.kl / 2, this.lu.ku, B.n, this.lu.AB, this.lu.ld, this.ipiv, B.A, B.ld);
            if (ret != 0) {
                logger.error("LAPACK GETRS error code: {}", (Object)ret);
                throw new ArithmeticException("LAPACK GETRS error code: " + ret);
            }
        }
    }

    public static class Cholesky
    implements Serializable {
        private static final long serialVersionUID = 2L;
        public final BandMatrix lu;

        public Cholesky(BandMatrix lu) {
            if (lu.nrow() != lu.ncol()) {
                throw new UnsupportedOperationException("Cholesky constructor on a non-square matrix");
            }
            this.lu = lu;
        }

        public float det() {
            double d = 1.0;
            for (int i = 0; i < this.lu.n; ++i) {
                d *= (double)this.lu.get(i, i);
            }
            return (float)(d * d);
        }

        public float logdet() {
            int n = this.lu.n;
            double d = 0.0;
            for (int i = 0; i < n; ++i) {
                d += Math.log(this.lu.get(i, i));
            }
            return (float)(2.0 * d);
        }

        public Matrix inverse() {
            Matrix inv = Matrix.eye(this.lu.n);
            this.solve(inv);
            return inv;
        }

        public float[] solve(float[] b) {
            Matrix x = Matrix.column(b);
            this.solve(x);
            return x.A;
        }

        public void solve(Matrix B) {
            if (B.m != this.lu.m) {
                throw new IllegalArgumentException(String.format("Row dimensions do not agree: A is %d x %d, but B is %d x %d", this.lu.m, this.lu.n, B.m, B.n));
            }
            int info = LAPACK.engine.pbtrs(this.lu.layout(), this.lu.uplo, this.lu.n, this.lu.uplo == UPLO.LOWER ? this.lu.kl : this.lu.ku, B.n, this.lu.AB, this.lu.ld, B.A, B.ld);
            if (info != 0) {
                logger.error("LAPACK POTRS error code: {}", (Object)info);
                throw new ArithmeticException("LAPACK POTRS error code: " + info);
            }
        }
    }
}

