/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.robotics.math.trajectories.generators;

import gnu.trove.list.array.TDoubleArrayList;
import java.util.ArrayList;
import java.util.List;
import org.ejml.data.DMatrix;
import org.ejml.data.DMatrixD1;
import org.ejml.data.DMatrixRMaj;
import org.ejml.dense.row.CommonOps_DDRM;
import us.ihmc.commons.lists.RecyclingArrayList;
import us.ihmc.robotics.math.trajectories.generators.MultiCubicSpline1DSolver;
import us.ihmc.robotics.time.ExecutionTimer;
import us.ihmc.yoVariables.registry.YoRegistry;
import us.ihmc.yoVariables.variable.YoDouble;
import us.ihmc.yoVariables.variable.YoInteger;

public class TrajectoryPointOptimizer {
    public static final int maxWaypoints = 200;
    public static final int maxIterations = 200;
    private static final double epsilon = 1.0E-7;
    private static final double initialTimeGain = 0.001;
    private static final double costEpsilon = 0.01;
    public static final int coefficients = 4;
    private final YoRegistry registry;
    private final YoInteger dimensions;
    private final YoInteger nWaypoints;
    private final YoInteger intervals;
    private final YoInteger problemSize;
    private final YoInteger iteration;
    private final TDoubleArrayList x0;
    private final TDoubleArrayList x1;
    private final TDoubleArrayList xd0;
    private final TDoubleArrayList xd1;
    private final ArrayList<DMatrixRMaj> waypoints = new ArrayList();
    private final MultiCubicSpline1DSolver solver;
    private final TDoubleArrayList w0;
    private final TDoubleArrayList w1;
    private final TDoubleArrayList wd0;
    private final TDoubleArrayList wd1;
    private final DMatrixRMaj intervalTimes = new DMatrixRMaj(1, 1);
    private final DMatrixRMaj saveIntervalTimes = new DMatrixRMaj(1, 1);
    private final TDoubleArrayList costs = new TDoubleArrayList(201);
    private final RecyclingArrayList<DMatrixRMaj> x = new RecyclingArrayList(0, () -> new DMatrixRMaj(1, 1));
    private final DMatrixRMaj timeGradient = new DMatrixRMaj(1, 1);
    private final DMatrixRMaj timeUpdate = new DMatrixRMaj(1, 1);
    private final YoDouble timeGain;
    private final ExecutionTimer computeTimer;
    private final ExecutionTimer timeUpdateTimer;
    private final DMatrixRMaj tempCoeffs = new DMatrixRMaj(1, 1);
    private final DMatrixRMaj tempLine = new DMatrixRMaj(1, 1);

    public TrajectoryPointOptimizer(int dimensions, YoRegistry parentRegistry) {
        this("", dimensions, parentRegistry);
    }

    public TrajectoryPointOptimizer(String namePrefix, int dimensions, YoRegistry parentRegistry) {
        this(namePrefix, dimensions);
        parentRegistry.addChild(this.registry);
    }

    public TrajectoryPointOptimizer(int dimensions) {
        this("", dimensions);
    }

    public TrajectoryPointOptimizer(String namePrefix, int dimensions) {
        int i;
        this.registry = new YoRegistry(namePrefix + this.getClass().getSimpleName());
        this.dimensions = new YoInteger(namePrefix + "Dimensions", this.registry);
        this.nWaypoints = new YoInteger(namePrefix + "NumberOfWaypoints", this.registry);
        this.intervals = new YoInteger(namePrefix + "NumberOfIntervals", this.registry);
        this.problemSize = new YoInteger(namePrefix + "ProblemSize", this.registry);
        this.iteration = new YoInteger(namePrefix + "Iteration", this.registry);
        this.computeTimer = new ExecutionTimer(namePrefix + "ComputeTimer", 0.0, this.registry);
        this.timeUpdateTimer = new ExecutionTimer(namePrefix + "TimeUpdateTimer", 0.0, this.registry);
        this.timeGain = new YoDouble(namePrefix + "TimeGain", this.registry);
        this.solver = new MultiCubicSpline1DSolver();
        dimensions = Math.max(dimensions, 0);
        this.dimensions.set(dimensions);
        this.timeGain.set(0.001);
        this.x0 = new TDoubleArrayList(dimensions);
        this.x1 = new TDoubleArrayList(dimensions);
        this.xd0 = new TDoubleArrayList(dimensions);
        this.xd1 = new TDoubleArrayList(dimensions);
        for (i = 0; i < dimensions; ++i) {
            this.x0.add(0.0);
            this.xd0.add(0.0);
            this.x1.add(0.0);
            this.xd1.add(0.0);
        }
        for (i = 0; i < 200; ++i) {
            this.waypoints.add(new DMatrixRMaj(dimensions, 1));
        }
        this.w0 = new TDoubleArrayList(dimensions);
        this.w1 = new TDoubleArrayList(dimensions);
        this.wd0 = new TDoubleArrayList(dimensions);
        this.wd1 = new TDoubleArrayList(dimensions);
        this.clearWeights();
        this.tempCoeffs.reshape(4, 1);
    }

