/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.robotics.linearAlgebra.careSolvers;

import java.util.ArrayList;
import java.util.List;
import org.ejml.EjmlUnitTests;
import org.ejml.data.DMatrix;
import org.ejml.data.DMatrix1Row;
import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.junit.jupiter.api.Test;
import us.ihmc.matrixlib.NativeCommonOps;
import us.ihmc.robotics.Assert;
import us.ihmc.robotics.linearAlgebra.careSolvers.CARESolver;
import us.ihmc.robotics.linearAlgebra.careSolvers.CARETools;
import us.ihmc.robotics.linearAlgebra.careSolvers.DefectCorrectionCARESolver;
import us.ihmc.robotics.linearAlgebra.careSolvers.EigenvectorCARESolver;
import us.ihmc.robotics.linearAlgebra.careSolvers.Newton2CARESolver;
import us.ihmc.robotics.linearAlgebra.careSolvers.NewtonCARESolver;
import us.ihmc.robotics.linearAlgebra.careSolvers.SignFunctionCARESolver;

public class CARESolversTest {
    private static final double epsilon = 1.0E-4;

    private List<CARESolver> getSolvers() {
        ArrayList<CARESolver> solvers = new ArrayList<CARESolver>();
        solvers.add((CARESolver)new EigenvectorCARESolver());
        solvers.add((CARESolver)new NewtonCARESolver((CARESolver)new EigenvectorCARESolver()));
        solvers.add((CARESolver)new Newton2CARESolver((CARESolver)new EigenvectorCARESolver()));
        solvers.add((CARESolver)new DefectCorrectionCARESolver((CARESolver)new EigenvectorCARESolver()));
        solvers.add((CARESolver)new SignFunctionCARESolver());
        solvers.add((CARESolver)new NewtonCARESolver((CARESolver)new SignFunctionCARESolver()));
        solvers.add((CARESolver)new Newton2CARESolver((CARESolver)new SignFunctionCARESolver()));
        solvers.add((CARESolver)new DefectCorrectionCARESolver((CARESolver)new SignFunctionCARESolver()));
        return solvers;
    }

    @Test
    public void testSimple() {
        for (CARESolver solver : this.getSolvers()) {
            DMatrixRMaj A = CommonOps_DDRM.identity((int)2);
            DMatrixRMaj B = CommonOps_DDRM.identity((int)2);
            DMatrixRMaj Q = CommonOps_DDRM.identity((int)2);
            DMatrixRMaj R = CommonOps_DDRM.identity((int)2);
            DMatrixRMaj AInput = new DMatrixRMaj(A);
            DMatrixRMaj BInput = new DMatrixRMaj(B);
            DMatrixRMaj QInput = new DMatrixRMaj(Q);
            DMatrixRMaj RInput = new DMatrixRMaj(R);
            solver.setMatrices(A, B, CommonOps_DDRM.identity((int)2), CommonOps_DDRM.identity((int)2), Q, R, null);
            solver.computeP();
            DMatrixRMaj assembledQ = new DMatrixRMaj(2, 2);
            CommonOps_DDRM.multTransA((DMatrix1Row)A, (DMatrix1Row)solver.getP(), (DMatrix1Row)assembledQ);
            CommonOps_DDRM.multAdd((DMatrix1Row)solver.getP(), (DMatrix1Row)A, (DMatrix1Row)assembledQ);
            DMatrixRMaj RInv = new DMatrixRMaj(2, 2);
            NativeCommonOps.invert((DMatrix1Row)R, (DMatrix1Row)RInv);
            DMatrixRMaj BTransposeP = new DMatrixRMaj(2, 2);
            CommonOps_DDRM.multTransA((DMatrix1Row)B, (DMatrix1Row)solver.getP(), (DMatrix1Row)BTransposeP);
            DMatrixRMaj BRInv = new DMatrixRMaj(2, 2);
            CommonOps_DDRM.mult((DMatrix1Row)B, (DMatrix1Row)RInv, (DMatrix1Row)BRInv);
            DMatrixRMaj PBRInv = new DMatrixRMaj(2, 2);
            CommonOps_DDRM.mult((DMatrix1Row)solver.getP(), (DMatrix1Row)BRInv, (DMatrix1Row)PBRInv);
            DMatrixRMaj PBRInvBTransposeP = new DMatrixRMaj(2, 2);
            CommonOps_DDRM.mult((DMatrix1Row)PBRInv, (DMatrix1Row)BTransposeP, (DMatrix1Row)PBRInvBTransposeP);
            CommonOps_DDRM.addEquals((DMatrixD1)assembledQ, (double)-1.0, (DMatrixD1)PBRInvBTransposeP);
            CommonOps_DDRM.scale((double)-1.0, (DMatrixD1)assembledQ);
            EjmlUnitTests.assertEquals((DMatrix)AInput, (DMatrix)A, (double)1.0E-4);
            EjmlUnitTests.assertEquals((DMatrix)BInput, (DMatrix)B, (double)1.0E-4);
            EjmlUnitTests.assertEquals((DMatrix)QInput, (DMatrix)Q, (double)1.0E-4);
            EjmlUnitTests.assertEquals((DMatrix)RInput, (DMatrix)R, (double)1.0E-4);
            CARESolversTest.assertIsSymmetric(solver.getP(), 1.0E-4);
            CARESolversTest.assertSolutionIsValid(AInput, BInput, QInput, RInput, solver.getP(), 1.0E-4);
        }
    }

