/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.optimizers;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix;
import org.deeplearning4j.optimize.solvers.VectorizedDeepLearningGradientAscent;
import org.deeplearning4j.optimize.solvers.VectorizedNonZeroStoppingConjugateGradient;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.deeplearning4j.util.OptimizerMatrix;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class NeuralNetworkOptimizer
implements OptimizableByGradientValueMatrix,
Serializable,
IterationListener {
    private static final long serialVersionUID = 4455143696487934647L;
    protected NeuralNetwork network;
    protected double lr;
    protected Object[] extraParams;
    protected double tolerance = 1.0E-5f;
    protected static Logger log = LoggerFactory.getLogger(NeuralNetworkOptimizer.class);
    protected List<Double> errors = new ArrayList<Double>();
    protected transient OptimizerMatrix opt;
    protected NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm;
    protected LossFunctions.LossFunction lossFunction;
    protected NeuralNetPlotter plotter = new NeuralNetPlotter();
    protected double maxStep = -1.0;
    protected int currIteration = -1;

    public NeuralNetworkOptimizer(NeuralNetwork network, double lr, Object[] trainingParams, NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm, LossFunctions.LossFunction lossFunction) {
        this.network = network;
        this.lr = lr;
        if (trainingParams != null) {
            this.extraParams = new Object[trainingParams.length + 1];
            System.arraycopy(trainingParams, 0, this.extraParams, 0, trainingParams.length);
        } else {
            this.extraParams = new Object[1];
        }
        this.optimizationAlgorithm = optimizationAlgorithm;
        this.lossFunction = lossFunction;
    }

    private void createOptimizationAlgorithm() {
        if (this.optimizationAlgorithm == NeuralNetwork.OptimizationAlgorithm.CONJUGATE_GRADIENT) {
            this.opt = new VectorizedNonZeroStoppingConjugateGradient((OptimizableByGradientValueMatrix)this, this);
            this.opt.setTolerance(this.tolerance);
        } else {
            this.opt = new VectorizedDeepLearningGradientAscent((OptimizableByGradientValueMatrix)this, this);
            this.opt.setTolerance(this.tolerance);
            if (this.maxStep > 0.0) {
                ((VectorizedDeepLearningGradientAscent)this.opt).setMaxStepSize(this.maxStep);
            }
        }
    }

    @Override
    public INDArray getParameters() {
        return this.network.params();
    }

    public void train(INDArray x) {
        if (this.opt == null) {
            this.createOptimizationAlgorithm();
        }
        this.network.setInput(x);
        int epochs = this.extraParams.length < 3 ? 1000 : (Integer)this.extraParams[2];
        this.opt.setMaxIterations(epochs);
        this.opt.optimize(epochs);
        this.network.backProp(this.lr, epochs, this.extraParams);
    }

    @Override
    public void iterationDone(int iteration) {
        int plotEpochs = this.network.conf().getRenderWeightIterations();
        if (plotEpochs <= 0) {
            return;
        }
        if (iteration % plotEpochs == 0) {
            this.plotter.plotNetworkGradient(this.network, this.network.getGradient(this.extraParams), 100);
        }
    }

    @Override
    public int getNumParameters() {
        return this.network.numParams();
    }

    @Override
    public double getParameter(int index) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void setParameters(INDArray params) {
        if (this.network.conf().isConstrainGradientToUnitNorm()) {
            params.divi(params.normmax(Integer.MAX_VALUE));
        }
        this.network.setParams(params);
    }

    @Override
    public void setParameter(int index, double value) {
        throw new UnsupportedOperationException();
    }

    private int getAdjustedIndex(int index) {
        int wLength = this.network.getW().length();
        int vBiasLength = this.network.getvBias().length();
        if (index < wLength) {
            return index;
        }
        if (index >= wLength + vBiasLength) {
            int hIndex = index - wLength - vBiasLength;
            return hIndex;
        }
        int vIndex = index - wLength;
        return vIndex;
    }

    @Override
    public INDArray getValueGradient(int iteration) {
        if (iteration >= 1) {
            this.extraParams[this.extraParams.length - 1] = iteration;
        }
        NeuralNetworkGradient g = this.network.getGradient(this.extraParams);
        return Nd4j.toFlattened(Arrays.asList(g.getwGradient(), g.getvBiasGradient(), g.gethBiasGradient()));
    }

    @Override
    public double getValue() {
        return -this.network.score();
    }

    @Override
    public void setCurrentIteration(int value) {
        if (value < 1) {
            return;
        }
        this.currIteration = value;
    }
}

