/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.robotics.linearDynamicSystems;

import java.util.Random;
import org.ejml.EjmlUnitTests;
import org.ejml.data.DMatrix;
import org.ejml.data.DMatrix1Row;
import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.ejml.dense.row.RandomMatrices_DDRM;
import org.jfree.chart.JFreeChart;
import org.jfree.chart.axis.AxisLocation;
import org.jfree.chart.axis.NumberAxis;
import org.jfree.chart.axis.ValueAxis;
import org.jfree.chart.plot.CombinedDomainXYPlot;
import org.jfree.chart.plot.Plot;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.XYPlot;
import org.jfree.chart.renderer.xy.StandardXYItemRenderer;
import org.jfree.chart.renderer.xy.XYItemRenderer;
import org.jfree.data.xy.XYDataset;
import org.jfree.data.xy.XYSeries;
import org.jfree.data.xy.XYSeriesCollection;
import org.junit.jupiter.api.Test;
import us.ihmc.commons.MathTools;
import us.ihmc.matrixlib.MatrixTestTools;
import us.ihmc.robotics.Assert;
import us.ihmc.robotics.linearDynamicSystems.SingleMatrixExponentialStateSpaceSystemDiscretizer;
import us.ihmc.robotics.linearDynamicSystems.SplitUpMatrixExponentialStateSpaceSystemDiscretizer;
import us.ihmc.robotics.linearDynamicSystems.StateSpaceSystemDiscretizer;

public class StateSpaceSystemDiscretizerTest {
    private static final boolean DEBUG = false;
    private static final boolean DISPLAY_GRAPHS_AND_SLEEP_FOREVER = false;