    @Test
    public void testMatlabCare() {
        for (CARESolver solver : this.getSolvers()) {
            int n = 2;
            int m = 1;
            DMatrixRMaj A = new DMatrixRMaj(n, n);
            DMatrixRMaj B = new DMatrixRMaj(n, m);
            DMatrixRMaj C = new DMatrixRMaj(1, 2);
            DMatrixRMaj Q = new DMatrixRMaj(2, 2);
            DMatrixRMaj R = new DMatrixRMaj(1, 1);
            A.set(0, 0, -3.0);
            A.set(0, 1, 2.0);
            A.set(1, 0, 1.0);
            A.set(1, 1, 1.0);
            B.set(1, 0, 1.0);
            C.set(0, 0, 1.0);
            C.set(0, 1, -1.0);
            R.set(0, 0, 3.0);
            CommonOps_DDRM.multInner((DMatrix1Row)C, (DMatrix1Row)Q);
            solver.setMatrices(A, B, CommonOps_DDRM.identity((int)n), CommonOps_DDRM.identity((int)n), Q, R, null);
            solver.computeP();
            DMatrixRMaj PExpected = new DMatrixRMaj(2, 2);
            PExpected.set(0, 0, 0.5895);
            PExpected.set(0, 1, 1.8216);
            PExpected.set(1, 0, 1.8216);
            PExpected.set(1, 1, 8.8188);
            DMatrixRMaj P = solver.getP();
            CARESolversTest.assertSolutionIsValid(A, B, Q, R, P, 1.0E-4);
            EjmlUnitTests.assertEquals((DMatrix)PExpected, (DMatrix)P, (double)1.0E-4);
        }
    }

    @Test
    public void testMatlabCare2() {
        for (CARESolver solver : this.getSolvers()) {
            int n = 3;
            int m = 1;
            DMatrixRMaj A = new DMatrixRMaj(n, n);
            DMatrixRMaj B = new DMatrixRMaj(n, m);
            DMatrixRMaj C = new DMatrixRMaj(1, n);
            DMatrixRMaj E = CommonOps_DDRM.identity((int)n);
            DMatrixRMaj Q = new DMatrixRMaj(n, n);
            DMatrixRMaj R = new DMatrixRMaj(1, 1);
            A.set(0, 0, 1.0);
            A.set(0, 1, -2.0);
            A.set(0, 2, 3.0);
            A.set(1, 0, -4.0);
            A.set(1, 1, 5.0);
            A.set(1, 2, 6.0);
            A.set(2, 0, 7.0);
            A.set(2, 1, 8.0);
            A.set(2, 2, 9.0);
            B.set(0, 0, 5.0);
            B.set(1, 0, 6.0);
            B.set(2, 0, -7.0);
            C.set(0, 0, 7.0);
            C.set(0, 1, -8.0);
            C.set(0, 2, 9.0);
            R.set(0, 0, 1.0);
            CommonOps_DDRM.multInner((DMatrix1Row)C, (DMatrix1Row)Q);
            solver.setMatrices(A, B, CommonOps_DDRM.identity((int)n), E, Q, R, null);
            solver.computeP();
            DMatrixRMaj RInverse = new DMatrixRMaj(m, m);
            DMatrixRMaj BTranspose = new DMatrixRMaj(m, n);
            DMatrixRMaj M = new DMatrixRMaj(n, n);
            NativeCommonOps.invert((DMatrix1Row)R, (DMatrix1Row)RInverse);
            CommonOps_DDRM.transpose((DMatrixRMaj)B, (DMatrixRMaj)BTranspose);
            NativeCommonOps.multQuad((DMatrix1Row)BTranspose, (DMatrix1Row)RInverse, (DMatrix1Row)M);
            CARESolversTest.assertSolutionIsValid(A, B, Q, R, solver.getP(), 1.0E-4);
        }
    }

    private static void assertIsSymmetric(DMatrixRMaj A, double epsilon) {
        for (int row = 0; row < A.getNumRows(); ++row) {
            for (int col = 0; col < A.getNumCols(); ++col) {
                Assert.assertEquals("Not symmetric!", A.get(row, col), A.get(col, row), epsilon);
            }
        }
    }

    static void assertSolutionIsValid(DMatrixRMaj A, DMatrixRMaj B, DMatrixRMaj Q, DMatrixRMaj R, DMatrixRMaj P, double epsilon) {
        int n = A.getNumRows();
        int m = B.getNumCols();
        DMatrixRMaj PDot = new DMatrixRMaj(n, n);
        DMatrixRMaj M = new DMatrixRMaj(m, m);
        DMatrixRMaj BTranspose = new DMatrixRMaj(m, n);
        CommonOps_DDRM.transpose((DMatrixRMaj)B, (DMatrixRMaj)BTranspose);
        CARETools.computeM((DMatrixRMaj)BTranspose, (DMatrixRMaj)R, null, (DMatrixRMaj)M);
        CARETools.computeRiccatiRate((DMatrixRMaj)P, (DMatrixRMaj)A, (DMatrixRMaj)Q, (DMatrixRMaj)M, (DMatrixRMaj)PDot);
        EjmlUnitTests.assertEquals((DMatrix)new DMatrixRMaj(n, n), (DMatrix)PDot, (double)epsilon);
    }
}

