/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn;

import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.lang.reflect.Constructor;
import java.util.Arrays;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.api.Persistable;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.optimize.optimizers.NeuralNetworkOptimizer;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.deeplearning4j.util.Dl4jReflection;
import org.nd4j.linalg.api.activation.ActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.AdaGrad;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseNeuralNetwork
implements NeuralNetwork,
Persistable {
    private static final long serialVersionUID = -7074102204433996574L;
    protected INDArray W;
    protected INDArray hBias;
    protected INDArray vBias;
    protected INDArray input;
    protected transient NeuralNetworkOptimizer optimizer;
    protected INDArray doMask;
    private static Logger log = LoggerFactory.getLogger(BaseNeuralNetwork.class);
    protected INDArray wGradient;
    protected INDArray vBiasGradient;
    protected INDArray hBiasGradient;
    protected int lastMiniBatchSize = 1;
    protected AdaGrad wAdaGrad;
    protected AdaGrad hBiasAdaGrad;
    protected AdaGrad vBiasAdaGrad;
    protected NeuralNetConfiguration conf;

    protected BaseNeuralNetwork() {
    }

    public BaseNeuralNetwork(INDArray input, INDArray W, INDArray hbias, INDArray vbias, NeuralNetConfiguration conf) {
        this.input = input;
        this.W = W;
        this.conf = conf;
        if (this.W != null) {
            this.wAdaGrad = new AdaGrad(this.W.rows(), this.W.columns());
        }
        this.vBias = vbias;
        if (this.vBias != null) {
            this.vBiasAdaGrad = new AdaGrad(this.vBias.rows(), this.vBias.columns());
        }
        this.hBias = hbias;
        if (this.hBias != null) {
            this.hBiasAdaGrad = new AdaGrad(this.hBias.rows(), this.hBias.columns());
        }
        this.initWeights();
    }

    @Override
    public INDArray params() {
        return Nd4j.toFlattened((INDArray[])new INDArray[]{this.W, this.vBias, this.hBias});
    }

    public double l2RegularizedCoefficient() {
        return (Double)Transforms.pow((INDArray)this.getW(), (Number)2).sum(Integer.MAX_VALUE).element() / 2.0 * (double)this.conf.getL2() + (double)1.0E-6f;
    }

    protected void initWeights() {
        if (this.conf.getnIn() < 1) {
            throw new IllegalStateException("Number of visible can not be less than 1");
        }
        if (this.conf.getnOut() < 1) {
            throw new IllegalStateException("Number of hidden can not be less than 1");
        }
        int nVisible = this.conf.getnIn();
        int nHidden = this.conf.getnOut();
        if (this.W == null) {
            this.W = Nd4j.zeros((int)nVisible, (int)nHidden);
            for (int i = 0; i < this.W.rows(); ++i) {
                this.W.putRow(i, Nd4j.create((double[])this.conf.getDist().sample(this.W.columns())));
            }
        }
        this.wAdaGrad = new AdaGrad(this.W.rows(), this.W.columns());
        if (this.hBias == null) {
            this.hBias = Nd4j.zeros((int)nHidden);
        }
        this.hBiasAdaGrad = new AdaGrad(this.hBias.rows(), this.hBias.columns());
        if (this.vBias == null) {
            this.vBias = this.input != null ? Nd4j.zeros((int)nVisible) : Nd4j.zeros((int)nVisible);
        }
        this.vBiasAdaGrad = new AdaGrad(this.vBias.rows(), this.vBias.columns());
    }

    @Override
    public int numParams() {
        return this.conf.getnIn() * this.conf.getnOut() + this.conf.getnIn() + this.conf.getnOut();
    }

    @Override
    public void setParams(INDArray params) {
        assert (params.length() == this.numParams()) : "Illegal number of parameters passed in, must be of length " + this.numParams();
        int weightLength = this.conf.getnIn() * this.conf.getnOut();
        INDArray weights = params.get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)weightLength)});
        INDArray vBias = params.get(new NDArrayIndex[]{NDArrayIndex.interval((int)weightLength, (int)(weightLength + this.conf.getnIn()))});
        INDArray hBias = params.get(new NDArrayIndex[]{NDArrayIndex.interval((int)(weightLength + this.conf.getnIn()), (int)(weightLength + this.conf.getnIn() + this.conf.getnOut()))});
        this.setW(weights.reshape(this.conf.getnIn(), this.conf.getnOut()));
        this.setvBias(vBias.dup());
        this.sethBias(hBias.dup());
    }

    @Override
    public void backProp(double lr, int iterations, Object[] extraParams) {
        double currRecon = LossFunctions.score((INDArray)this.input, (LossFunctions.LossFunction)LossFunctions.LossFunction.SQUARED_LOSS, (INDArray)this.transform(this.input), (double)this.conf.getL2(), (boolean)this.conf.isUseRegularization());
        boolean train = true;
        NeuralNetwork revert = this.clone();
        while (train && iterations <= iterations) {
            double newRecon = LossFunctions.score((INDArray)this.input, (LossFunctions.LossFunction)LossFunctions.LossFunction.SQUARED_LOSS, (INDArray)this.transform(this.input), (double)this.conf.getL2(), (boolean)this.conf.isUseRegularization());
            if (newRecon > currRecon || currRecon < 0.0 && newRecon < currRecon) {
                this.update((BaseNeuralNetwork)revert);
                log.info("Converged for new recon; breaking...");
                break;
            }
            if (Double.isNaN(newRecon) || Double.isInfinite(newRecon)) {
                this.update((BaseNeuralNetwork)revert);
                log.info("Converged for new recon; breaking...");
                break;
            }
            if (newRecon == currRecon) break;
            currRecon = newRecon;
            revert = this.clone();
            log.info("Recon went down " + currRecon);
            ++iterations;
            int plotIterations = this.conf.getRenderWeightsEveryNumEpochs();
            if (plotIterations <= 0) continue;
            NeuralNetPlotter plotter = new NeuralNetPlotter();
            if (iterations % plotIterations != 0) continue;
            plotter.plotNetworkGradient(this, this.getGradient(extraParams), this.getInput().rows());
        }
    }

    @Override
    public void fit(INDArray data) {
        this.fit(data, null);
    }

    protected void applySparsity(INDArray hBiasGradient) {
        if (this.conf.isUseAdaGrad()) {
            INDArray change = this.hBiasAdaGrad.getLearningRates(this.hBias).neg().muli((Number)Float.valueOf(this.conf.getSparsity())).mul(hBiasGradient.mul((Number)Float.valueOf(this.conf.getSparsity())));
            hBiasGradient.addi(change);
        } else {
            INDArray change = hBiasGradient.mul((Number)Float.valueOf(this.conf.getSparsity())).mul((Number)Float.valueOf(-this.conf.getLr() * this.conf.getSparsity()));
            hBiasGradient.addi(change);
        }
    }

    protected void updateGradientAccordingToParams(NeuralNetworkGradient gradient, int iteration, double learningRate) {
        int key;
        INDArray wGradient = gradient.getwGradient();
        INDArray hBiasGradient = gradient.gethBiasGradient();
        INDArray vBiasGradient = gradient.getvBiasGradient();
        if (iteration != 0 && this.conf.getResetAdaGradIterations() > 0 && iteration % this.conf.getResetAdaGradIterations() == 0) {
            this.wAdaGrad.historicalGradient = null;
            this.hBiasAdaGrad.historicalGradient = null;
            this.vBiasAdaGrad.historicalGradient = null;
            if (this.W != null && this.wAdaGrad == null) {
                this.wAdaGrad = new AdaGrad(this.W.rows(), this.W.columns());
            }
            if (this.vBias != null && this.vBiasAdaGrad == null) {
                this.vBiasAdaGrad = new AdaGrad(this.vBias.rows(), this.vBias.columns());
            }
            if (this.hBias != null && this.hBiasAdaGrad == null) {
                this.hBiasAdaGrad = new AdaGrad(this.hBias.rows(), this.hBias.columns());
            }
            log.info("Resetting adagrad");
        }
        INDArray wLearningRates = this.wAdaGrad.getLearningRates(wGradient);
        double momentum = this.conf.getMomentum();
        if (this.conf.getMomentumAfter() != null && !this.conf.getMomentumAfter().isEmpty() && iteration >= (key = this.conf.getMomentumAfter().keySet().iterator().next().intValue())) {
            momentum = this.conf.getMomentumAfter().get(key).floatValue();
        }
        if (this.conf.isUseAdaGrad()) {
            wGradient.muli(wLearningRates);
        } else {
            wGradient.muli((Number)learningRate);
        }
        if (this.conf.isUseAdaGrad()) {
            hBiasGradient.muli(this.hBiasAdaGrad.getLearningRates(hBiasGradient));
        } else {
            hBiasGradient.muli((Number)learningRate);
        }
        if (this.conf.isUseAdaGrad()) {
            vBiasGradient.muli(this.vBiasAdaGrad.getLearningRates(vBiasGradient));
        } else {
            vBiasGradient.muli((Number)learningRate);
        }
        if (this.hBiasGradient != null && this.conf.getSparsity() != 0.0f) {
            this.applySparsity(hBiasGradient);
        }
        if (momentum != 0.0 && this.wGradient != null) {
            wGradient.addi(this.wGradient.mul((Number)momentum).addi(wGradient.mul((Number)(1.0 - momentum))));
        }
        if (momentum != 0.0 && this.vBiasGradient != null) {
            vBiasGradient.addi(this.vBiasGradient.mul((Number)momentum).addi(vBiasGradient.mul((Number)(1.0 - momentum))));
        }
        if (momentum != 0.0 && this.hBiasGradient != null) {
            hBiasGradient.addi(this.hBiasGradient.mul((Number)momentum).addi(hBiasGradient.mul((Number)(1.0 - momentum))));
        }
        wGradient.divi((Number)this.lastMiniBatchSize);
        vBiasGradient.divi((Number)this.lastMiniBatchSize);
        hBiasGradient.divi((Number)this.lastMiniBatchSize);
        if (this.conf.isUseRegularization() && this.conf.getL2() > 0.0f) {
            if (this.conf.isUseAdaGrad()) {
                wGradient.subi(this.W.mul((Number)Float.valueOf(this.conf.getL2())).muli(wLearningRates));
            } else {
                wGradient.subi(this.W.mul((Number)((double)this.conf.getL2() * learningRate)));
            }
        }
        if (this.conf.isConstrainGradientToUnitNorm()) {
            wGradient.divi(wGradient.norm2(Integer.MAX_VALUE));
            vBiasGradient.divi(vBiasGradient.norm2(Integer.MAX_VALUE));
            hBiasGradient.divi(hBiasGradient.norm2(Integer.MAX_VALUE));
        }
        this.wGradient = wGradient;
        this.vBiasGradient = vBiasGradient;
        this.hBiasGradient = hBiasGradient;
    }

    @Override
    public double score() {
        if (this.conf.getLossFunction() != LossFunctions.LossFunction.RECONSTRUCTION_CROSSENTROPY) {
            return LossFunctions.score((INDArray)this.input, (LossFunctions.LossFunction)this.conf.getLossFunction(), (INDArray)this.transform(this.input), (double)this.conf.getL2(), (boolean)this.conf.isUseRegularization());
        }
        return -LossFunctions.reconEntropy((INDArray)this.input, (INDArray)this.hBias, (INDArray)this.vBias, (INDArray)this.W, (ActivationFunction)this.conf.getActivationFunction());
    }

    @Override
    public void clearInput() {
        this.input = null;
    }

    @Override
    public AdaGrad getAdaGrad() {
        return this.wAdaGrad;
    }

    @Override
    public void setAdaGrad(AdaGrad adaGrad) {
        this.wAdaGrad = adaGrad;
    }

    @Override
    public NeuralNetwork transpose() {
        try {
            Constructor<?> c = Dl4jReflection.getEmptyConstructor(this.getClass());
            c.setAccessible(true);
            NeuralNetwork ret = (NeuralNetwork)c.newInstance(new Object[0]);
            ret.setVBiasAdaGrad(this.hBiasAdaGrad);
            ret.sethBias(this.vBias.dup());
            ret.setConf(this.conf);
            ret.setvBias(Nd4j.zeros((int)this.hBias.rows(), (int)this.hBias.columns()));
            ret.setW(this.W.transpose());
            return ret;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public NeuralNetConfiguration conf() {
        return this.conf;
    }

    @Override
    public void setConf(NeuralNetConfiguration conf) {
        this.conf = conf;
    }

    @Override
    public NeuralNetwork clone() {
        try {
            Constructor<?> c = Dl4jReflection.getEmptyConstructor(this.getClass());
            c.setAccessible(true);
            NeuralNetwork ret = (NeuralNetwork)c.newInstance(new Object[0]);
            ret.setConf(this.conf);
            ret.setHbiasAdaGrad(this.hBiasAdaGrad);
            ret.setVBiasAdaGrad(this.vBiasAdaGrad);
            ret.sethBias(this.hBias.dup());
            ret.setvBias(this.vBias.dup());
            ret.setW(this.W.dup());
            ret.setAdaGrad(this.wAdaGrad);
            return ret;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void merge(NeuralNetwork network, int batchSize) {
        this.W.addi(network.getW().sub(this.W).divi((Number)batchSize));
        this.hBias.addi(network.gethBias().sub(this.hBias).divi((Number)batchSize));
        this.vBias.addi(network.getvBias().subi(this.vBias).divi((Number)batchSize));
    }

    public void update(BaseNeuralNetwork n) {
        this.W = n.W;
        this.conf = n.conf;
        this.hBias = n.hBias;
        this.vBias = n.vBias;
        this.wAdaGrad = n.wAdaGrad;
        this.hBiasAdaGrad = n.hBiasAdaGrad;
        this.vBiasAdaGrad = n.vBiasAdaGrad;
    }

    @Override
    public void load(InputStream is) {
        try {
            ObjectInputStream ois = new ObjectInputStream(is);
            BaseNeuralNetwork loaded = (BaseNeuralNetwork)ois.readObject();
            this.update(loaded);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public INDArray getW() {
        return this.W;
    }

    @Override
    public void setW(INDArray w) {
        assert (Arrays.equals(w.shape(), new int[]{this.conf.getnIn(), this.conf.getnOut()})) : "Invalid shape for w, must be " + Arrays.toString(new int[]{this.conf.getnIn(), this.conf.getnOut()});
        this.W = w;
    }

    @Override
    public INDArray gethBias() {
        return this.hBias;
    }

    @Override
    public void sethBias(INDArray hBias) {
        assert (Arrays.equals(hBias.shape(), new int[]{this.conf.getnOut()})) : "Illegal shape for visible bias, must be of shape " + new int[]{this.conf.getnOut()};
        this.hBias = hBias;
    }

    @Override
    public INDArray getvBias() {
        return this.vBias;
    }

    @Override
    public void setvBias(INDArray vBias) {
        assert (Arrays.equals(vBias.shape(), new int[]{this.conf.getnIn()})) : "Illegal shape for visible bias, must be of shape " + Arrays.toString(new int[]{this.conf.getnIn()});
        this.vBias = vBias;
    }

    @Override
    public INDArray getInput() {
        return this.input;
    }

    @Override
    public void setInput(INDArray input) {
        this.input = input;
    }

    @Override
    public AdaGrad gethBiasAdaGrad() {
        return this.hBiasAdaGrad;
    }

    @Override
    public void setHbiasAdaGrad(AdaGrad adaGrad) {
        this.hBiasAdaGrad = adaGrad;
    }

    @Override
    public AdaGrad getVBiasAdaGrad() {
        return this.vBiasAdaGrad;
    }

    @Override
    public void setVBiasAdaGrad(AdaGrad adaGrad) {
        this.vBiasAdaGrad = adaGrad;
    }

    @Override
    public void write(OutputStream os) {
        try {
            ObjectOutputStream os2 = new ObjectOutputStream(os);
            os2.writeObject(this);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public abstract INDArray transform(INDArray var1);

    protected void applyDropOutIfNecessary(INDArray input) {
        this.doMask = this.conf.getDropOut() > 0.0f ? Nd4j.rand((int)input.rows(), (int)input.columns()).gt((Number)Float.valueOf(this.conf.getDropOut())) : Nd4j.ones((int)input.rows(), (int)input.columns());
        input.muli(this.doMask);
    }

    @Override
    public INDArray hBiasMean() {
        INDArray hbiasMean = this.getInput().mmul(this.getW()).addRowVector(this.gethBias());
        return hbiasMean;
    }

    protected INDArray preProcessInput(INDArray input) {
        if (this.conf.isConcatBiases()) {
            return Nd4j.hstack((INDArray[])new INDArray[]{input, Nd4j.ones((int)input.rows(), (int)1)});
        }
        return input;
    }

    @Override
    public void iterationDone(int iteration) {
        int plotEpochs = this.conf.getRenderWeightsEveryNumEpochs();
        if (plotEpochs <= 0) {
            return;
        }
        if (iteration % plotEpochs == 0 || iteration == 0) {
            NeuralNetPlotter plotter = new NeuralNetPlotter();
            plotter.plotNetworkGradient(this, this.getGradient(new Object[]{1, 0.001, 1000}), this.getInput().rows());
        }
    }

    public static class Builder<E extends BaseNeuralNetwork> {
        private E ret = null;
        private INDArray W;
        protected Class<? extends NeuralNetwork> clazz;
        private INDArray vBias;
        private INDArray hBias;
        private INDArray input;
        private NeuralNetConfiguration conf;

        public Builder<E> configure(NeuralNetConfiguration conf) {
            this.conf = conf;
            return this;
        }

        public E buildEmpty() {
            try {
                return (E)((BaseNeuralNetwork)this.clazz.newInstance());
            }
            catch (IllegalAccessException | InstantiationException e) {
                throw new RuntimeException(e);
            }
        }

        public Builder<E> withClazz(Class<? extends BaseNeuralNetwork> clazz) {
            this.clazz = clazz;
            return this;
        }

        public Builder<E> withInput(INDArray input) {
            this.input = input;
            return this;
        }

        public Builder<E> asType(Class<E> clazz) {
            this.clazz = clazz;
            return this;
        }

        public Builder<E> withWeights(INDArray W) {
            this.W = W;
            return this;
        }

        public Builder<E> withVisibleBias(INDArray vBias) {
            this.vBias = vBias;
            return this;
        }

        public Builder<E> withHBias(INDArray hBias) {
            this.hBias = hBias;
            return this;
        }

        public E build() {
            return this.buildWithInput();
        }

        private E buildWithInput() {
            Constructor<?>[] c = this.clazz.getDeclaredConstructors();
            for (int i = 0; i < c.length; ++i) {
                Constructor<?> curr = c[i];
                curr.setAccessible(true);
                Class<?>[] classes = curr.getParameterTypes();
                if (classes == null || classes.length <= 0 || !classes[0].isAssignableFrom(INDArray.class)) continue;
                try {
                    this.ret = (BaseNeuralNetwork)curr.newInstance(this.input, this.W, this.hBias, this.vBias, this.conf);
                    return this.ret;
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            return this.ret;
        }
    }
}

