/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.avatar.multiContact;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiPredicate;
import java.util.function.Function;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.ejml.data.DMatrix;
import org.ejml.data.DMatrixRMaj;
import us.ihmc.euclid.orientation.interfaces.Orientation3DReadOnly;
import us.ihmc.euclid.referenceFrame.FrameVector3D;
import us.ihmc.euclid.referenceFrame.ReferenceFrame;
import us.ihmc.euclid.referenceFrame.interfaces.FrameTuple3DReadOnly;
import us.ihmc.euclid.transform.RigidBodyTransform;
import us.ihmc.euclid.transform.interfaces.RigidBodyTransformReadOnly;
import us.ihmc.euclid.tuple3D.Point3D;
import us.ihmc.euclid.tuple3D.Vector3D;
import us.ihmc.euclid.tuple3D.interfaces.Point3DBasics;
import us.ihmc.euclid.tuple3D.interfaces.Point3DReadOnly;
import us.ihmc.euclid.tuple3D.interfaces.Tuple3DReadOnly;
import us.ihmc.euclid.tuple3D.interfaces.Vector3DBasics;
import us.ihmc.euclid.tuple4D.Quaternion;
import us.ihmc.euclid.tuple4D.interfaces.QuaternionReadOnly;
import us.ihmc.mecano.multiBodySystem.interfaces.RigidBodyReadOnly;
import us.ihmc.mecano.spatial.SpatialVector;
import us.ihmc.mecano.spatial.interfaces.SpatialVectorReadOnly;
import us.ihmc.robotics.optimization.LevenbergMarquardtParameterOptimizer;
import us.ihmc.robotics.optimization.OutputCalculator;
import us.ihmc.robotics.referenceFrames.PoseReferenceFrame;
import us.ihmc.robotics.screwTheory.SelectionMatrix3D;
import us.ihmc.robotics.screwTheory.SelectionMatrix6D;
import us.ihmc.robotics.weightMatrices.WeightMatrix3D;
import us.ihmc.robotics.weightMatrices.WeightMatrix6D;

public class RobotTransformOptimizer {
    private final RigidBodyReadOnly heaviestBodyA;
    private final RigidBodyReadOnly heaviestBodyB;
    private final RigidBodyReadOnly[] rigidBodiesA;
    private final RigidBodyReadOnly[] rigidBodiesB;
    private final List<RigidBodyPairErrorCalculator> rigidBodyErrorCalculators = new ArrayList<RigidBodyPairErrorCalculator>();
    private final Map<String, Pair<RigidBodyReadOnly, RigidBodyReadOnly>> nameAToBodyMap = new HashMap<String, Pair<RigidBodyReadOnly, RigidBodyReadOnly>>();
    private final Map<String, Pair<RigidBodyReadOnly, RigidBodyReadOnly>> nameBToBodyMap = new HashMap<String, Pair<RigidBodyReadOnly, RigidBodyReadOnly>>();
    private boolean initializeWithHeaviestBody = false;
    private int maxIterations = 500;
    private double convergenceThreshold = 1.0E-7;
    private final RigidBodyTransform transformFromBToA = new RigidBodyTransform();
    private final DMatrixRMaj perturbationVector = new DMatrixRMaj(6, 1);
    private final Function<DMatrixRMaj, RigidBodyTransform> inputFunction = input -> {
        this.transformFromBToA.getRotation().setYawPitchRoll(input.get(5), input.get(4), input.get(3));
        this.transformFromBToA.getTranslation().set(input.get(0), input.get(1), input.get(2));
        return this.transformFromBToA;
    };
    private final DMatrixRMaj errorSpace = new DMatrixRMaj(50, 1);
    private final OutputCalculator errorCalculator = inputParameter -> {
        this.errorSpace.reshape(this.rigidBodyErrorCalculators.size(), 1);
        RigidBodyTransform correction = this.inputFunction.apply((DMatrixRMaj)inputParameter);
        for (int i = 0; i < this.rigidBodyErrorCalculators.size(); ++i) {
            this.errorSpace.set(i, this.rigidBodyErrorCalculators.get(i).computeError(correction).length());
        }
        return this.errorSpace;
    };
    private LevenbergMarquardtParameterOptimizer optimizer;

