/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.mecano.algorithms;

import java.util.List;
import java.util.Random;
import org.ejml.data.DMatrix;
import org.ejml.data.DMatrix1Row;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import us.ihmc.euclid.referenceFrame.FrameNameRestrictionLevel;
import us.ihmc.euclid.referenceFrame.FrameVector3D;
import us.ihmc.euclid.referenceFrame.ReferenceFrame;
import us.ihmc.euclid.referenceFrame.interfaces.EuclidFrameGeometry;
import us.ihmc.euclid.referenceFrame.interfaces.FrameVector3DBasics;
import us.ihmc.euclid.referenceFrame.interfaces.FrameVector3DReadOnly;
import us.ihmc.euclid.referenceFrame.tools.EuclidFrameTestTools;
import us.ihmc.mecano.algorithms.CentroidalMomentumCalculator;
import us.ihmc.mecano.algorithms.CentroidalMomentumRateCalculator;
import us.ihmc.mecano.algorithms.SpatialAccelerationCalculator;
import us.ihmc.mecano.frames.CenterOfMassReferenceFrame;
import us.ihmc.mecano.multiBodySystem.interfaces.JointBasics;
import us.ihmc.mecano.multiBodySystem.interfaces.JointReadOnly;
import us.ihmc.mecano.multiBodySystem.interfaces.OneDoFJointBasics;
import us.ihmc.mecano.multiBodySystem.interfaces.RigidBodyBasics;
import us.ihmc.mecano.multiBodySystem.interfaces.RigidBodyReadOnly;
import us.ihmc.mecano.spatial.Momentum;
import us.ihmc.mecano.spatial.SpatialAcceleration;
import us.ihmc.mecano.spatial.SpatialForce;
import us.ihmc.mecano.spatial.Wrench;
import us.ihmc.mecano.spatial.interfaces.MomentumReadOnly;
import us.ihmc.mecano.spatial.interfaces.SpatialAccelerationReadOnly;
import us.ihmc.mecano.spatial.interfaces.SpatialForceBasics;
import us.ihmc.mecano.spatial.interfaces.SpatialForceReadOnly;
import us.ihmc.mecano.spatial.interfaces.SpatialInertiaReadOnly;
import us.ihmc.mecano.spatial.interfaces.SpatialMotionReadOnly;
import us.ihmc.mecano.spatial.interfaces.SpatialVectorReadOnly;
import us.ihmc.mecano.spatial.interfaces.TwistReadOnly;
import us.ihmc.mecano.spatial.interfaces.WrenchBasics;
import us.ihmc.mecano.tools.JointStateType;
import us.ihmc.mecano.tools.MecanoTestTools;
import us.ihmc.mecano.tools.MultiBodySystemRandomTools;
import us.ihmc.mecano.tools.MultiBodySystemStateIntegrator;
import us.ihmc.mecano.tools.MultiBodySystemTools;

public class CentroidalMomentumRateCalculatorTest {
    private static final ReferenceFrame worldFrame = ReferenceFrame.getWorldFrame();
    private static final int ITERATIONS = 500;
    private static final double EPSILON = 2.0E-10;
    private static final double FD_EPSILON = 5.0E-4;

    @BeforeEach
    public void disableNameRestriction() {
        ReferenceFrame.getWorldFrame().setNameRestrictionLevel(FrameNameRestrictionLevel.NONE);
    }

