/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.optimizers;

import java.io.Serializable;
import java.util.List;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.nn.gradient.OutputLayerGradient;
import org.deeplearning4j.optimize.api.OptimizableByGradientValueMatrix;
import org.deeplearning4j.optimize.api.TrainingEvaluator;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MultiLayerNetworkOptimizer
implements Serializable,
OptimizableByGradientValueMatrix {
    private static final long serialVersionUID = -3012638773299331828L;
    protected BaseMultiLayerNetwork network;
    private static Logger log = LoggerFactory.getLogger(MultiLayerNetworkOptimizer.class);
    private double lr;
    private int currentIteration;

    public MultiLayerNetworkOptimizer(BaseMultiLayerNetwork network, double lr) {
        this.network = network;
        this.lr = lr;
    }

    @Override
    public void setCurrentIteration(int value) {
        this.currentIteration = value;
    }

    public void optimize(INDArray labels, double lr, int epochs, TrainingEvaluator eval) {
        this.network.getOutputLayer().setLabels(labels);
        if (!this.network.isForceNumEpochs()) {
            if (this.network.isShouldBackProp()) {
                this.network.backProp(lr, epochs, eval);
            }
        } else {
            log.info("Training for " + epochs + " epochs");
            List<INDArray> activations = this.network.feedForward();
            INDArray train = activations.get(activations.size() - 1);
            for (int i = 0; i < epochs; ++i) {
                if (i % this.network.getDefaultConfiguration().getResetAdaGradIterations() == 0) {
                    this.network.getOutputLayer().getAdaGrad().historicalGradient = null;
                }
                this.network.getOutputLayer().train(train, labels, lr);
            }
            if (this.network.isShouldBackProp()) {
                this.network.backProp(lr, epochs, eval);
            }
        }
    }

    public void optimize(INDArray labels, double lr, int iteration) {
        this.network.getOutputLayer().setLabels(labels);
        if (!this.network.isForceNumEpochs()) {
            if (this.network.isShouldBackProp()) {
                this.network.backProp(lr, iteration);
            }
            this.network.getOutputLayer().trainTillConvergence(lr, iteration);
        } else {
            log.info("Training for " + iteration + " iteration");
            if (this.network.isShouldBackProp()) {
                this.network.backProp(lr, iteration);
            }
        }
    }

    @Override
    public int getNumParameters() {
        return this.network.getOutputLayer().getW().length() + this.network.getOutputLayer().getB().length();
    }

    public void getParameters(double[] buffer) {
        int i;
        int idx = 0;
        for (i = 0; i < this.network.getOutputLayer().getW().length(); ++i) {
            buffer[idx++] = (Double)this.network.getOutputLayer().getW().getScalar(i).element();
        }
        for (i = 0; i < this.network.getOutputLayer().getB().length(); ++i) {
            buffer[idx++] = (Double)this.network.getOutputLayer().getB().getScalar(i).element();
        }
    }

    @Override
    public double getParameter(int index) {
        if (index >= this.network.getOutputLayer().getW().length()) {
            int i = index - this.network.getOutputLayer().getB().length();
            return (Double)this.network.getOutputLayer().getB().getScalar(i).element();
        }
        return (Double)this.network.getOutputLayer().getW().getScalar(index).element();
    }

    public void setParameters(double[] params) {
        int i;
        int idx = 0;
        for (i = 0; i < this.network.getOutputLayer().getW().length(); ++i) {
            this.network.getOutputLayer().getW().putScalar(i, params[idx++]);
        }
        for (i = 0; i < this.network.getOutputLayer().getB().length(); ++i) {
            this.network.getOutputLayer().getB().putScalar(i, params[idx++]);
        }
    }

    @Override
    public void setParameter(int index, double value) {
        if (index >= this.network.getOutputLayer().getW().length()) {
            int i = index - this.network.getOutputLayer().getB().length();
            this.network.getOutputLayer().getB().putScalar(i, value);
        } else {
            this.network.getOutputLayer().getW().putScalar(index, value);
        }
    }

    public void getValueGradient(double[] buffer) {
        int i;
        OutputLayerGradient gradient = this.network.getOutputLayer().getGradient(this.lr);
        INDArray weightGradient = gradient.getwGradient();
        INDArray biasGradient = gradient.getbGradient();
        int idx = 0;
        for (i = 0; i < weightGradient.length(); ++i) {
            buffer[idx++] = (Double)weightGradient.getScalar(i).element();
        }
        for (i = 0; i < biasGradient.length(); ++i) {
            buffer[idx++] = (Double)biasGradient.getScalar(i).element();
        }
    }

    @Override
    public double getValue() {
        return this.network.score();
    }

    @Override
    public INDArray getParameters() {
        double[] d = new double[this.getNumParameters()];
        this.getParameters(d);
        return Nd4j.create((double[])d);
    }

    @Override
    public void setParameters(INDArray params) {
        this.setParameters(params.data().asDouble());
    }

    @Override
    public INDArray getValueGradient(int iteration) {
        double[] buffer = new double[this.getNumParameters()];
        this.getValueGradient(buffer);
        return Nd4j.create((double[])buffer);
    }
}