    public RobotTransformOptimizer(RigidBodyReadOnly rootBodyA, RigidBodyReadOnly rootBodyB) {
        this.rigidBodiesA = rootBodyA.subtreeArray();
        this.rigidBodiesB = rootBodyB.subtreeArray();
        RigidBodyReadOnly candidateA = this.rigidBodiesA[0];
        RigidBodyReadOnly candidateB = this.rigidBodiesB[0];
        for (int i = 0; i < this.rigidBodiesA.length; ++i) {
            double candidateMass;
            RigidBodyReadOnly bodyA = this.rigidBodiesA[i];
            RigidBodyReadOnly bodyB = this.rigidBodiesB[i];
            double currentMass = bodyA.isRootBody() ? 0.0 : bodyA.getInertia().getMass();
            double d = candidateMass = candidateA.isRootBody() ? 0.0 : candidateA.getInertia().getMass();
            if (currentMass > candidateMass) {
                candidateA = bodyA;
                candidateB = bodyB;
            }
            this.nameAToBodyMap.put(bodyA.getName(), (Pair<RigidBodyReadOnly, RigidBodyReadOnly>)new ImmutablePair((Object)bodyA, (Object)bodyB));
            this.nameBToBodyMap.put(bodyB.getName(), (Pair<RigidBodyReadOnly, RigidBodyReadOnly>)new ImmutablePair((Object)bodyA, (Object)bodyB));
        }
        this.heaviestBodyA = candidateA;
        this.heaviestBodyB = candidateB;
        this.perturbationVector.set(0, 1.25E-4);
        this.perturbationVector.set(1, 1.25E-4);
        this.perturbationVector.set(2, 1.25E-4);
        this.perturbationVector.set(3, 2.5E-5);
        this.perturbationVector.set(4, 2.5E-5);
        this.perturbationVector.set(5, 2.5E-5);
    }

    public void setInitializeWithHeaviestBody(boolean initializeWithHeaviestBody) {
        this.initializeWithHeaviestBody = initializeWithHeaviestBody;
    }

    public void clearErrorCalculators() {
        this.rigidBodyErrorCalculators.clear();
    }

    public void addDefaultRigidBodySpatialErrorCalculators(BiPredicate<RigidBodyReadOnly, RigidBodyReadOnly> bodySelector) {
        this.rigidBodyErrorCalculators.clear();
        for (int i = 0; i < this.rigidBodiesA.length; ++i) {
            RigidBodyReadOnly bodyA = this.rigidBodiesA[i];
            RigidBodyReadOnly bodyB = this.rigidBodiesB[i];
            if (!bodySelector.test(bodyA, bodyB)) continue;
            this.rigidBodyErrorCalculators.add(new RigidBodyPairSpatialErrorCalculator(bodyA, bodyB));
        }
    }

    public void addDefaultRigidBodyLinearErrorCalculators(BiPredicate<RigidBodyReadOnly, RigidBodyReadOnly> bodySelector) {
        this.rigidBodyErrorCalculators.clear();
        for (int i = 0; i < this.rigidBodiesA.length; ++i) {
            RigidBodyReadOnly bodyA = this.rigidBodiesA[i];
            RigidBodyReadOnly bodyB = this.rigidBodiesB[i];
            if (!bodySelector.test(bodyA, bodyB)) continue;
            this.rigidBodyErrorCalculators.add(new RigidBodyPairLinearErrorCalculator(bodyA, bodyB));
        }
    }

    public void addDefaultRigidBodyAngularErrorCalculators(BiPredicate<RigidBodyReadOnly, RigidBodyReadOnly> bodySelector) {
        this.rigidBodyErrorCalculators.clear();
        for (int i = 0; i < this.rigidBodiesA.length; ++i) {
            RigidBodyReadOnly bodyA = this.rigidBodiesA[i];
            RigidBodyReadOnly bodyB = this.rigidBodiesB[i];
            if (!bodySelector.test(bodyA, bodyB)) continue;
            this.rigidBodyErrorCalculators.add(new RigidBodyPairAngularErrorCalculator(bodyA, bodyB));
        }
    }