    @Test
    public void testMomentumRateWithOneDoFJointChain() {
        Random random = new Random(360675L);
        for (int i = 0; i < 500; ++i) {
            int numberOfJoints = random.nextInt(50) + 1;
            List joints = MultiBodySystemRandomTools.nextOneDoFJointChain((Random)random, (int)numberOfJoints);
            MultiBodySystemRandomTools.nextState((Random)random, (JointStateType)JointStateType.CONFIGURATION, (Iterable)joints);
            MultiBodySystemRandomTools.nextState((Random)random, (JointStateType)JointStateType.VELOCITY, (Iterable)joints);
            MultiBodySystemRandomTools.nextState((Random)random, (JointStateType)JointStateType.ACCELERATION, (Iterable)joints);
            RigidBodyBasics rootBody = ((JointBasics)joints.get(0)).getPredecessor();
            rootBody.updateFramesRecursively();
            CenterOfMassReferenceFrame centerOfMassFrame = new CenterOfMassReferenceFrame("centerOfMassFrame", worldFrame, (RigidBodyReadOnly)rootBody);
            centerOfMassFrame.update();
            CentroidalMomentumCalculator centroidalMomentumCalculator = new CentroidalMomentumCalculator((RigidBodyReadOnly)rootBody, (ReferenceFrame)centerOfMassFrame);
            centroidalMomentumCalculator.reset();
            CentroidalMomentumRateCalculator centroidalMomentumRateCalculator = new CentroidalMomentumRateCalculator((RigidBodyReadOnly)rootBody, (ReferenceFrame)centerOfMassFrame);
            centroidalMomentumRateCalculator.reset();
            MomentumReadOnly actualMomentum = centroidalMomentumRateCalculator.getMomentum();
            MomentumReadOnly expectedMomentum = centroidalMomentumCalculator.getMomentum();
            MecanoTestTools.assertMomentumEquals((MomentumReadOnly)expectedMomentum, (MomentumReadOnly)actualMomentum, (double)2.0E-10);
            SpatialAccelerationCalculator spatialAccelerationCalculator = new SpatialAccelerationCalculator((RigidBodyReadOnly)rootBody, worldFrame);
            spatialAccelerationCalculator.reset();
            SpatialForce actualMomentumRate = new SpatialForce(centroidalMomentumRateCalculator.getMomentumRate());
            SpatialForce expectedMomentumRate = CentroidalMomentumRateCalculatorTest.computeMomentumRate((RigidBodyReadOnly)rootBody, spatialAccelerationCalculator, centroidalMomentumRateCalculator.getReferenceFrame());
            MecanoTestTools.assertSpatialVectorEquals((SpatialVectorReadOnly)expectedMomentumRate, (SpatialVectorReadOnly)actualMomentumRate, (double)2.0E-10);
            DMatrixRMaj jointAccelerationMatrix = new DMatrixRMaj(numberOfJoints, 1);
            MultiBodySystemTools.extractJointsState((List)joints, (JointStateType)JointStateType.ACCELERATION, (DMatrix)jointAccelerationMatrix);
            actualMomentumRate = new SpatialForce();
            centroidalMomentumRateCalculator.getMomentumRate((DMatrix1Row)jointAccelerationMatrix, (SpatialForceBasics)actualMomentumRate);
            MecanoTestTools.assertSpatialVectorEquals((SpatialVectorReadOnly)expectedMomentumRate, (SpatialVectorReadOnly)actualMomentumRate, (double)2.0E-10);
            FrameVector3DReadOnly expectedCenterOfMassAcceleration = centroidalMomentumRateCalculator.getCenterOfMassAcceleration();
            FrameVector3D actualCenterOfMassAcceleration = new FrameVector3D();
            centroidalMomentumRateCalculator.getCenterOfMassAcceleration(jointAccelerationMatrix, (FrameVector3DBasics)actualCenterOfMassAcceleration);
            EuclidFrameTestTools.assertEquals((EuclidFrameGeometry)expectedCenterOfMassAcceleration, (EuclidFrameGeometry)actualCenterOfMassAcceleration, (double)2.0E-10);
        }
    }