    @Test
    public void testWithSimpleSpringDamperSystem() {
        int numberOfStates = 2;
        int numberOfInputs = 1;
        SingleMatrixExponentialStateSpaceSystemDiscretizer stateSpaceSystemDiscretizer = new SingleMatrixExponentialStateSpaceSystemDiscretizer(numberOfStates, numberOfInputs);
        double springK = 100.0;
        double damperB = 1.0;
        double discretizationTimeStep = 0.01;
        double totalSimulationTime = 5.0;
        double xProcessCovariance = 0.1;
        double xDotProcessCovariance = 0.012;
        double xSensorCovariance = 0.03;
        double xInitial = 1.0;
        double xDotInitial = 0.1;
        DMatrixRMaj continuousA = new DMatrixRMaj((double[][])new double[][]{{0.0, 1.0}, {-springK, -damperB}});
        DMatrixRMaj continuousB = new DMatrixRMaj((double[][])new double[][]{{0.0}, {1.0}});
        DMatrixRMaj continuousQ = new DMatrixRMaj((double[][])new double[][]{{xProcessCovariance, 0.0}, {0.0, xDotProcessCovariance}});
        DMatrixRMaj continuousR = new DMatrixRMaj((double[][])new double[][]{{xSensorCovariance}});
        this.printSystemMatrices("Continuous: ", continuousA, continuousB, continuousQ, continuousR, 1.0);
        DMatrixRMaj simpleDiscreteA = new DMatrixRMaj(continuousA);
        DMatrixRMaj simpleDiscreteB = new DMatrixRMaj(continuousB);
        DMatrixRMaj simpleDiscreteQ = new DMatrixRMaj(continuousQ);
        DMatrixRMaj simpleDiscreteR = new DMatrixRMaj(continuousR);
        CommonOps_DDRM.scale((double)discretizationTimeStep, (DMatrixD1)simpleDiscreteA);
        CommonOps_DDRM.add((DMatrixD1)simpleDiscreteA, (DMatrixD1)CommonOps_DDRM.identity((int)numberOfStates), (DMatrixD1)simpleDiscreteA);
        CommonOps_DDRM.scale((double)discretizationTimeStep, (DMatrixD1)simpleDiscreteB);
        CommonOps_DDRM.scale((double)discretizationTimeStep, (DMatrixD1)simpleDiscreteQ);
        this.printSystemMatrices("Simple Discrete: ", simpleDiscreteA, simpleDiscreteB, simpleDiscreteQ, simpleDiscreteR, discretizationTimeStep);
        DMatrixRMaj discreteA = new DMatrixRMaj(continuousA);
        DMatrixRMaj discreteB = new DMatrixRMaj(continuousB);
        DMatrixRMaj discreteQ = new DMatrixRMaj(continuousQ);
        DMatrixRMaj discreteR = new DMatrixRMaj(continuousR);
        stateSpaceSystemDiscretizer.discretize(discreteA, discreteB, discreteQ, discretizationTimeStep);
        MatrixTestTools.assertMatrixEquals((DMatrix)discreteR, (DMatrix)continuousR, (double)1.0E-7);
        this.printSystemMatrices("Discrete: ", discreteA, discreteB, discreteQ, discreteR, discretizationTimeStep);
        int numberOfPoints = (int)(totalSimulationTime / discretizationTimeStep);
        int numberOfIntegrationSteps = 1000;
        double eulerStepSize = discretizationTimeStep / (double)numberOfIntegrationSteps;
        double[] time = new double[numberOfPoints];
        double[] xContinuous = new double[numberOfPoints];
        double[] xDiscreteSimple = new double[numberOfPoints];
        double[] xDiscrete = new double[numberOfPoints];
        double t = 0.0;
        double x = xInitial;
        double xDot = xDotInitial;
        double uAmplitude = 2.5;
        double uFreq = 2.0;
        DMatrixRMaj stateDiscreteSimple = new DMatrixRMaj((double[][])new double[][]{{xInitial}, {xDotInitial}});
        DMatrixRMaj stateDiscrete = new DMatrixRMaj((double[][])new double[][]{{xInitial}, {xDotInitial}});
        DMatrixRMaj input = new DMatrixRMaj(1, 1);
        for (int pointNumber = 0; pointNumber < numberOfPoints; ++pointNumber) {
            time[pointNumber] = t;
            xContinuous[pointNumber] = x;
            xDiscreteSimple[pointNumber] = stateDiscreteSimple.get(0, 0);
            xDiscrete[pointNumber] = stateDiscrete.get(0, 0);
            double u = uAmplitude * Math.cos(Math.PI * 2 * uFreq * t);
            input.set(0, 0, u);
            for (int i = 0; i < numberOfIntegrationSteps; ++i) {
                double xNext = x + xDot * eulerStepSize;
                double xDotNext = xDot + (-springK * x - damperB * xDot) * eulerStepSize + u * eulerStepSize;
                t += eulerStepSize;
                x = xNext;
                xDot = xDotNext;
            }
            stateDiscreteSimple = this.computeNextState(stateDiscreteSimple, input, simpleDiscreteA, simpleDiscreteB);
            stateDiscrete = this.computeNextState(stateDiscrete, input, discreteA, discreteB);
        }
        double squaredErrorDiscrete = 0.0;
        double squaredErrorDiscreteSimple = 0.0;
        for (int i = 0; i < xContinuous.length; ++i) {
            squaredErrorDiscrete += MathTools.square((double)(xContinuous[i] - xDiscrete[i]));
            squaredErrorDiscreteSimple += MathTools.square((double)(xContinuous[i] - xDiscreteSimple[i]));
        }
        this.printIfDebug("squaredErrorDiscrete = " + (squaredErrorDiscrete /= (double)xContinuous.length));
        this.printIfDebug("squaredErrorDiscreteSimple = " + (squaredErrorDiscreteSimple /= (double)xContinuous.length));
        Assert.assertTrue(squaredErrorDiscrete < 1.0E-4);
        Assert.assertTrue(squaredErrorDiscrete * 100.0 < squaredErrorDiscreteSimple);
        Assert.assertEquals("Regression Test. Only will be true for certain values. If failing, check changes.", 1.5445857883898253E-5, squaredErrorDiscrete, 1.0E-7);
        Assert.assertEquals("Regression Test. Only will be true for certain values. If failing, check changes.", 0.25136292086153883, squaredErrorDiscreteSimple, 1.0E-7);
    }

    @Test
    public void testCompareDifferentImplementations() {
        int numberOfStates = 30;
        int numberOfInputs = 10;
        SingleMatrixExponentialStateSpaceSystemDiscretizer singleMatrixExponentialDiscretizer = new SingleMatrixExponentialStateSpaceSystemDiscretizer(numberOfStates, numberOfInputs);
        SplitUpMatrixExponentialStateSpaceSystemDiscretizer splitUpMatrixExponentialDiscretizer = new SplitUpMatrixExponentialStateSpaceSystemDiscretizer(numberOfStates, numberOfInputs);
        StateSpaceSystemDiscretizer[] discretizers = new StateSpaceSystemDiscretizer[]{singleMatrixExponentialDiscretizer, splitUpMatrixExponentialDiscretizer};
        Random random = new Random(125L);
        DMatrixRMaj A = RandomMatrices_DDRM.rectangle((int)numberOfStates, (int)numberOfStates, (Random)random);
        DMatrixRMaj B = RandomMatrices_DDRM.rectangle((int)numberOfStates, (int)numberOfInputs, (Random)random);
        DMatrixRMaj Q = RandomMatrices_DDRM.symmetricPosDef((int)numberOfStates, (Random)random);
        DMatrixRMaj[] As = new DMatrixRMaj[]{new DMatrixRMaj(A), new DMatrixRMaj(A)};
        DMatrixRMaj[] Bs = new DMatrixRMaj[]{new DMatrixRMaj(B), new DMatrixRMaj(B)};
        DMatrixRMaj[] Qs = new DMatrixRMaj[]{new DMatrixRMaj(Q), new DMatrixRMaj(Q)};
        double dt = 1.0;
        for (int i = 0; i < discretizers.length; ++i) {
            discretizers[i].discretize(As[i], Bs[i], Qs[i], dt);
        }
        double tol = 1.0E-12;
        EjmlUnitTests.assertEquals((DMatrix)As[0], (DMatrix)As[1], (double)tol);
        EjmlUnitTests.assertEquals((DMatrix)Bs[0], (DMatrix)Bs[1], (double)tol);
        EjmlUnitTests.assertEquals((DMatrix)Qs[0], (DMatrix)Qs[1], (double)tol);
    }