    public RigidBodyPairSpatialErrorCalculator addSpatialRigidBodyErrorCalculator(String bodyName) {
        Pair<RigidBodyReadOnly, RigidBodyReadOnly> bodyPair = this.nameAToBodyMap.get(bodyName);
        if (bodyPair == null) {
            bodyPair = this.nameBToBodyMap.get(bodyName);
        }
        if (bodyPair == null) {
            return null;
        }
        RigidBodyPairSpatialErrorCalculator calculator = new RigidBodyPairSpatialErrorCalculator((RigidBodyReadOnly)bodyPair.getLeft(), (RigidBodyReadOnly)bodyPair.getRight());
        this.rigidBodyErrorCalculators.add(calculator);
        return calculator;
    }

    public RigidBodyPairLinearErrorCalculator addLinearRigidBodyErrorCalculator(String bodyName) {
        Pair<RigidBodyReadOnly, RigidBodyReadOnly> bodyPair = this.nameAToBodyMap.get(bodyName);
        if (bodyPair == null) {
            bodyPair = this.nameBToBodyMap.get(bodyName);
        }
        if (bodyPair == null) {
            return null;
        }
        RigidBodyPairLinearErrorCalculator calculator = new RigidBodyPairLinearErrorCalculator((RigidBodyReadOnly)bodyPair.getLeft(), (RigidBodyReadOnly)bodyPair.getRight());
        this.rigidBodyErrorCalculators.add(calculator);
        return calculator;
    }

    public RigidBodyPairAngularErrorCalculator addAngularRigidBodyErrorCalculator(String bodyName) {
        Pair<RigidBodyReadOnly, RigidBodyReadOnly> bodyPair = this.nameAToBodyMap.get(bodyName);
        if (bodyPair == null) {
            bodyPair = this.nameBToBodyMap.get(bodyName);
        }
        if (bodyPair == null) {
            return null;
        }
        RigidBodyPairAngularErrorCalculator calculator = new RigidBodyPairAngularErrorCalculator((RigidBodyReadOnly)bodyPair.getLeft(), (RigidBodyReadOnly)bodyPair.getRight());
        this.rigidBodyErrorCalculators.add(calculator);
        return calculator;
    }

    public void compute() {
        double quality;
        DMatrixRMaj initialGuess = new DMatrixRMaj(6, 1);
        this.rigidBodyErrorCalculators.forEach(bodyPair -> bodyPair.initialize());
        if (this.initializeWithHeaviestBody) {
            RigidBodyTransform initialTransform = new RigidBodyTransform((RigidBodyTransformReadOnly)this.heaviestBodyB.getBodyFixedFrame().getTransformToRoot());
            initialTransform.preMultiplyInvertThis((RigidBodyTransformReadOnly)this.heaviestBodyA.getBodyFixedFrame().getTransformToRoot());
            initialGuess.set(0, initialTransform.getTranslationX());
            initialGuess.set(1, initialTransform.getTranslationY());
            initialGuess.set(2, initialTransform.getTranslationZ());
            initialGuess.set(3, initialTransform.getRotation().getRoll());
            initialGuess.set(4, initialTransform.getRotation().getPitch());
            initialGuess.set(5, initialTransform.getRotation().getYaw());
        } else {
            Vector3D average = new Vector3D();
            this.rigidBodyErrorCalculators.forEach(bodyPair -> average.add((Tuple3DReadOnly)bodyPair.computeError().getLinearPart()));
            average.scale(1.0 / (double)this.rigidBodyErrorCalculators.size());
            average.get((DMatrix)initialGuess);
        }
        int inputParameterDimension = 6;
        int outputDimension = this.rigidBodyErrorCalculators.size();
        this.optimizer = new LevenbergMarquardtParameterOptimizer(this.inputFunction, this.errorCalculator, inputParameterDimension, outputDimension);
        this.optimizer.setInitialOptimalGuess(initialGuess);
        this.optimizer.setPerturbationVector(this.perturbationVector);
        this.optimizer.setCorrespondenceThreshold(Double.POSITIVE_INFINITY);
        this.optimizer.initialize();
        for (int i = 0; i < this.maxIterations && !((quality = this.optimizer.iterate()) <= this.convergenceThreshold); ++i) {
        }
        this.transformFromBToA.set(this.inputFunction.apply(this.optimizer.getOptimalParameter()));
    }