    public void clearWeights() {
        this.w0.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
        this.w1.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
        this.wd0.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
        this.wd1.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
    }

    public void setEndPoints(int dimension, double startPosition, double startVelocity, double targetPosition, double targetVelocity) {
        if (dimension < 0 || dimension >= this.dimensions.getValue()) {
            throw new IllegalArgumentException("Illegal dimension, expected to be in [0, " + this.dimensions.getValue() + "[, but was: " + dimension);
        }
        this.x0.set(dimension, startPosition);
        this.xd0.set(dimension, startVelocity);
        this.x1.set(dimension, targetPosition);
        this.xd1.set(dimension, targetVelocity);
    }

    public void setEndPoints(TDoubleArrayList startPosition, TDoubleArrayList startVelocity, TDoubleArrayList targetPosition, TDoubleArrayList targetVelocity) {
        if (startPosition.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        if (startVelocity.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        if (targetPosition.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        if (targetVelocity.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        for (int i = 0; i < this.dimensions.getIntegerValue(); ++i) {
            this.x0.set(i, startPosition.get(i));
            this.xd0.set(i, startVelocity.get(i));
            this.x1.set(i, targetPosition.get(i));
            this.xd1.set(i, targetVelocity.get(i));
        }
    }

    public void setEndPointWeights(int dimension, double startPositionWeight, double startVelocityWeight, double targetPositionWeight, double targetVelocityWeight) {
        if (dimension < 0 || dimension >= this.dimensions.getValue()) {
            throw new IllegalArgumentException("Illegal dimension, expected to be in [0, " + this.dimensions.getValue() + "[, but was: " + dimension);
        }
        this.w0.set(dimension, startPositionWeight);
        this.wd0.set(dimension, startVelocityWeight);
        this.w1.set(dimension, targetPositionWeight);
        this.wd1.set(dimension, targetVelocityWeight);
    }

    public void setEndPointWeights(TDoubleArrayList startPositionWeight, TDoubleArrayList startVelocityWeight, TDoubleArrayList targetPositionWeight, TDoubleArrayList targetVelocityWeight) {
        if (startPositionWeight != null && startPositionWeight.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        if (startVelocityWeight != null && startVelocityWeight.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        if (targetPositionWeight != null && targetPositionWeight.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        if (targetVelocityWeight != null && targetVelocityWeight.size() != this.dimensions.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Input");
        }
        if (startPositionWeight == null) {
            this.w0.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
        }
        if (startVelocityWeight == null) {
            this.wd0.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
        }
        if (targetPositionWeight == null) {
            this.w1.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
        }
        if (targetVelocityWeight == null) {
            this.wd1.fill(0, this.dimensions.getValue(), Double.POSITIVE_INFINITY);
        }
        for (int i = 0; i < this.dimensions.getValue(); ++i) {
            if (startPositionWeight != null) {
                this.w0.set(i, startPositionWeight.get(i));
            }
            if (startVelocityWeight != null) {
                this.w1.set(i, targetPositionWeight.get(i));
            }
            if (targetPositionWeight != null) {
                this.wd0.set(i, startVelocityWeight.get(i));
            }
            if (targetVelocityWeight == null) continue;
            this.wd1.set(i, targetVelocityWeight.get(i));
        }
    }

    public void setWaypoints(List<TDoubleArrayList> waypoints) {
        if (waypoints.size() > 200) {
            throw new RuntimeException("Too Many Waypoints");
        }
        this.nWaypoints.set(waypoints.size());
        for (int i = 0; i < this.nWaypoints.getIntegerValue(); ++i) {
            if (waypoints.get(i).size() != this.dimensions.getIntegerValue()) {
                throw new RuntimeException("Unexpected Size of Input");
            }
            waypoints.get(i).toArray(this.waypoints.get((int)i).data);
        }
    }

    public void compute() {
        this.compute(200);
    }

    public void compute(int maxIterations) {
        this.intervals.set(this.nWaypoints.getIntegerValue() + 1);
        this.intervalTimes.reshape(this.intervals.getValue(), 1);
        CommonOps_DDRM.fill((DMatrixD1)this.intervalTimes, (double)(1.0 / (double)this.intervals.getValue()));
        this.computeInternal(maxIterations);
    }

    public void computeForFixedTime(TDoubleArrayList waypointTimes) {
        this.compute(0, waypointTimes);
    }

    public void compute(int maxIterations, TDoubleArrayList waypointTimes) {
        this.intervals.set(this.nWaypoints.getIntegerValue() + 1);
        this.setIntervalTimes(waypointTimes);
        this.computeInternal(maxIterations);
    }

    private void computeInternal(int maxIterations) {
        this.computeTimer.startMeasurement();
        this.timeGain.set(0.001);
        this.problemSize.set(this.dimensions.getIntegerValue() * 4 * this.intervals.getValue());
        this.costs.reset();
        this.costs.add(this.solveMinAcceleration());
        this.iteration.set(0);
        for (int iteration = 0; iteration < maxIterations && !this.doFullTimeUpdate(); ++iteration) {
        }
        this.computeTimer.stopMeasurement();
    }

    private void setIntervalTimes(TDoubleArrayList waypointTimes) {
        this.intervalTimes.reshape(this.intervals.getValue(), 1);
        if (waypointTimes.size() != this.nWaypoints.getValue()) {
            throw new RuntimeException("Unexpected number of waypoint times. Need " + this.nWaypoints.getValue() + ", got " + waypointTimes.size() + ".");
        }
        for (int i = 0; i < this.intervals.getValue(); ++i) {
            double previousWaypointTime = i == 0 ? 0.0 : waypointTimes.get(i - 1);
            double waypointTime = i == this.nWaypoints.getValue() ? 1.0 : waypointTimes.get(i);
            double intervalTime = waypointTime - previousWaypointTime;
            if (intervalTime < 0.0 || intervalTime > 1.0) {
                throw new RuntimeException("Time in this trajectory is from 0.0 to 1.0. Got invalid waypoint times:\n" + waypointTimes.toString());
            }
            this.intervalTimes.set(i, intervalTime);
        }
    }

    public boolean doFullTimeUpdate() {
        double oldCost = this.costs.get(this.iteration.getIntegerValue());
        double newCost = this.computeTimeUpdate(oldCost);
        this.costs.add(newCost);
        this.iteration.increment();
        return Math.abs(oldCost - newCost) < 0.01;
    }

    private double computeTimeUpdate(double cost) {
        this.timeUpdateTimer.startMeasurement();
        int intervals = this.intervals.getIntegerValue();
        this.timeGradient.reshape(intervals, 1);
        this.saveIntervalTimes.set((DMatrixD1)this.intervalTimes);
        for (int i = 0; i < intervals; ++i) {
            for (int j = 0; j < intervals; ++j) {
                if (j == i) {
                    this.intervalTimes.add(j, 0, 1.0E-7);
                    continue;
                }
                this.intervalTimes.add(j, 0, -1.0E-7 / (double)(intervals - 1));
            }
            double value = (this.solveMinAcceleration() - cost) / 1.0E-7;
            this.timeGradient.set(i, value);
            this.intervalTimes.set((DMatrixD1)this.saveIntervalTimes);
        }
        double length = CommonOps_DDRM.elementSum((DMatrixD1)this.timeGradient);
        CommonOps_DDRM.add((DMatrixD1)this.timeGradient, (double)(-length / (double)intervals));
        for (int i = 0; i < 10; ++i) {
            double newCost = this.applyTimeUpdate();
            if (!(newCost > cost)) {
                return newCost;
            }
            this.timeGain.set(this.timeGain.getDoubleValue() * 0.5);
            this.intervalTimes.set((DMatrixD1)this.saveIntervalTimes);
        }
        double newCost = this.applyTimeUpdate();
        this.timeUpdateTimer.stopMeasurement();
        return newCost;
    }

    private double applyTimeUpdate() {
        this.timeUpdate.set((DMatrixD1)this.timeGradient);
        CommonOps_DDRM.scale((double)(-this.timeGain.getDoubleValue()), (DMatrixD1)this.timeUpdate);
        double maxUpdate = CommonOps_DDRM.elementMaxAbs((DMatrixD1)this.timeUpdate);
        double minIntervalTime = CommonOps_DDRM.elementMin((DMatrixD1)this.intervalTimes);
        if (maxUpdate > 0.4 * minIntervalTime) {
            CommonOps_DDRM.scale((double)(0.4 * minIntervalTime / maxUpdate), (DMatrixD1)this.timeUpdate);
        }
        for (int i = 0; i < this.intervals.getIntegerValue(); ++i) {
            this.intervalTimes.add(i, 0, this.timeUpdate.get(i));
        }
        return this.solveMinAcceleration();
    }

    private double solveMinAcceleration() {
        double cost = 0.0;
        this.x.clear();
        for (int dimension = 0; dimension < this.dimensions.getValue(); ++dimension) {
            cost += this.solveDimension(dimension, (DMatrixRMaj)this.x.add());
        }
        return cost;
    }

    private double solveDimension(int dimension, DMatrixRMaj solutionToPack) {
        this.solver.setEndpoints(this.x0.get(dimension), this.xd0.get(dimension), this.x1.get(dimension), this.xd1.get(dimension));
        this.solver.setEndpointWeights(this.w0.get(dimension), this.wd0.get(dimension), this.w1.get(dimension), this.wd1.get(dimension));
        this.solver.clearWaypoints();
        double time = 0.0;
        for (int w = 0; w < this.nWaypoints.getValue(); ++w) {
            this.solver.addWaypoint(this.waypoints.get(w).get(dimension), time += this.intervalTimes.get(w));
        }
        return this.solver.solveAndComputeCost(solutionToPack);
    }

    public void getWaypointTimes(TDoubleArrayList timesToPack) {
        timesToPack.reset();
        for (int i = 0; i < this.nWaypoints.getIntegerValue(); ++i) {
            if (i == 0) {
                timesToPack.add(this.intervalTimes.get(0));
                continue;
            }
            timesToPack.add(timesToPack.get(i - 1) + this.intervalTimes.get(i));
        }
    }

    public double getWaypointTime(int waypoint) {
        if (waypoint < 0) {
            throw new RuntimeException("Unexpected Waypoint Index");
        }
        if (waypoint > this.nWaypoints.getIntegerValue() - 1) {
            throw new RuntimeException("Unexpected Waypoint Index");
        }
        double time = this.intervalTimes.get(0);
        for (int i = 1; i < waypoint + 1; ++i) {
            time += this.intervalTimes.get(i);
        }
        return time;
    }

    public void getPolynomialCoefficients(List<TDoubleArrayList> coefficientsToPack, int dimension) {
        if (coefficientsToPack.size() != this.intervals.getIntegerValue()) {
            throw new RuntimeException("Unexpected Size of Output");
        }
        if (dimension > this.dimensions.getIntegerValue() - 1 || dimension < 0) {
            throw new RuntimeException("Unknown Dimension");
        }
        DMatrixRMaj xDim = (DMatrixRMaj)this.x.get(dimension);
        for (int i = 0; i < this.intervals.getIntegerValue(); ++i) {
            int index = i * 4;
            CommonOps_DDRM.extract((DMatrix)xDim, (int)index, (int)(index + 4), (int)0, (int)1, (DMatrix)this.tempCoeffs, (int)0, (int)0);
            coefficientsToPack.get(i).reset();
            coefficientsToPack.get(i).add(this.tempCoeffs.getData());
        }
    }

    public void getWaypointVelocity(TDoubleArrayList velocityToPack, int waypointIndex) {
        double waypointTime = this.getWaypointTime(waypointIndex);
        this.tempLine.reshape(1, 4);
        MultiCubicSpline1DSolver.getVelocityConstraintABlock(waypointTime, 0, 0, this.tempLine);
        velocityToPack.reset();
        for (int dimension = 0; dimension < this.dimensions.getIntegerValue(); ++dimension) {
            DMatrixRMaj xDim = (DMatrixRMaj)this.x.get(dimension);
            int index = waypointIndex * 4;
            CommonOps_DDRM.extract((DMatrix)xDim, (int)index, (int)(index + 4), (int)0, (int)1, (DMatrix)this.tempCoeffs, (int)0, (int)0);
            velocityToPack.add(CommonOps_DDRM.dot((DMatrixD1)this.tempCoeffs, (DMatrixD1)this.tempLine));
        }
    }

    public void getStartPosition(TDoubleArrayList positionToPack) {
        this.tempLine.reshape(1, 4);
        MultiCubicSpline1DSolver.getPositionConstraintABlock(0.0, 0, 0, this.tempLine);
        positionToPack.reset();
        for (int dimension = 0; dimension < this.dimensions.getIntegerValue(); ++dimension) {
            DMatrixRMaj xDim = (DMatrixRMaj)this.x.get(dimension);
            CommonOps_DDRM.extract((DMatrix)xDim, (int)0, (int)4, (int)0, (int)1, (DMatrix)this.tempCoeffs, (int)0, (int)0);
            positionToPack.add(CommonOps_DDRM.dot((DMatrixD1)this.tempCoeffs, (DMatrixD1)this.tempLine));
        }
    }

    public void getStartVelocity(TDoubleArrayList velocityToPack) {
        this.tempLine.reshape(1, 4);
        MultiCubicSpline1DSolver.getVelocityConstraintABlock(0.0, 0, 0, this.tempLine);
        velocityToPack.reset();
        for (int dimension = 0; dimension < this.dimensions.getIntegerValue(); ++dimension) {
            DMatrixRMaj xDim = (DMatrixRMaj)this.x.get(dimension);
            CommonOps_DDRM.extract((DMatrix)xDim, (int)0, (int)4, (int)0, (int)1, (DMatrix)this.tempCoeffs, (int)0, (int)0);
            velocityToPack.add(CommonOps_DDRM.dot((DMatrixD1)this.tempCoeffs, (DMatrixD1)this.tempLine));
        }
    }

    public void getTargetPosition(TDoubleArrayList positionToPack) {
        this.tempLine.reshape(1, 4);
        MultiCubicSpline1DSolver.getPositionConstraintABlock(1.0, 0, 0, this.tempLine);
        positionToPack.reset();
        int index = this.nWaypoints.getValue() * 4;
        for (int dimension = 0; dimension < this.dimensions.getIntegerValue(); ++dimension) {
            DMatrixRMaj xDim = (DMatrixRMaj)this.x.get(dimension);
            CommonOps_DDRM.extract((DMatrix)xDim, (int)index, (int)(index + 4), (int)0, (int)1, (DMatrix)this.tempCoeffs, (int)0, (int)0);
            positionToPack.add(CommonOps_DDRM.dot((DMatrixD1)this.tempCoeffs, (DMatrixD1)this.tempLine));
        }
    }

    public void getTargetVelocity(TDoubleArrayList velocityToPack) {
        this.tempLine.reshape(1, 4);
        MultiCubicSpline1DSolver.getVelocityConstraintABlock(1.0, 0, 0, this.tempLine);
        velocityToPack.reset();
        int index = this.nWaypoints.getValue() * 4;
        for (int dimension = 0; dimension < this.dimensions.getIntegerValue(); ++dimension) {
            DMatrixRMaj xDim = (DMatrixRMaj)this.x.get(dimension);
            CommonOps_DDRM.extract((DMatrix)xDim, (int)index, (int)(index + 4), (int)0, (int)1, (DMatrix)this.tempCoeffs, (int)0, (int)0);
            velocityToPack.add(CommonOps_DDRM.dot((DMatrixD1)this.tempCoeffs, (DMatrixD1)this.tempLine));
        }
    }
}

