/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize;

import cc.mallet.optimize.Optimizable;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.optimize.NeuralNetEpochListener;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.deeplearning4j.util.NonZeroStoppingConjugateGradient;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class NeuralNetworkOptimizer
implements Optimizable.ByGradientValue,
Serializable,
NeuralNetEpochListener {
    private static final long serialVersionUID = 4455143696487934647L;
    protected BaseNeuralNetwork network;
    protected double lr;
    protected Object[] extraParams;
    protected double tolerance = 1.0E-4;
    protected static Logger log = LoggerFactory.getLogger(NeuralNetworkOptimizer.class);
    protected List<Double> errors = new ArrayList<Double>();
    protected double minLearningRate = 0.001;
    protected transient NonZeroStoppingConjugateGradient opt;

    public NeuralNetworkOptimizer(BaseNeuralNetwork network, double lr, Object[] trainingParams) {
        this.network = network;
        this.lr = lr;
        this.extraParams = trainingParams;
    }

    public void train(DoubleMatrix x) {
        if (this.opt == null) {
            this.opt = new NonZeroStoppingConjugateGradient((Optimizable.ByGradientValue)this, this);
        }
        this.opt.setTolerance(this.tolerance);
        int epochs = (Integer)this.extraParams[2];
        this.opt.setMaxIterations(10000);
        this.opt.optimize(epochs);
    }

    @Override
    public void epochDone(int epoch) {
        int plotEpochs = this.network.getRenderEpochs();
        if (plotEpochs <= 0) {
            return;
        }
        if (epoch % plotEpochs == 0 || epoch == 0) {
            NeuralNetPlotter plotter = new NeuralNetPlotter();
            plotter.plotNetworkGradient(this.network, this.network.getGradient(this.extraParams));
        }
    }

    public List<Double> getErrors() {
        return this.errors;
    }

    public int getNumParameters() {
        return this.network.W.length + this.network.hBias.length + this.network.vBias.length;
    }

    public void getParameters(double[] buffer) {
        for (int i = 0; i < buffer.length; ++i) {
            buffer[i] = this.getParameter(i);
        }
    }

    public double getParameter(int index) {
        if (index >= this.network.W.length) {
            int i = this.getAdjustedIndex(index);
            if (index >= this.network.vBias.length + this.network.W.length) {
                return this.network.hBias.get(i);
            }
            return this.network.vBias.get(i);
        }
        return this.network.W.get(index);
    }

    public void setParameters(double[] params) {
        for (int i = 0; i < params.length; ++i) {
            this.setParameter(i, params[i]);
        }
    }

    public void setParameter(int index, double value) {
        if (index >= this.network.W.length) {
            if (index >= this.network.vBias.length + this.network.W.length) {
                int i = this.getAdjustedIndex(index);
                this.network.hBias.put(i, value);
            } else {
                int i = this.getAdjustedIndex(index);
                this.network.vBias.put(i, value);
            }
        } else {
            this.network.W.put(index, value);
        }
    }

    private int getAdjustedIndex(int index) {
        int wLength = this.network.W.length;
        int vBiasLength = this.network.vBias.length;
        if (index < wLength) {
            return index;
        }
        if (index >= wLength + vBiasLength) {
            int hIndex = index - wLength - vBiasLength;
            return hIndex;
        }
        int vIndex = index - wLength;
        return vIndex;
    }

    public abstract void getValueGradient(double[] var1);

    public double getValue() {
        return -this.network.getReConstructionCrossEntropy();
    }
}