    public RigidBodyTransform getTransformFromBToA() {
        return this.transformFromBToA;
    }

    public double getSolutionQuality() {
        if (this.optimizer != null) {
            return this.optimizer.getQuality();
        }
        return Double.NaN;
    }

    public static class RigidBodyPairSpatialErrorCalculator
    extends RigidBodyPairErrorCalculator {
        private final RigidBodyTransform poseInitialA = new RigidBodyTransform();
        private final RigidBodyTransform poseInitialB = new RigidBodyTransform();
        private final WeightMatrix6D weightMatrix = new WeightMatrix6D();
        private final SelectionMatrix6D selectionMatrix = new SelectionMatrix6D();
        private final RigidBodyTransform errorTransform = new RigidBodyTransform();
        private final SpatialVector error = new SpatialVector(ReferenceFrame.getWorldFrame());
        private final SpatialVector subSpaceError = new SpatialVector(ReferenceFrame.getWorldFrame());
        private final SpatialVector weightedSubSpaceError = new SpatialVector(ReferenceFrame.getWorldFrame());
        private final FrameVector3D tempError3D = new FrameVector3D(ReferenceFrame.getWorldFrame());

        protected RigidBodyPairSpatialErrorCalculator(RigidBodyReadOnly bodyA, RigidBodyReadOnly bodyB) {
            super(bodyA, bodyB);
            this.weightMatrix.setAngularWeights(1.0, 1.0, 1.0);
            this.weightMatrix.setLinearWeights(1.0, 1.0, 1.0);
        }

        @Override
        protected void initialize() {
            this.poseInitialA.set(this.controlFrameA.getTransformToRoot());
            this.poseInitialB.set(this.controlFrameB.getTransformToRoot());
        }

        @Override
        protected SpatialVector computeError(RigidBodyTransform transformForB) {
            this.errorTransform.set(this.poseInitialA);
            this.errorTransform.multiplyInvertOther((RigidBodyTransformReadOnly)this.poseInitialB);
            if (transformForB != null) {
                this.errorTransform.multiplyInvertOther((RigidBodyTransformReadOnly)transformForB);
            }
            this.errorTransform.getRotation().getRotationVector((Vector3DBasics)this.error.getAngularPart());
            this.error.getLinearPart().set((Tuple3DReadOnly)this.errorTransform.getTranslation());
            this.tempError3D.setIncludingFrame((FrameTuple3DReadOnly)this.error.getAngularPart());
            this.selectionMatrix.applyAngularSelection(this.tempError3D);
            this.subSpaceError.getAngularPart().set((FrameTuple3DReadOnly)this.tempError3D);
            this.tempError3D.setIncludingFrame((FrameTuple3DReadOnly)this.error.getLinearPart());
            this.selectionMatrix.applyLinearSelection(this.tempError3D);
            this.subSpaceError.getLinearPart().set((FrameTuple3DReadOnly)this.tempError3D);
            this.tempError3D.setIncludingFrame((FrameTuple3DReadOnly)this.subSpaceError.getAngularPart());
            this.weightMatrix.applyAngularWeight(this.tempError3D);
            this.weightedSubSpaceError.getAngularPart().set((FrameTuple3DReadOnly)this.tempError3D);
            this.tempError3D.setIncludingFrame((FrameTuple3DReadOnly)this.subSpaceError.getLinearPart());
            this.weightMatrix.applyLinearWeight(this.tempError3D);
            this.weightedSubSpaceError.getLinearPart().set((FrameTuple3DReadOnly)this.tempError3D);
            return this.weightedSubSpaceError;
        }

        @Override
        public SpatialVectorReadOnly getError() {
            return this.weightedSubSpaceError;
        }

        public SelectionMatrix6D getSelectionMatrix() {
            return this.selectionMatrix;
        }

        public WeightMatrix6D getWeightMatrix() {
            return this.weightMatrix;
        }
    }

