/*
 * Decompiled with CFR 0.152.
 */
package us.ihmc.robotics.functionApproximation.NeuralNetwork;

import java.util.ArrayList;
import us.ihmc.robotics.functionApproximation.NeuralNetwork.Layer;
import us.ihmc.robotics.functionApproximation.NeuralNetwork.Neuron;
import us.ihmc.robotics.functionApproximation.NeuralNetwork.activationFunction.ActivationFunction;
import us.ihmc.robotics.functionApproximation.NeuralNetwork.activationFunction.PassThrough;
import us.ihmc.robotics.functionApproximation.NeuralNetwork.importing.NeuralNetworkConfiguration;

public class NeuralNetwork {
    private final ArrayList<Layer> layers = new ArrayList();
    private String[] inputOrder;
    private String[] outputOrder;

    public NeuralNetwork() {
    }

    public NeuralNetwork(NeuralNetworkConfiguration config) {
        int[] numberOfNeuronsPerLayer = config.getNumberOfNeuronsPerLayer();
        this.createInputLayer(numberOfNeuronsPerLayer[0]);
        ActivationFunction[] activationFunctions = config.getActivationFunctions();
        double[][] bias = config.getBias();
        double[][][] weights = config.getWeights();
        for (int i = 1; i < numberOfNeuronsPerLayer.length; ++i) {
            this.createLayer(numberOfNeuronsPerLayer[i], bias[i], weights[i], activationFunctions[i]);
        }
        this.inputOrder = config.getInputVariableNames();
        this.outputOrder = config.getOutputVariableNames();
    }

    public void createInputLayer(int numberOfNeurons) {
        if (this.layers.size() > 0) {
            throw new IllegalArgumentException("Neural Network already contains a layer, make sure you create your input layer first! (or fix this class to be more modular)");
        }
        PassThrough passThrough = new PassThrough();
        Layer layer = new Layer();
        for (int i = 0; i < numberOfNeurons; ++i) {
            layer.addNeuron(new Neuron(passThrough, 0.0));
        }
        this.layers.add(layer);
    }

    public void createLayer(int numberOfNeurons, double[] bias, double[][] weights, ActivationFunction activationFunction) {
        if (this.layers.size() < 1) {
            throw new IllegalArgumentException("Neural Network does not contain an input layer, make sure you create your input layer first! (or fix this class to be more modular)");
        }
        Layer previousLayer = this.layers.get(this.layers.size() - 1);
        ArrayList<Neuron> previousLayerNeurons = previousLayer.getNeurons();
        Layer layer = new Layer();
        for (int currentLayerNeuronIndex = 0; currentLayerNeuronIndex < numberOfNeurons; ++currentLayerNeuronIndex) {
            Neuron neuron = new Neuron(activationFunction, bias[currentLayerNeuronIndex]);
            for (int previousLayerNeuronIndex = 0; previousLayerNeuronIndex < previousLayerNeurons.size(); ++previousLayerNeuronIndex) {
                double weight = weights[currentLayerNeuronIndex][previousLayerNeuronIndex];
                Neuron inputNeuron = previousLayerNeurons.get(previousLayerNeuronIndex);
                neuron.addInputNeuron(inputNeuron, weight);
            }
            layer.addNeuron(neuron);
        }
        this.layers.add(layer);
    }

    public void setInput(double[] input) {
        Layer inputLayer = this.layers.get(0);
        ArrayList<Neuron> inputNeurons = inputLayer.getNeurons();
        if (input.length != inputNeurons.size()) {
            throw new IllegalArgumentException("input array does not equal NN input size");
        }
        for (int i = 0; i < inputNeurons.size(); ++i) {
            Neuron neuron = inputNeurons.get(i);
            neuron.setBias(input[i]);
        }
    }

    public void compute(double[] output) {
        int i;
        Layer outputLayer = this.layers.get(this.layers.size() - 1);
        ArrayList<Neuron> outputNeurons = outputLayer.getNeurons();
        if (output.length != outputNeurons.size()) {
            throw new IllegalArgumentException("output array does not equal NN output size");
        }
        for (i = 0; i < this.layers.size(); ++i) {
            this.layers.get(i).compute();
        }
        for (i = 0; i < outputNeurons.size(); ++i) {
            output[i] = outputNeurons.get(i).getOutput();
        }
    }

    public String[] getInputOrder() {
        return this.inputOrder;
    }

    public String[] getOutputOrder() {
        return this.outputOrder;
    }
}

