/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.robotics.numericalMethods;

import gnu.trove.list.array.TDoubleArrayList;
import java.util.function.ToDoubleFunction;
import us.ihmc.commons.Conversions;
import us.ihmc.commons.MathTools;

public class GradientDescentModule {
    private static final boolean DEBUG = false;
    private ToDoubleFunction<TDoubleArrayList> function;
    private final int dimension;
    private final TDoubleArrayList initialInput;
    private boolean solved;
    private final TDoubleArrayList optimalInput;
    private double optimalQuery;
    private double computationTime;
    private final TDoubleArrayList inputUpperLimit;
    private final TDoubleArrayList inputLowerLimit;
    private double deltaThreshold = 1.0E-9;
    private int maximumIterations = 1000;
    private double learningRate = -1.0;
    private double minLearningRate = -1.0E-6;
    private double learningRateToUse = this.learningRate;
    private double perturb = 0.001;
    private double reducingLearningRateRatio = 1.1;

    public GradientDescentModule(ToDoubleFunction<TDoubleArrayList> function, TDoubleArrayList initial) {
        this.function = function;
        this.dimension = initial.size();
        this.initialInput = new TDoubleArrayList();
        this.optimalInput = new TDoubleArrayList();
        this.inputUpperLimit = new TDoubleArrayList();
        this.inputLowerLimit = new TDoubleArrayList();
        for (int i = 0; i < this.dimension; ++i) {
            this.initialInput.add(initial.get(i));
            this.optimalInput.add(0.0);
            this.inputUpperLimit.add(Double.POSITIVE_INFINITY);
            this.inputLowerLimit.add(Double.NEGATIVE_INFINITY);
        }
    }

    public void redefineModule(ToDoubleFunction<TDoubleArrayList> function) {
        this.function = function;
    }

    private void reduceLearningRate() {
        this.learningRateToUse = Math.min(this.learningRateToUse / this.reducingLearningRateRatio, this.minLearningRate);
    }

    public void setMaximumIterations(int value) {
        this.maximumIterations = value;
    }

    public void setInputUpperLimit(TDoubleArrayList limit) {
        this.inputUpperLimit.clear();
        for (int i = 0; i < this.dimension; ++i) {
            this.inputUpperLimit.add(limit.get(i));
        }
    }

    public void setInputLowerLimit(TDoubleArrayList limit) {
        this.inputLowerLimit.clear();
        for (int i = 0; i < this.dimension; ++i) {
            this.inputLowerLimit.add(limit.get(i));
        }
    }

    public void setConvergenceThreshold(double value) {
        this.deltaThreshold = value;
    }

    public void setLearningRate(double value) {
        this.learningRate = -Math.abs(value);
    }

    public void setMinimumLearningRate(double value) {
        this.minLearningRate = -Math.abs(value);
    }

    public void setPerturbationSize(double value) {
        this.perturb = Math.abs(value);
    }

    public void setReducingLearningRateRatio(double value) {
        this.reducingLearningRateRatio = value;
    }

    public int run() {
        long startTime = System.nanoTime();
        this.solved = false;
        int iteration = 0;
        TDoubleArrayList pastInput = new TDoubleArrayList();
        for (int i = 0; i < this.dimension; ++i) {
            pastInput.add(this.initialInput.get(i));
        }
        this.optimalQuery = this.function.applyAsDouble(pastInput);
        this.learningRateToUse = this.learningRate;
        double pastQuery = 0.0;
        double newQuery = 0.0;
        for (int i = 0; i < this.maximumIterations; ++i) {
            int j;
            long curTime = System.nanoTime();
            ++iteration;
            pastQuery = this.optimalQuery;
            double tempSignForPerturb = 1.0;
            TDoubleArrayList gradient = new TDoubleArrayList();
            for (int j2 = 0; j2 < this.dimension; ++j2) {
                TDoubleArrayList perturbedInput = new TDoubleArrayList();
                for (int k = 0; k < this.dimension; ++k) {
                    perturbedInput.add(pastInput.get(k));
                }
                if (perturbedInput.get(j2) == this.inputUpperLimit.get(j2)) {
                    tempSignForPerturb = -1.0;
                }
                double tempInput = perturbedInput.get(j2) + this.perturb * tempSignForPerturb;
                perturbedInput.replace(j2, MathTools.clamp((double)tempInput, (double)this.inputLowerLimit.get(j2), (double)this.inputUpperLimit.get(j2)));
                double perturbedQuery = this.function.applyAsDouble(perturbedInput);
                gradient.add((perturbedQuery - pastQuery) / (this.perturb * tempSignForPerturb));
            }
            double gradientNorm = 0.0;
            for (j = 0; j < this.dimension; ++j) {
                gradientNorm += gradient.get(j) * gradient.get(j);
            }
            gradientNorm = Math.sqrt(gradientNorm);
            for (j = 0; j < this.dimension; ++j) {
                gradient.set(j, gradient.get(j) / gradientNorm);
            }
            this.optimalInput.clear();
            for (j = 0; j < this.dimension; ++j) {
                double input = pastInput.get(j) + gradient.get(j) * this.learningRateToUse;
                this.optimalInput.add(MathTools.clamp((double)input, (double)this.inputLowerLimit.get(j), (double)this.inputUpperLimit.get(j)));
            }
            newQuery = this.function.applyAsDouble(this.optimalInput);
            this.reduceLearningRate();
            this.optimalQuery = newQuery;
            double delta = Math.abs((pastQuery - this.optimalQuery) / this.optimalQuery);
            if (delta < this.deltaThreshold) break;
            pastInput.clear();
            for (int j3 = 0; j3 < this.dimension; ++j3) {
                pastInput.add(this.optimalInput.get(j3));
            }
        }
        this.computationTime = Conversions.nanosecondsToSeconds((long)(System.nanoTime() - startTime));
        return iteration;
    }

    public boolean isSolved() {
        return this.solved;
    }

    public TDoubleArrayList getOptimalInput() {
        return this.optimalInput;
    }

    public double getOptimalQuery() {
        return this.optimalQuery;
    }

    public double getComputationTime() {
        return this.computationTime;
    }
}