    public static class RigidBodyPairLinearErrorCalculator
    extends RigidBodyPairErrorCalculator {
        private final Point3D pointInitialA = new Point3D();
        private final Point3D pointInitialB = new Point3D();
        private final Point3D pointCorrectedB = new Point3D();
        private final WeightMatrix3D weightMatrix = new WeightMatrix3D();
        private final SelectionMatrix3D selectionMatrix = new SelectionMatrix3D();
        private final SpatialVector error = new SpatialVector(ReferenceFrame.getWorldFrame());
        private final SpatialVector subSpaceError = new SpatialVector(ReferenceFrame.getWorldFrame());
        private final SpatialVector weightedSubSpaceError = new SpatialVector(ReferenceFrame.getWorldFrame());
        private final FrameVector3D tempError3D = new FrameVector3D(ReferenceFrame.getWorldFrame());

        protected RigidBodyPairLinearErrorCalculator(RigidBodyReadOnly bodyA, RigidBodyReadOnly bodyB) {
            super(bodyA, bodyB);
            this.weightMatrix.setWeights(1.0, 1.0, 1.0);
        }

        @Override
        protected void initialize() {
            this.pointInitialA.set((Tuple3DReadOnly)this.controlFrameA.getTransformToRoot().getTranslation());
            this.pointInitialB.set((Tuple3DReadOnly)this.controlFrameB.getTransformToRoot().getTranslation());
        }

        @Override
        protected SpatialVector computeError(RigidBodyTransform transformForB) {
            if (transformForB != null) {
                transformForB.transform((Point3DReadOnly)this.pointInitialB, (Point3DBasics)this.pointCorrectedB);
                this.error.getLinearPart().sub((Tuple3DReadOnly)this.pointInitialA, (Tuple3DReadOnly)this.pointCorrectedB);
            } else {
                this.error.getLinearPart().sub((Tuple3DReadOnly)this.pointInitialA, (Tuple3DReadOnly)this.pointInitialB);
            }
            this.tempError3D.setIncludingFrame((FrameTuple3DReadOnly)this.error.getLinearPart());
            this.selectionMatrix.applySelection(this.tempError3D);
            this.subSpaceError.getLinearPart().set((FrameTuple3DReadOnly)this.tempError3D);
            this.tempError3D.setIncludingFrame((FrameTuple3DReadOnly)this.subSpaceError.getLinearPart());
            this.weightMatrix.applyWeight(this.tempError3D);
            this.weightedSubSpaceError.getLinearPart().set((FrameTuple3DReadOnly)this.tempError3D);
            return this.weightedSubSpaceError;
        }

        @Override
        public SpatialVectorReadOnly getError() {
            return this.weightedSubSpaceError;
        }

        public SelectionMatrix3D getSelectionMatrix() {
            return this.selectionMatrix;
        }

        public WeightMatrix3D getWeightMatrix() {
            return this.weightMatrix;
        }
    }

