/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.robotics.functionApproximation;

import java.util.ArrayList;
import java.util.Random;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import us.ihmc.robotics.Assert;
import us.ihmc.robotics.functionApproximation.LinearRegression;

public class LinearRegressionTest {
    private static final boolean VERBOSE = false;

    @Test
    public void testTypicalExampleOne() {
        Random random = new Random(1984L);
        ArrayList<double[]> inputs = new ArrayList<double[]>();
        ArrayList<Double> outputs = new ArrayList<Double>();
        double xCoefficient = 1.0;
        double xSquaredCoefficient = 5.0;
        for (int i = 0; i < 500; ++i) {
            double x = random.nextDouble() * 2.0 - xCoefficient;
            double[] input = new double[]{xCoefficient, x, x * x};
            double output = random.nextDouble() * 0.1 + xCoefficient * x + xSquaredCoefficient * x * x;
            inputs.add(input);
            outputs.add(output);
        }
        LinearRegression linearRegression = new LinearRegression(inputs, outputs);
        double runTimeInMilliseconds = this.solveAndReturnRuntimeInMilliseconds(linearRegression);
        double[] coefficientArray = new double[3];
        linearRegression.getCoefficientVector(coefficientArray);
        this.printResults(linearRegression, runTimeInMilliseconds, coefficientArray);
        double epsilon = 0.1;
        double maxSquaredError = 0.001;
        double maxRuntimeInMilliseconds = 7.2;
        double[] expectedCoefficients = new double[]{0.05, xCoefficient, xSquaredCoefficient};
        this.assertResultsAreAsExpected(linearRegression, runTimeInMilliseconds, coefficientArray, epsilon, maxSquaredError, maxRuntimeInMilliseconds, expectedCoefficients);
    }

    @Test
    public void testTypicalExampleTwo() {
        Random random = new Random(1776L);
        int numberOfPoints = 500;
        double[][] inputs = new double[numberOfPoints][];
        double[] outputs = new double[numberOfPoints];
        double xCoefficient = 1.0;
        double xSquaredCoefficient = 0.0;
        double yCoefficient = 0.2;
        double ySquaredCoefficient = -3.0;
        double xyCoefficient = 5.0;
        for (int i = 0; i < numberOfPoints; ++i) {
            double x = random.nextDouble() * 2.0 - 1.0;
            double y = random.nextDouble() * 2.0 - 1.0;
            double[] input = new double[]{1.0, x, x * x, y, y * y, x * y};
            double output = 4.0 + random.nextDouble() * 0.1 + xCoefficient * x + xSquaredCoefficient * x * x + yCoefficient * y + ySquaredCoefficient * y * y + xyCoefficient * x * y;
            inputs[i] = input;
            outputs[i] = output;
        }
        LinearRegression linearRegression = new LinearRegression((double[][])inputs, outputs);
        double runTimeInMilliseconds = this.solveAndReturnRuntimeInMilliseconds(linearRegression);
        double[] coefficientArray = new double[6];
        linearRegression.getCoefficientVector(coefficientArray);
        this.printResults(linearRegression, runTimeInMilliseconds, coefficientArray);
        double epsilon = 0.1;
        double maxSquaredError = 0.001;
        double maxRuntimeInMilliseconds = 5.0;
        double[] expectedCoefficients = new double[]{4.05, xCoefficient, xSquaredCoefficient, yCoefficient, ySquaredCoefficient, xyCoefficient};
        this.assertResultsAreAsExpected(linearRegression, runTimeInMilliseconds, coefficientArray, epsilon, maxSquaredError, maxRuntimeInMilliseconds, expectedCoefficients);
    }

    @Test
    public void testPerfectMatch() {
        Random random = new Random(2000L);
        ArrayList<double[]> inputs = new ArrayList<double[]>();
        ArrayList<Double> outputs = new ArrayList<Double>();
        double unitsCoefficient = 90.0;
        double xCoefficient = 5.0;
        double xSquaredCoefficient = 6.0;
        double yCoefficient = 7.0;
        double ySquaredCoefficient = -8.0;
        double xyCoefficient = 13.0;
        for (int i = 0; i < 500; ++i) {
            double x = random.nextDouble() * 4.0 - 1.0;
            double y = random.nextDouble() * 4.0 - 1.0;
            double[] input = new double[]{1.0, x, x * x, y, y * y, x * y};
            double output = unitsCoefficient + xCoefficient * x + xSquaredCoefficient * x * x + yCoefficient * y + ySquaredCoefficient * y * y + xyCoefficient * x * y;
            inputs.add(input);
            outputs.add(output);
        }
        LinearRegression linearRegression = new LinearRegression(inputs, outputs);
        double runTimeInMilliseconds = this.solveAndReturnRuntimeInMilliseconds(linearRegression);
        double[] coefficientArray = new double[6];
        linearRegression.getCoefficientVector(coefficientArray);
        this.printResults(linearRegression, runTimeInMilliseconds, coefficientArray);
        double epsilon = 1.0E-10;
        double maxSquaredError = 1.0E-14;
        double maxRuntimeInMilliseconds = 5.0;
        double[] expectedCoefficients = new double[]{unitsCoefficient, xCoefficient, xSquaredCoefficient, yCoefficient, ySquaredCoefficient, xyCoefficient};
        this.assertResultsAreAsExpected(linearRegression, runTimeInMilliseconds, coefficientArray, epsilon, maxSquaredError, maxRuntimeInMilliseconds, expectedCoefficients);
    }