    @Test
    public void testMomentumRateWithOneDoFJointTree() {
        Random random = new Random(360675L);
        for (int i = 0; i < 500; ++i) {
            int numberOfJoints = random.nextInt(50) + 1;
            List joints = MultiBodySystemRandomTools.nextOneDoFJointTree((Random)random, (int)numberOfJoints);
            MultiBodySystemRandomTools.nextState((Random)random, (JointStateType)JointStateType.CONFIGURATION, (Iterable)joints);
            MultiBodySystemRandomTools.nextState((Random)random, (JointStateType)JointStateType.VELOCITY, (Iterable)joints);
            MultiBodySystemRandomTools.nextState((Random)random, (JointStateType)JointStateType.ACCELERATION, (Iterable)joints);
            RigidBodyBasics rootBody = ((JointBasics)joints.get(0)).getPredecessor();
            rootBody.updateFramesRecursively();
            CenterOfMassReferenceFrame centerOfMassFrame = new CenterOfMassReferenceFrame("centerOfMassFrame", worldFrame, (RigidBodyReadOnly)rootBody);
            centerOfMassFrame.update();
            CentroidalMomentumCalculator centroidalMomentumCalculator = new CentroidalMomentumCalculator((RigidBodyReadOnly)rootBody, (ReferenceFrame)centerOfMassFrame);
            centroidalMomentumCalculator.reset();
            CentroidalMomentumRateCalculator centroidalMomentumRateCalculator = new CentroidalMomentumRateCalculator((RigidBodyReadOnly)rootBody, (ReferenceFrame)centerOfMassFrame);
            centroidalMomentumRateCalculator.reset();
            MomentumReadOnly actualMomentum = centroidalMomentumRateCalculator.getMomentum();
            MomentumReadOnly expectedMomentum = centroidalMomentumCalculator.getMomentum();
            MecanoTestTools.assertMomentumEquals((MomentumReadOnly)expectedMomentum, (MomentumReadOnly)actualMomentum, (double)2.0E-10);
            SpatialAccelerationCalculator spatialAccelerationCalculator = new SpatialAccelerationCalculator((RigidBodyReadOnly)rootBody, worldFrame);
            spatialAccelerationCalculator.reset();
            SpatialForce actualMomentumRate = new SpatialForce(centroidalMomentumRateCalculator.getMomentumRate());
            SpatialForce expectedMomentumRate = CentroidalMomentumRateCalculatorTest.computeMomentumRate((RigidBodyReadOnly)rootBody, spatialAccelerationCalculator, centroidalMomentumRateCalculator.getReferenceFrame());
            MecanoTestTools.assertSpatialVectorEquals((SpatialVectorReadOnly)expectedMomentumRate, (SpatialVectorReadOnly)actualMomentumRate, (double)2.0E-10);
            DMatrixRMaj jointAccelerationMatrix = new DMatrixRMaj(numberOfJoints, 1);
            MultiBodySystemTools.extractJointsState((List)joints, (JointStateType)JointStateType.ACCELERATION, (DMatrix)jointAccelerationMatrix);
            actualMomentumRate = new SpatialForce();
            centroidalMomentumRateCalculator.getMomentumRate((DMatrix1Row)jointAccelerationMatrix, (SpatialForceBasics)actualMomentumRate);
            MecanoTestTools.assertSpatialVectorEquals((SpatialVectorReadOnly)expectedMomentumRate, (SpatialVectorReadOnly)actualMomentumRate, (double)2.0E-10);
            FrameVector3DReadOnly expectedCenterOfMassAcceleration = centroidalMomentumRateCalculator.getCenterOfMassAcceleration();
            FrameVector3D actualCenterOfMassAcceleration = new FrameVector3D();
            centroidalMomentumRateCalculator.getCenterOfMassAcceleration(jointAccelerationMatrix, (FrameVector3DBasics)actualCenterOfMassAcceleration);
            EuclidFrameTestTools.assertEquals((EuclidFrameGeometry)expectedCenterOfMassAcceleration, (EuclidFrameGeometry)actualCenterOfMassAcceleration, (double)2.0E-10);
        }
    }