    public static class RigidBodyPairAngularErrorCalculator
    extends RigidBodyPairErrorCalculator {
        private final Quaternion orientationInitialA = new Quaternion();
        private final Quaternion orientationInitialB = new Quaternion();
        private final WeightMatrix3D weightMatrix = new WeightMatrix3D();
        private final SelectionMatrix3D selectionMatrix = new SelectionMatrix3D();
        private final Quaternion orientationError = new Quaternion();
        private final SpatialVector error = new SpatialVector(ReferenceFrame.getWorldFrame());
        private final SpatialVector subSpaceError = new SpatialVector(ReferenceFrame.getWorldFrame());
        private final SpatialVector weightedSubSpaceError = new SpatialVector(ReferenceFrame.getWorldFrame());
        private final FrameVector3D tempError3D = new FrameVector3D(ReferenceFrame.getWorldFrame());

        protected RigidBodyPairAngularErrorCalculator(RigidBodyReadOnly bodyA, RigidBodyReadOnly bodyB) {
            super(bodyA, bodyB);
            this.weightMatrix.setWeights(1.0, 1.0, 1.0);
        }

        @Override
        protected void initialize() {
            this.orientationInitialA.set((Orientation3DReadOnly)this.controlFrameA.getTransformToRoot().getRotation());
            this.orientationInitialB.set((Orientation3DReadOnly)this.controlFrameB.getTransformToRoot().getRotation());
        }

        @Override
        protected SpatialVector computeError(RigidBodyTransform transformForB) {
            this.orientationError.set(this.orientationInitialA);
            this.orientationError.multiplyConjugateOther((QuaternionReadOnly)this.orientationInitialB);
            if (transformForB != null) {
                this.orientationError.appendInvertOther((Orientation3DReadOnly)transformForB.getRotation());
            }
            this.orientationError.getRotationVector((Vector3DBasics)this.error.getAngularPart());
            this.tempError3D.setIncludingFrame((FrameTuple3DReadOnly)this.error.getAngularPart());
            this.selectionMatrix.applySelection(this.tempError3D);
            this.subSpaceError.getAngularPart().set((FrameTuple3DReadOnly)this.tempError3D);
            this.tempError3D.setIncludingFrame((FrameTuple3DReadOnly)this.subSpaceError.getAngularPart());
            this.weightMatrix.applyWeight(this.tempError3D);
            this.weightedSubSpaceError.getAngularPart().set((FrameTuple3DReadOnly)this.tempError3D);
            return this.weightedSubSpaceError;
        }

        @Override
        public SpatialVectorReadOnly getError() {
            return this.weightedSubSpaceError;
        }

        public SelectionMatrix3D getSelectionMatrix() {
            return this.selectionMatrix;
        }

        public WeightMatrix3D getWeightMatrix() {
            return this.weightMatrix;
        }
    }

    public static abstract class RigidBodyPairErrorCalculator {
        protected final RigidBodyReadOnly bodyA;
        protected final RigidBodyReadOnly bodyB;
        protected final PoseReferenceFrame controlFrameA;
        protected final PoseReferenceFrame controlFrameB;

        protected RigidBodyPairErrorCalculator(RigidBodyReadOnly bodyA, RigidBodyReadOnly bodyB) {
            this.bodyA = bodyA;
            this.bodyB = bodyB;
            this.controlFrameA = new PoseReferenceFrame(bodyA.getName() + "ControlFrame", (ReferenceFrame)bodyA.getBodyFixedFrame());
            this.controlFrameB = new PoseReferenceFrame(bodyB.getName() + "ControlFrame", (ReferenceFrame)bodyB.getBodyFixedFrame());
        }

        protected abstract void initialize();

        protected SpatialVector computeError() {
            return this.computeError(null);
        }

        protected abstract SpatialVector computeError(RigidBodyTransform var1);

        public abstract SpatialVectorReadOnly getError();

        public RigidBodyReadOnly getBodyA() {
            return this.bodyA;
        }

        public RigidBodyReadOnly getBodyB() {
            return this.bodyB;
        }

        public void setControlFrameOffset(RigidBodyTransform controlFrameOffset) {
            this.controlFrameA.setPoseAndUpdate((RigidBodyTransformReadOnly)controlFrameOffset);
            this.controlFrameB.setPoseAndUpdate((RigidBodyTransformReadOnly)controlFrameOffset);
        }

        public ReferenceFrame getControlFrameA() {
            return this.controlFrameA;
        }

        public ReferenceFrame getControlFrameB() {
            return this.controlFrameB;
        }
    }
}