    @Test
    public void testRandomness() {
        Random random = new Random(1776L);
        ArrayList<double[]> inputs = new ArrayList<double[]>();
        ArrayList<Double> outputs = new ArrayList<Double>();
        for (int i = 0; i < 5; ++i) {
            double[] input = new double[]{random.nextDouble() * 100.0, random.nextDouble() * 10.0};
            double output = random.nextDouble() * 500.0;
            inputs.add(input);
            outputs.add(output);
        }
        LinearRegression linearRegression = new LinearRegression(inputs, outputs);
        boolean foundSolution = linearRegression.solve();
        Assert.assertTrue(foundSolution);
        double squaredError = linearRegression.getSquaredError();
    }

    @Test
    public void testNotEnoughPoints() {
        Assertions.assertThrows(IllegalArgumentException.class, () -> {
            ArrayList<double[]> inputs = new ArrayList<double[]>();
            ArrayList<Double> outputs = new ArrayList<Double>();
            inputs.add(new double[]{1.0});
            outputs.add(2.0);
            outputs.add(3.0);
            LinearRegression linearRegression = new LinearRegression(inputs, outputs);
            boolean foundSolution = linearRegression.solve();
        });
    }

    @Test
    public void testAskingForAnswerBeforeDone() {
        Assertions.assertThrows(IllegalStateException.class, () -> {
            double[][] inputs = new double[][]{{1.0}, {1.0}};
            double[] outputs = new double[]{1.0, 1.0};
            double[] coefficientVector = new double[1];
            LinearRegression linearRegression = new LinearRegression((double[][])inputs, outputs);
            linearRegression.getCoefficientVector(coefficientVector);
        });
    }

    @Test
    public void testAskingForSquaredErrorBeforeDone() {
        Assertions.assertThrows(IllegalStateException.class, () -> {
            double[][] inputs = new double[][]{{1.0}, {1.0}};
            double[] outputs = new double[]{1.0, 1.0};
            LinearRegression linearRegression = new LinearRegression((double[][])inputs, outputs);
            linearRegression.getSquaredError();
        });
    }

    @Test
    public void testAskingForCoefficientVectorAsMatrixBeforeDone() {
        Assertions.assertThrows(IllegalStateException.class, () -> {
            double[][] inputs = new double[][]{{1.0}, {1.0}};
            double[] outputs = new double[]{1.0, 1.0};
            LinearRegression linearRegression = new LinearRegression((double[][])inputs, outputs);
            linearRegression.getCoefficientVectorAsMatrix();
        });
    }

    private void assertResultsAreAsExpected(LinearRegression linearRegression, double runTimeInMilliseconds, double[] coefficientArray, double epsilon, double maxSquaredError, double maxRuntimeInMilliseconds, double[] expectedCoefficients) {
        Assert.assertEquals(expectedCoefficients.length, coefficientArray.length);
        for (int i = 0; i < expectedCoefficients.length; ++i) {
            Assert.assertEquals(expectedCoefficients[i], coefficientArray[i], epsilon);
        }
        Assert.assertTrue("linearRegression.getSquaredError() was less than 0.0!", linearRegression.getSquaredError() > 0.0);
        Assert.assertTrue("linearRegression.getSquaredError() = " + linearRegression.getSquaredError(), linearRegression.getSquaredError() < maxSquaredError);
    }

    private void printResults(LinearRegression linearRegression, double runTimeInMilliseconds, double[] coefficientArray) {
    }

    private double solveAndReturnRuntimeInMilliseconds(LinearRegression linearRegression) {
        long startTime = System.nanoTime();
        boolean solveSucceeded = linearRegression.solve();
        Assert.assertTrue(solveSucceeded);
        long endTime = System.nanoTime();
        double runTimeInMilliseconds = (double)(endTime - startTime) / 1000000.0;
        return runTimeInMilliseconds;
    }
}