    @Test
    public void testMomentumRateWithJointChain() {
        Random random = new Random(360675L);
        for (int i = 0; i < 500; ++i) {
            int numberOfJoints = random.nextInt(50) + 1;
            List joints = MultiBodySystemRandomTools.nextJointChain((Random)random, (int)numberOfJoints);
            MultiBodySystemRandomTools.nextState((Random)random, (JointStateType)JointStateType.CONFIGURATION, (Iterable)joints);
            MultiBodySystemRandomTools.nextState((Random)random, (JointStateType)JointStateType.VELOCITY, (Iterable)joints);
            MultiBodySystemRandomTools.nextState((Random)random, (JointStateType)JointStateType.ACCELERATION, (Iterable)joints);
            RigidBodyBasics rootBody = ((JointBasics)joints.get(0)).getPredecessor();
            rootBody.updateFramesRecursively();
            CenterOfMassReferenceFrame centerOfMassFrame = new CenterOfMassReferenceFrame("centerOfMassFrame", worldFrame, (RigidBodyReadOnly)rootBody);
            centerOfMassFrame.update();
            CentroidalMomentumCalculator centroidalMomentumCalculator = new CentroidalMomentumCalculator((RigidBodyReadOnly)rootBody, (ReferenceFrame)centerOfMassFrame);
            centroidalMomentumCalculator.reset();
            CentroidalMomentumRateCalculator centroidalMomentumRateCalculator = new CentroidalMomentumRateCalculator((RigidBodyReadOnly)rootBody, (ReferenceFrame)centerOfMassFrame);
            centroidalMomentumRateCalculator.reset();
            MomentumReadOnly actualMomentum = centroidalMomentumRateCalculator.getMomentum();
            MomentumReadOnly expectedMomentum = centroidalMomentumCalculator.getMomentum();
            MecanoTestTools.assertMomentumEquals((MomentumReadOnly)expectedMomentum, (MomentumReadOnly)actualMomentum, (double)2.0E-10);
            SpatialAccelerationCalculator spatialAccelerationCalculator = new SpatialAccelerationCalculator((RigidBodyReadOnly)rootBody, worldFrame);
            spatialAccelerationCalculator.reset();
            SpatialForce actualMomentumRate = new SpatialForce(centroidalMomentumRateCalculator.getMomentumRate());
            SpatialForce expectedMomentumRate = CentroidalMomentumRateCalculatorTest.computeMomentumRate((RigidBodyReadOnly)rootBody, spatialAccelerationCalculator, centroidalMomentumRateCalculator.getReferenceFrame());
            MecanoTestTools.assertSpatialVectorEquals((SpatialVectorReadOnly)expectedMomentumRate, (SpatialVectorReadOnly)actualMomentumRate, (double)2.0E-10);
            DMatrixRMaj jointAccelerationMatrix = new DMatrixRMaj(centroidalMomentumCalculator.getCentroidalMomentumMatrix().getNumCols(), 1);
            MultiBodySystemTools.extractJointsState((List)joints, (JointStateType)JointStateType.ACCELERATION, (DMatrix)jointAccelerationMatrix);
            actualMomentumRate = new SpatialForce();
            centroidalMomentumRateCalculator.getMomentumRate((DMatrix1Row)jointAccelerationMatrix, (SpatialForceBasics)actualMomentumRate);
            MecanoTestTools.assertSpatialVectorEquals((SpatialVectorReadOnly)expectedMomentumRate, (SpatialVectorReadOnly)actualMomentumRate, (double)2.0E-10);
            FrameVector3DReadOnly expectedCenterOfMassAcceleration = centroidalMomentumRateCalculator.getCenterOfMassAcceleration();
            FrameVector3D actualCenterOfMassAcceleration = new FrameVector3D();
            centroidalMomentumRateCalculator.getCenterOfMassAcceleration(jointAccelerationMatrix, (FrameVector3DBasics)actualCenterOfMassAcceleration);
            EuclidFrameTestTools.assertEquals((EuclidFrameGeometry)expectedCenterOfMassAcceleration, (EuclidFrameGeometry)actualCenterOfMassAcceleration, (double)2.0E-10);
        }
    }