    private void printSystemMatrices(String name, DMatrixRMaj A, DMatrixRMaj B, DMatrixRMaj Q, DMatrixRMaj R, double deltaT) {
        DMatrixRMaj QDividedByDeltaT = new DMatrixRMaj(Q);
        CommonOps_DDRM.scale((double)(1.0 / deltaT), (DMatrixD1)QDividedByDeltaT);
        this.printIfDebug(name);
        this.printIfDebug("A = " + A);
        this.printIfDebug("B = " + B);
        this.printIfDebug("Scaled Q = " + QDividedByDeltaT);
        this.printIfDebug("R = " + R);
    }

    private void printIfDebug(String string) {
    }

    private JFreeChart plot(String title, double[] time, double[] xContinuous, double[] xDiscreteSimple, double[] xDiscrete) {
        XYDataset continuousDataSet = this.createDataset(time, xContinuous);
        XYDataset discreteSimpleDataset = this.createDataset(time, xDiscreteSimple);
        XYDataset discreteDataset = this.createDataset(time, xDiscrete);
        StandardXYItemRenderer renderer1 = new StandardXYItemRenderer();
        NumberAxis rangeAxis1 = new NumberAxis("Continuous");
        XYPlot subplot1 = new XYPlot(continuousDataSet, null, (ValueAxis)rangeAxis1, (XYItemRenderer)renderer1);
        subplot1.setRangeAxisLocation(AxisLocation.BOTTOM_OR_LEFT);
        renderer1.setSeriesVisibleInLegend(0, Boolean.valueOf(false));
        StandardXYItemRenderer renderer2 = new StandardXYItemRenderer();
        NumberAxis rangeAxis2 = new NumberAxis("DiscreteSimple");
        XYPlot subplot2 = new XYPlot(discreteSimpleDataset, null, (ValueAxis)rangeAxis2, (XYItemRenderer)renderer2);
        subplot2.setRangeAxisLocation(AxisLocation.TOP_OR_LEFT);
        renderer2.setSeriesVisibleInLegend(0, Boolean.valueOf(false));
        StandardXYItemRenderer renderer3 = new StandardXYItemRenderer();
        NumberAxis rangeAxis3 = new NumberAxis("Discrete");
        XYPlot subplot3 = new XYPlot(discreteDataset, null, (ValueAxis)rangeAxis3, (XYItemRenderer)renderer3);
        subplot3.setRangeAxisLocation(AxisLocation.TOP_OR_LEFT);
        renderer3.setSeriesVisibleInLegend(0, Boolean.valueOf(false));
        CombinedDomainXYPlot plot = new CombinedDomainXYPlot();
        plot.add(subplot1, 1);
        plot.add(subplot2, 1);
        plot.add(subplot3, 1);
        plot.setOrientation(PlotOrientation.VERTICAL);
        return new JFreeChart(title, JFreeChart.DEFAULT_TITLE_FONT, (Plot)plot, true);
    }

    private XYDataset createDataset(double[] xdata, double[] ydata) {
        XYSeries series = new XYSeries((Comparable)((Object)"data series"), false);
        XYSeriesCollection dataset = new XYSeriesCollection();
        for (int i = 0; i < xdata.length; ++i) {
            series.add(xdata[i], ydata[i]);
        }
        dataset.addSeries(series);
        return dataset;
    }

    private DMatrixRMaj computeNextState(DMatrixRMaj state, DMatrixRMaj input, DMatrixRMaj A, DMatrixRMaj B) {
        int numberOfStates = state.getNumRows();
        DMatrixRMaj nextState = new DMatrixRMaj(numberOfStates, 1);
        DMatrixRMaj Ax = new DMatrixRMaj(numberOfStates, 1);
        DMatrixRMaj Bu = new DMatrixRMaj(numberOfStates, 1);
        CommonOps_DDRM.mult((DMatrix1Row)A, (DMatrix1Row)state, (DMatrix1Row)Ax);
        CommonOps_DDRM.mult((DMatrix1Row)B, (DMatrix1Row)input, (DMatrix1Row)Bu);
        CommonOps_DDRM.add((DMatrixD1)Ax, (DMatrixD1)Bu, (DMatrixD1)nextState);
        return nextState;
    }
}

