/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize;

import cc.mallet.optimize.Optimizable;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.nn.NeuralNetwork;
import org.deeplearning4j.optimize.NeuralNetEpochListener;
import org.deeplearning4j.optimize.OptimizableByGradientValueMatrix;
import org.deeplearning4j.optimize.VectorizedDeepLearningGradientAscent;
import org.deeplearning4j.optimize.VectorizedNonZeroStoppingConjugateGradient;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.deeplearning4j.util.OptimizerMatrix;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class NeuralNetworkOptimizer
implements Optimizable.ByGradientValue,
OptimizableByGradientValueMatrix,
Serializable,
NeuralNetEpochListener {
    private static final long serialVersionUID = 4455143696487934647L;
    protected NeuralNetwork network;
    protected double lr;
    protected Object[] extraParams;
    protected double tolerance = 1.0E-5;
    protected static Logger log = LoggerFactory.getLogger(NeuralNetworkOptimizer.class);
    protected List<Double> errors = new ArrayList<Double>();
    protected double minLearningRate = 0.001;
    protected transient OptimizerMatrix opt;
    protected NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm;
    protected NeuralNetwork.LossFunction lossFunction;

    public NeuralNetworkOptimizer(NeuralNetwork network, double lr, Object[] trainingParams, NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm, NeuralNetwork.LossFunction lossFunction) {
        this.network = network;
        this.lr = lr;
        this.extraParams = trainingParams;
        this.optimizationAlgorithm = optimizationAlgorithm;
        this.lossFunction = lossFunction;
    }

    private void createOptimizationAlgorithm() {
        if (this.optimizationAlgorithm == NeuralNetwork.OptimizationAlgorithm.CONJUGATE_GRADIENT) {
            this.opt = new VectorizedNonZeroStoppingConjugateGradient((OptimizableByGradientValueMatrix)this, this);
            ((VectorizedNonZeroStoppingConjugateGradient)this.opt).setTolerance(this.tolerance);
        } else {
            this.opt = new VectorizedDeepLearningGradientAscent((OptimizableByGradientValueMatrix)this, this);
            ((VectorizedDeepLearningGradientAscent)this.opt).setTolerance(this.tolerance);
        }
    }

    public void train(DoubleMatrix x) {
        if (this.opt == null) {
            this.createOptimizationAlgorithm();
        }
        int epochs = this.extraParams.length < 3 ? 1000 : (Integer)this.extraParams[2];
        this.opt.optimize(epochs);
    }

    @Override
    public void epochDone(int epoch) {
        int plotEpochs = this.network.getRenderEpochs();
        if (plotEpochs <= 0) {
            return;
        }
        if (epoch % plotEpochs == 0 || epoch == 0) {
            NeuralNetPlotter plotter = new NeuralNetPlotter();
            plotter.plotNetworkGradient(this.network, this.network.getGradient(this.extraParams));
        }
    }

    public List<Double> getErrors() {
        return this.errors;
    }

    @Override
    public int getNumParameters() {
        return this.network.getW().length + this.network.gethBias().length + this.network.getvBias().length;
    }

    public void getParameters(double[] buffer) {
        for (int i = 0; i < buffer.length; ++i) {
            buffer[i] = this.getParameter(i);
        }
    }

    @Override
    public double getParameter(int index) {
        if (index >= this.network.getW().length) {
            int i = this.getAdjustedIndex(index);
            if (index >= this.network.getvBias().length + this.network.getW().length) {
                return this.network.gethBias().get(i);
            }
            return this.network.getvBias().get(i);
        }
        return this.network.getW().get(index);
    }

    public void setParameters(double[] params) {
        for (int i = 0; i < params.length; ++i) {
            this.setParameter(i, params[i]);
        }
    }

    @Override
    public void setParameter(int index, double value) {
        if (index >= this.network.getW().length) {
            if (index >= this.network.getvBias().length + this.network.getW().length) {
                int i = this.getAdjustedIndex(index);
                this.network.gethBias().put(i, value);
            } else {
                int i = this.getAdjustedIndex(index);
                this.network.getvBias().put(i, value);
            }
        } else {
            this.network.getW().put(index, value);
        }
    }

    private int getAdjustedIndex(int index) {
        int wLength = this.network.getW().length;
        int vBiasLength = this.network.getvBias().length;
        if (index < wLength) {
            return index;
        }
        if (index >= wLength + vBiasLength) {
            int hIndex = index - wLength - vBiasLength;
            return hIndex;
        }
        int vIndex = index - wLength;
        return vIndex;
    }

    @Override
    public DoubleMatrix getParameters() {
        double[] params = new double[this.getNumParameters()];
        this.getParameters(params);
        return new DoubleMatrix(params);
    }

    @Override
    public void setParameters(DoubleMatrix params) {
        this.setParameters(params.toArray());
    }

    @Override
    public DoubleMatrix getValueGradient() {
        double[] d = new double[this.getNumParameters()];
        this.getValueGradient(d);
        return new DoubleMatrix(d);
    }

    public abstract void getValueGradient(double[] var1);

    @Override
    public double getValue() {
        if (this.lossFunction == NeuralNetwork.LossFunction.RECONSTRUCTION_CROSSENTROPY) {
            return -this.network.getReConstructionCrossEntropy();
        }
        if (this.lossFunction == NeuralNetwork.LossFunction.SQUARED_LOSS) {
            return -this.network.squaredLoss();
        }
        if (this.lossFunction == NeuralNetwork.LossFunction.NEGATIVELOGLIKELIHOOD) {
            return -this.network.negativeLogLikelihood();
        }
        return -this.network.getReConstructionCrossEntropy();
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setTolerance(double tolerance) {
        this.tolerance = tolerance;
    }
}