    @Test
    public void testAgainsFiniteDifference() {
        Random random = new Random(360675L);
        double dt = 1.0E-7;
        MultiBodySystemStateIntegrator integrator = new MultiBodySystemStateIntegrator();
        integrator.setIntegrationDT(dt);
        for (int i = 0; i < 500; ++i) {
            int numberOfJoints = random.nextInt(50) + 1;
            List joints = MultiBodySystemRandomTools.nextOneDoFJointChain((Random)random, (int)numberOfJoints);
            MultiBodySystemRandomTools.nextState((Random)random, (JointStateType)JointStateType.CONFIGURATION, (Iterable)joints);
            MultiBodySystemRandomTools.nextState((Random)random, (JointStateType)JointStateType.VELOCITY, (double)-0.5, (double)0.5, (Iterable)joints);
            MultiBodySystemRandomTools.nextState((Random)random, (JointStateType)JointStateType.ACCELERATION, (double)-0.5, (double)0.5, (Iterable)joints);
            RigidBodyBasics rootBody = ((OneDoFJointBasics)joints.get(0)).getPredecessor();
            CenterOfMassReferenceFrame centerOfMassFrame = new CenterOfMassReferenceFrame("centerOfMassFrame", worldFrame, (RigidBodyReadOnly)rootBody);
            CentroidalMomentumCalculator centroidalMomentumCalculator = new CentroidalMomentumCalculator((RigidBodyReadOnly)rootBody, (ReferenceFrame)centerOfMassFrame);
            CentroidalMomentumRateCalculator centroidalMomentumRateCalculator = new CentroidalMomentumRateCalculator((RigidBodyReadOnly)rootBody, (ReferenceFrame)centerOfMassFrame);
            Momentum previousMomentum = null;
            Momentum currentMomentum = null;
            SpatialForce expectedMomentumRate = new SpatialForce();
            for (int j = 0; j < 10; ++j) {
                rootBody.updateFramesRecursively();
                centerOfMassFrame.update();
                centroidalMomentumCalculator.reset();
                centroidalMomentumRateCalculator.reset();
                currentMomentum = new Momentum(centroidalMomentumCalculator.getMomentum());
                if (previousMomentum != null) {
                    expectedMomentumRate.setIncludingFrame((SpatialVectorReadOnly)currentMomentum);
                    expectedMomentumRate.sub((SpatialVectorReadOnly)previousMomentum);
                    expectedMomentumRate.scale(1.0 / dt);
                    SpatialForceReadOnly actualMomentumRate = centroidalMomentumRateCalculator.getMomentumRate();
                    MecanoTestTools.assertSpatialVectorEquals((SpatialVectorReadOnly)expectedMomentumRate, (SpatialVectorReadOnly)actualMomentumRate, (double)5.0E-4);
                }
                previousMomentum = currentMomentum;
                integrator.doubleIntegrateFromAccelerationSubtree(rootBody);
            }
        }
    }

    public static Momentum extractMomentum(List<? extends JointReadOnly> joints, CentroidalMomentumRateCalculator centroidalMomentumRateCalculator) {
        DMatrixRMaj jointVelocities = new DMatrixRMaj(MultiBodySystemTools.computeDegreesOfFreedom(joints), 1);
        MultiBodySystemTools.extractJointsState(joints, (JointStateType)JointStateType.VELOCITY, (DMatrix)jointVelocities);
        DMatrixRMaj momentumMatrix = new DMatrixRMaj(6, 1);
        CommonOps_DDRM.mult((DMatrix1Row)centroidalMomentumRateCalculator.getCentroidalMomentumMatrix(), (DMatrix1Row)jointVelocities, (DMatrix1Row)momentumMatrix);
        return new Momentum(centroidalMomentumRateCalculator.getReferenceFrame(), (DMatrix)momentumMatrix);
    }

    public static SpatialForce computeMomentumRate(RigidBodyReadOnly rootBody, SpatialAccelerationCalculator spatialAccelerationCalculator, ReferenceFrame referenceFrame) {
        SpatialForce momentumRate = new SpatialForce(referenceFrame);
        for (RigidBodyReadOnly rigidBody : rootBody.subtreeIterable()) {
            SpatialInertiaReadOnly inertia = rigidBody.getInertia();
            if (inertia == null) continue;
            Wrench bodyDynamicWrench = new Wrench();
            SpatialAcceleration bodyAcceleration = new SpatialAcceleration();
            bodyAcceleration.setIncludingFrame((SpatialMotionReadOnly)spatialAccelerationCalculator.getAccelerationOfBody(rigidBody));
            TwistReadOnly bodyTwist = rigidBody.getBodyFixedFrame().getTwistOfFrame();
            inertia.computeDynamicWrenchFast((SpatialAccelerationReadOnly)bodyAcceleration, bodyTwist, (WrenchBasics)bodyDynamicWrench);
            bodyDynamicWrench.changeFrame(referenceFrame);
            momentumRate.add((SpatialVectorReadOnly)bodyDynamicWrench);
        }
        return momentumRate;
    }
}

