/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.robotics.optimization;

import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.mult.VectorVectorMult_DDRM;
import org.junit.jupiter.api.Test;
import us.ihmc.commons.MathTools;
import us.ihmc.robotics.Assert;
import us.ihmc.robotics.optimization.Optimizer;
import us.ihmc.robotics.optimization.WrappedGradientDescent;
import us.ihmc.robotics.optimization.constrainedOptimization.AugmentedLagrangeOptimizationProblem;
import us.ihmc.robotics.optimization.constrainedOptimization.AugmentedLagrangeOptimizer;

public class AugmentedLagrangeOptimizerTest {
    private AugmentedLagrangeOptimizationProblem augmentedLagrangeProblem;

    private static double costFunctionQuadratic(DMatrixD1 inputs) {
        double cost = VectorVectorMult_DDRM.innerProd((DMatrixD1)inputs, (DMatrixD1)inputs);
        return cost;
    }

    private static double costFunctionNonconvex(DMatrixD1 inputs) {
        double norm = Math.sqrt(VectorVectorMult_DDRM.innerProd((DMatrixD1)inputs, (DMatrixD1)inputs));
        if (norm == 0.0) {
            return -5.0;
        }
        return -5.0 * Math.sin(norm) / norm;
    }

    private static double constraint1(DMatrixD1 inputs) {
        return inputs.get(1) - 5.0;
    }

    private static double constraint2(DMatrixD1 inputs) {
        return inputs.get(0) - 6.0;
    }

    private static double constraint3(DMatrixD1 inputs) {
        return -inputs.get(2) + 3.0;
    }

    private static double constraint4(DMatrixD1 inputs) {
        return inputs.get(0) + inputs.get(1) + inputs.get(2) - 6.0;
    }

    private static double constraintNonconvex(DMatrixD1 inputs) {
        return inputs.get(0) + inputs.get(1) - 6.354061535;
    }

    @Test
    public void testIsolatedConstraints() {
        DMatrixRMaj initial = new DMatrixRMaj(new double[]{10.0, 14.5, 16.0});
        int numLagrangeIterations = 10;
        double initialPenalty = 1.0;
        double penaltyIncreaseFactor = 1.5;
        this.augmentedLagrangeProblem = new AugmentedLagrangeOptimizationProblem(AugmentedLagrangeOptimizerTest::costFunctionQuadratic);
        this.augmentedLagrangeProblem.addEqualityConstraint(AugmentedLagrangeOptimizerTest::constraint1);
        this.augmentedLagrangeProblem.addInequalityConstraint(AugmentedLagrangeOptimizerTest::constraint2);
        this.augmentedLagrangeProblem.addInequalityConstraint(AugmentedLagrangeOptimizerTest::constraint3);
        this.augmentedLagrangeProblem.initialize(initialPenalty, penaltyIncreaseFactor);
        WrappedGradientDescent optimizer = new WrappedGradientDescent();
        AugmentedLagrangeOptimizer aloOptimizer = new AugmentedLagrangeOptimizer((Optimizer)optimizer, this.augmentedLagrangeProblem);
        aloOptimizer.setVerbose(true);
        DMatrixD1 optimumX = aloOptimizer.optimize(numLagrangeIterations, (DMatrixD1)initial);
        Assert.assertTrue("x1 arrived on desired value", MathTools.epsilonCompare((double)optimumX.get(0), (double)6.0, (double)0.001));
        Assert.assertTrue("x2 arrived on desired value", MathTools.epsilonCompare((double)optimumX.get(1), (double)5.0, (double)0.001));
        Assert.assertTrue("x3 arrived on desired value", MathTools.epsilonCompare((double)optimumX.get(2), (double)0.0, (double)0.001));
    }

    @Test
    public void testJointConstraints() {
        DMatrixRMaj initial = new DMatrixRMaj(new double[]{10.0, 14.5, 16.0});
        int numLagrangeIterations = 10;
        double initialPenalty = 1.0;
        double penaltyIncreaseFactor = 1.5;
        this.augmentedLagrangeProblem = new AugmentedLagrangeOptimizationProblem(AugmentedLagrangeOptimizerTest::costFunctionQuadratic);
        this.augmentedLagrangeProblem.addEqualityConstraint(AugmentedLagrangeOptimizerTest::constraint4);
        this.augmentedLagrangeProblem.initialize(initialPenalty, penaltyIncreaseFactor);
        WrappedGradientDescent optimizer = new WrappedGradientDescent();
        AugmentedLagrangeOptimizer aloOptimizer = new AugmentedLagrangeOptimizer((Optimizer)optimizer, this.augmentedLagrangeProblem);
        aloOptimizer.setVerbose(true);
        DMatrixD1 optimumX = aloOptimizer.optimize(numLagrangeIterations, (DMatrixD1)initial);
        Assert.assertTrue("x1 arrived on desired value", MathTools.epsilonCompare((double)optimumX.get(0), (double)2.0, (double)0.001));
        Assert.assertTrue("x2 arrived on desired value", MathTools.epsilonCompare((double)optimumX.get(1), (double)2.0, (double)0.001));
        Assert.assertTrue("x3 arrived on desired value", MathTools.epsilonCompare((double)optimumX.get(2), (double)2.0, (double)0.001));
    }

    @Test
    public void testNonconvex() {
        DMatrixRMaj initial = new DMatrixRMaj(new double[]{13.0, 14.0});
        int numLagrangeIterations = 10;
        double initialPenalty = 1.0;
        double penaltyIncreaseFactor = 1.5;
        this.augmentedLagrangeProblem = new AugmentedLagrangeOptimizationProblem(AugmentedLagrangeOptimizerTest::costFunctionNonconvex);
        this.augmentedLagrangeProblem.addEqualityConstraint(AugmentedLagrangeOptimizerTest::constraintNonconvex);
        this.augmentedLagrangeProblem.initialize(initialPenalty, penaltyIncreaseFactor);
        WrappedGradientDescent optimizer = new WrappedGradientDescent();
        optimizer.setInitialStepSize(10.0);
        optimizer.setLearningRate(0.9);
        AugmentedLagrangeOptimizer aloOptimizer = new AugmentedLagrangeOptimizer((Optimizer)optimizer, this.augmentedLagrangeProblem);
        aloOptimizer.setVerbose(true);
        DMatrixD1 optimumX = aloOptimizer.optimize(numLagrangeIterations, (DMatrixD1)initial);
        Assert.assertTrue("x1 arrived on desired value", MathTools.epsilonCompare((double)optimumX.get(0), (double)3.1770307678, (double)0.001));
        Assert.assertTrue("x2 arrived on desired value", MathTools.epsilonCompare((double)optimumX.get(1), (double)3.1770307678, (double)0.001));
    }
}

