/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.solvers;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.exception.InvalidStepException;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TerminationCondition;
import org.deeplearning4j.optimize.solvers.BackTrackLineSearch;
import org.deeplearning4j.optimize.terminations.EpsTermination;
import org.deeplearning4j.optimize.terminations.ZeroDirection;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.learning.AdaGrad;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseOptimizer {
    protected NeuralNetConfiguration conf;
    protected AdaGrad adaGrad;
    protected int iteration = 0;
    protected static Logger log = LoggerFactory.getLogger(BaseOptimizer.class);
    protected StepFunction stepFunction;
    private Collection<IterationListener> iterationListeners = new ArrayList<IterationListener>();
    protected Collection<TerminationCondition> terminationConditions = new ArrayList<TerminationCondition>();
    protected Model model;
    protected BackTrackLineSearch lineMaximizer;
    protected double step;
    private int batchSize = 10;
    protected double score;
    protected double oldScore;
    protected double stpMax = Double.MAX_VALUE;
    public static final String GRADIENT_KEY = "g";
    public static final String SCORE_KEY = "score";
    public static final String PARAMS_KEY = "params";
    protected Map<String, Object> searchState = new HashMap<String, Object>();

    public BaseOptimizer(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<IterationListener> iterationListeners, Model model) {
        this(conf, stepFunction, iterationListeners, Arrays.asList(new ZeroDirection(), new EpsTermination()), model);
    }

    public BaseOptimizer(NeuralNetConfiguration conf, StepFunction stepFunction, Collection<IterationListener> iterationListeners, Collection<TerminationCondition> terminationConditions, Model model) {
        this.conf = conf;
        this.stepFunction = stepFunction;
        this.iterationListeners = iterationListeners;
        this.terminationConditions = terminationConditions;
        this.model = model;
        this.lineMaximizer = new BackTrackLineSearch(model, stepFunction, this);
        this.lineMaximizer.setStpmax(this.stpMax);
    }

    public void updateGradientAccordingToParams(INDArray gradient, INDArray params, int batchSize) {
        int key;
        if (this.adaGrad == null) {
            this.adaGrad = new AdaGrad(1, gradient.length());
        }
        if (this.iteration != 0 && this.conf.getResetAdaGradIterations() > 0 && this.iteration % this.conf.getResetAdaGradIterations() == 0) {
            this.adaGrad.historicalGradient = null;
            log.info("Resetting adagrad");
        }
        double momentum = this.conf.getMomentum();
        if (this.conf.getMomentumAfter() != null && !this.conf.getMomentumAfter().isEmpty() && this.iteration >= (key = this.conf.getMomentumAfter().keySet().iterator().next().intValue())) {
            momentum = this.conf.getMomentumAfter().get(key);
        }
        gradient = this.adaGrad.getGradient(gradient);
        if (this.conf.isUseAdaGrad()) {
            gradient.assign(this.adaGrad.getGradient(gradient));
        } else {
            gradient.muli((Number)this.conf.getLr());
        }
        if (momentum > 0.0) {
            gradient.addi(gradient.mul((Number)momentum).addi(gradient.mul((Number)(1.0 - momentum))));
        }
        if (this.conf.isUseRegularization() && this.conf.getL2() > 0.0 && this.conf.isUseAdaGrad()) {
            gradient.subi(params.mul((Number)this.conf.getL2()));
        }
        if (this.conf.isConstrainGradientToUnitNorm()) {
            gradient.divi(gradient.norm2(Integer.MAX_VALUE));
        }
        gradient.divi((Number)batchSize);
    }

    public Pair<Gradient, Double> gradientAndScore() {
        Pair<Gradient, Double> pair = this.model.gradientAndScore();
        return pair;
    }

    public boolean optimize() {
        this.model.validateInput();
        Pair<Gradient, Double> pair = this.gradientAndScore();
        this.setupSearchState(pair);
        this.score = pair.getSecond();
        INDArray gradient = (INDArray)this.searchState.get(GRADIENT_KEY);
        for (TerminationCondition condition : this.terminationConditions) {
            if (!condition.terminate(0.0, 0.0, new Object[]{gradient})) continue;
            return true;
        }
        boolean testLineSearch = this.preFirstStepProcess(gradient);
        if (testLineSearch) {
            try {
                INDArray params = (INDArray)this.searchState.get(PARAMS_KEY);
                this.step = this.lineMaximizer.optimize(gradient, this.conf.getNumIterations(), this.step, params, gradient);
            }
            catch (InvalidStepException e) {
                e.printStackTrace();
            }
            gradient = (INDArray)this.searchState.get(GRADIENT_KEY);
            this.postFirstStep(gradient);
            if (this.step == 0.0) {
                log.warn("Unable to step in direction");
                return false;
            }
        }
        for (int i = 0; i < this.conf.getNumIterations(); ++i) {
            this.preProcessLine(gradient);
            try {
                INDArray params = (INDArray)this.searchState.get(PARAMS_KEY);
                this.step = this.lineMaximizer.optimize(gradient, this.conf.getNumIterations(), this.step, params, gradient);
            }
            catch (InvalidStepException e) {
                e.printStackTrace();
            }
            for (IterationListener listener : this.iterationListeners) {
                listener.iterationDone(i);
            }
            this.oldScore = this.score;
            pair = this.gradientAndScore();
            this.score = pair.getSecond();
            gradient = pair.getFirst().gradient(this.conf.getGradientList());
            this.searchState.put(GRADIENT_KEY, gradient);
            this.searchState.put(SCORE_KEY, this.score);
            for (TerminationCondition condition : this.terminationConditions) {
                if (!condition.terminate(this.score, this.oldScore, new Object[]{gradient})) continue;
                return true;
            }
            this.postStep();
            log.info("Score at iteration " + i + " is " + this.score);
            for (TerminationCondition condition : this.terminationConditions) {
                if (!condition.terminate(this.score, this.oldScore, new Object[]{gradient})) continue;
                return true;
            }
        }
        return true;
    }

    public double score() {
        return (Double)this.searchState.get(SCORE_KEY);
    }

    protected void postFirstStep(INDArray gradient) {
    }

    protected boolean preFirstStepProcess(INDArray gradient) {
        return false;
    }

    public int batchSize() {
        return this.batchSize;
    }

    public void setBatchSize(int batchSize) {
        this.batchSize = batchSize;
    }

    public void preProcessLine(INDArray line) {
    }

    public void postStep() {
    }

    public void setupSearchState(Pair<Gradient, Double> pair) {
        INDArray gradient = pair.getFirst().gradient(this.conf.getGradientList());
        INDArray params = this.model.params();
        this.updateGradientAccordingToParams(gradient, params, this.batchSize());
        this.searchState.put(GRADIENT_KEY, gradient);
        this.searchState.put(SCORE_KEY, pair.getSecond());
        this.searchState.put(PARAMS_KEY, params);
    }
}

