/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.robotics.optimization;

import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.mult.VectorVectorMult_DDRM;
import org.junit.jupiter.api.Test;
import us.ihmc.commons.MathTools;
import us.ihmc.log.LogTools;
import us.ihmc.robotics.Assert;
import us.ihmc.robotics.optimization.Optimizer;
import us.ihmc.robotics.optimization.WrappedGradientDescent;
import us.ihmc.robotics.optimization.constrainedOptimization.AugmentedLagrangeOptimizationProblem;
import us.ihmc.robotics.optimization.constrainedOptimization.MultiblockADMMOptimizer;
import us.ihmc.robotics.optimization.constrainedOptimization.MultiblockADMMProblem;

public class MultiblockADMMOptimizerTest {
    private int numLagrangeIterations = 15;
    private double initialPenalty = 0.5;
    private double penaltyIncreaseFactor = 1.1;
    private DMatrixD1 initial1 = new DMatrixRMaj(new double[]{4.0});
    private DMatrixD1 initial2 = new DMatrixRMaj(new double[]{0.0});
    private DMatrixD1[] initialValues = new DMatrixD1[]{this.initial1, this.initial2};

    private static double costFunction(DMatrixD1 inputs) {
        return VectorVectorMult_DDRM.innerProd((DMatrixD1)inputs, (DMatrixD1)inputs);
    }

    private static double constraint1(DMatrixD1 inputs) {
        return inputs.get(1) - 5.0;
    }

    private static double constraint2(DMatrixD1 inputs) {
        return inputs.get(0) - 6.0;
    }

    private static double constraint3(DMatrixD1 inputs) {
        return inputs.get(0) - 1.0;
    }

    private static double blockConstraint1(DMatrixD1 ... blocks) {
        return blocks[0].get(0) + blocks[1].get(0) - 4.0;
    }

    @Test
    public void testSimpleBlockConstraint() {
        AugmentedLagrangeOptimizationProblem augmentedLagrange1 = new AugmentedLagrangeOptimizationProblem(MultiblockADMMOptimizerTest::costFunction);
        augmentedLagrange1.addInequalityConstraint(MultiblockADMMOptimizerTest::constraint3);
        AugmentedLagrangeOptimizationProblem augmentedLagrange2 = new AugmentedLagrangeOptimizationProblem(MultiblockADMMOptimizerTest::costFunction);
        augmentedLagrange2.addInequalityConstraint(MultiblockADMMOptimizerTest::constraint3);
        MultiblockADMMProblem admm = new MultiblockADMMProblem();
        admm.addIsolatedProblem(augmentedLagrange1);
        admm.addIsolatedProblem(augmentedLagrange2);
        admm.addEqualityConstraint(MultiblockADMMOptimizerTest::blockConstraint1);
        admm.initialize(this.initialPenalty, this.penaltyIncreaseFactor);
        int numBlocks = admm.getNumBlocks();
        Optimizer[] optimizers = new Optimizer[numBlocks];
        for (int i = 0; i < numBlocks; ++i) {
            optimizers[i] = new WrappedGradientDescent();
        }
        MultiblockADMMOptimizer admmOptimizer = new MultiblockADMMOptimizer(admm, optimizers);
        admmOptimizer.setVerbose(true);
        DMatrixD1[] optima = admmOptimizer.solveOverNIterations(this.numLagrangeIterations, this.initialValues);
        Assert.assertTrue("x1 arrived on desired value", MathTools.epsilonCompare((double)optima[0].get(0), (double)2.0, (double)0.001));
        Assert.assertTrue("x2 arrived on desired value", MathTools.epsilonCompare((double)optima[1].get(0), (double)2.0, (double)0.001));
        LogTools.debug((String)"Test completed successfully");
    }
}

