/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn;

import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.lang.reflect.Constructor;
import java.util.List;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.gradient.NeuralNetworkGradientListener;
import org.deeplearning4j.nn.NeuralNetwork;
import org.deeplearning4j.nn.Persistable;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.nn.learning.AdaGrad;
import org.deeplearning4j.optimize.NeuralNetworkOptimizer;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

public abstract class BaseNeuralNetwork
implements NeuralNetwork,
Persistable {
    private static final long serialVersionUID = -7074102204433996574L;
    public int nVisible;
    protected int nHidden;
    protected DoubleMatrix W;
    protected DoubleMatrix hBias;
    protected DoubleMatrix vBias;
    protected RandomGenerator rng;
    protected DoubleMatrix input;
    protected double sparsity = 0.0;
    protected double momentum = 0.5;
    protected transient RealDistribution dist = new NormalDistribution(this.rng, 0.0, 0.01, 1.0E-9);
    protected double l2 = 0.1;
    protected transient NeuralNetworkOptimizer optimizer;
    protected int renderWeightsEveryNumEpochs = -1;
    protected double fanIn = -1.0;
    protected boolean useRegularization = false;
    protected boolean useAdaGrad = false;
    protected boolean firstTimeThrough = false;
    protected boolean normalizeByInputRows = false;
    protected boolean applySparsity = true;
    protected List<NeuralNetworkGradientListener> gradientListeners;
    protected double dropOut = 0.0;
    protected DoubleMatrix doMask;
    protected NeuralNetwork.OptimizationAlgorithm optimizationAlgo;
    protected NeuralNetwork.LossFunction lossFunction;
    protected AdaGrad wAdaGrad;
    protected AdaGrad hBiasAdaGrad;
    protected AdaGrad vBiasAdaGrad;

    protected BaseNeuralNetwork() {
    }

    public BaseNeuralNetwork(int nVisible, int nHidden, DoubleMatrix W, DoubleMatrix hbias, DoubleMatrix vbias, RandomGenerator rng, double fanIn, RealDistribution dist) {
        this(null, nVisible, nHidden, W, hbias, vbias, rng, fanIn, dist);
    }

    public BaseNeuralNetwork(DoubleMatrix input, int nVisible, int nHidden, DoubleMatrix W, DoubleMatrix hbias, DoubleMatrix vbias, RandomGenerator rng, double fanIn, RealDistribution dist) {
        this.nVisible = nVisible;
        this.dist = dist != null ? dist : new NormalDistribution(rng, 0.0, 0.01, 1.0E-9);
        this.nHidden = nHidden;
        this.fanIn = fanIn;
        this.input = input;
        this.rng = rng == null ? new MersenneTwister(1234) : rng;
        this.W = W;
        if (this.W != null) {
            this.wAdaGrad = new AdaGrad(this.W.rows, this.W.columns);
        }
        this.vBias = vbias;
        if (this.vBias != null) {
            this.vBiasAdaGrad = new AdaGrad(this.vBias.rows, this.vBias.columns);
        }
        this.hBias = hbias;
        if (this.hBias != null) {
            this.hBiasAdaGrad = new AdaGrad(this.hBias.rows, this.hBias.columns);
        }
        this.initWeights();
    }

    @Override
    public double l2RegularizedCoefficient() {
        return MatrixFunctions.pow((DoubleMatrix)this.getW(), (double)2.0).sum() / 2.0 * this.l2;
    }

    protected void initWeights() {
        if (this.nVisible < 1) {
            throw new IllegalStateException("Number of visible can not be less than 1");
        }
        if (this.nHidden < 1) {
            throw new IllegalStateException("Number of hidden can not be less than 1");
        }
        if (this.dist == null) {
            this.dist = new NormalDistribution(this.rng, 0.0, 0.01, 1.0E-9);
        }
        if (this.W == null) {
            this.W = DoubleMatrix.zeros((int)this.nVisible, (int)this.nHidden);
            for (int i = 0; i < this.W.rows; ++i) {
                this.W.putRow(i, new DoubleMatrix(this.dist.sample(this.W.columns)));
            }
        }
        this.wAdaGrad = new AdaGrad(this.W.rows, this.W.columns);
        if (this.hBias == null) {
            this.hBias = DoubleMatrix.zeros((int)this.nHidden);
        }
        this.hBiasAdaGrad = new AdaGrad(this.hBias.rows, this.hBias.columns);
        if (this.vBias == null) {
            this.vBias = this.input != null ? DoubleMatrix.zeros((int)this.nVisible) : DoubleMatrix.zeros((int)this.nVisible);
        }
        this.vBiasAdaGrad = new AdaGrad(this.vBias.rows, this.vBias.columns);
    }

    @Override
    public void resetAdaGrad(double lr) {
        if (!this.firstTimeThrough) {
            this.wAdaGrad = new AdaGrad(this.getW().rows, this.getW().columns, lr);
            this.firstTimeThrough = false;
        }
    }

    @Override
    public List<NeuralNetworkGradientListener> getGradientListeners() {
        return this.gradientListeners;
    }

    @Override
    public synchronized void setGradientListeners(List<NeuralNetworkGradientListener> gradientListeners) {
        this.gradientListeners = gradientListeners;
    }

    @Override
    public void setRenderEpochs(int renderEpochs) {
        this.renderWeightsEveryNumEpochs = renderEpochs;
    }

    @Override
    public int getRenderEpochs() {
        return this.renderWeightsEveryNumEpochs;
    }

    @Override
    public double fanIn() {
        return this.fanIn < 0.0 ? (double)(1 / this.nVisible) : this.fanIn;
    }

    @Override
    public void setFanIn(double fanIn) {
        this.fanIn = fanIn;
    }

    public void jostleWeighMatrix() {
        DoubleMatrix W = DoubleMatrix.zeros((int)this.nVisible, (int)this.nHidden);
        for (int i = 0; i < this.W.rows; ++i) {
            W.putRow(i, new DoubleMatrix(this.dist.sample(this.W.columns)));
        }
    }

    protected void applySparsity(DoubleMatrix hBiasGradient, double learningRate) {
        if (this.useAdaGrad) {
            DoubleMatrix change = this.hBiasAdaGrad.getLearningRates(this.hBias).neg().mul(this.sparsity).mul(hBiasGradient.mul(this.sparsity));
            hBiasGradient.addi(change);
        } else {
            DoubleMatrix change = hBiasGradient.mul(this.sparsity).mul(-learningRate * this.sparsity);
            hBiasGradient.addi(change);
        }
    }

    protected void updateGradientAccordingToParams(NeuralNetworkGradient gradient, double learningRate) {
        DoubleMatrix wGradient = gradient.getwGradient();
        DoubleMatrix hBiasGradient = gradient.gethBiasGradient();
        DoubleMatrix vBiasGradient = gradient.getvBiasGradient();
        if (this.useAdaGrad) {
            wGradient.muli(this.wAdaGrad.getLearningRates(wGradient));
        } else {
            wGradient.muli(learningRate);
        }
        if (this.useRegularization) {
            wGradient.subi(this.W.muli(this.l2));
        }
        if (this.momentum != 0.0) {
            DoubleMatrix change = wGradient.mul(this.momentum).add(wGradient.mul(1.0 - this.momentum));
            wGradient.addi(change);
        }
        hBiasGradient = this.useAdaGrad ? hBiasGradient.mul(this.hBiasAdaGrad.getLearningRates(hBiasGradient)).add(hBiasGradient.mul(this.momentum)) : hBiasGradient.mul(learningRate).add(hBiasGradient.mul(this.momentum));
        vBiasGradient = this.useAdaGrad ? vBiasGradient.mul(this.vBiasAdaGrad.getLearningRates(vBiasGradient)).add(vBiasGradient.mul(this.momentum)) : vBiasGradient.mul(learningRate).add(vBiasGradient.mul(this.momentum));
        if (this.applySparsity) {
            this.applySparsity(hBiasGradient, learningRate);
        }
        if (this.normalizeByInputRows) {
            wGradient.divi((double)this.input.rows);
            vBiasGradient.divi((double)this.input.rows);
            hBiasGradient.divi((double)this.input.rows);
        }
    }

    protected void triggerGradientEvents(NeuralNetworkGradient gradient) {
        if (this.gradientListeners != null && !this.gradientListeners.isEmpty()) {
            for (NeuralNetworkGradientListener listener : this.gradientListeners) {
                listener.onGradient(gradient);
            }
        }
    }

    @Override
    public void setDropOut(double dropOut) {
        this.dropOut = dropOut;
    }

    @Override
    public double dropOut() {
        return this.dropOut;
    }

    @Override
    public AdaGrad getAdaGrad() {
        return this.wAdaGrad;
    }

    @Override
    public void setAdaGrad(AdaGrad adaGrad) {
        this.wAdaGrad = adaGrad;
    }

    @Override
    public NeuralNetwork transpose() {
        try {
            NeuralNetwork ret = (NeuralNetwork)this.getClass().newInstance();
            ret.sethBias(this.hBias.dup());
            ret.setvBias(this.vBias.dup());
            ret.setnHidden(this.getnVisible());
            ret.setnVisible(this.getnHidden());
            ret.setW(this.W.transpose());
            ret.setRng(this.getRng());
            ret.setAdaGrad(this.wAdaGrad);
            ret.setDist(this.getDist());
            return ret;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public NeuralNetwork clone() {
        try {
            Constructor<?> c = this.getClass().getDeclaredConstructors()[0];
            c.setAccessible(true);
            NeuralNetwork ret = (NeuralNetwork)c.newInstance(new Object[0]);
            ret.setHbiasAdaGrad(this.hBiasAdaGrad);
            ret.setVBiasAdaGrad(this.vBiasAdaGrad);
            ret.sethBias(this.hBias.dup());
            ret.setvBias(this.vBias.dup());
            ret.setnHidden(this.getnHidden());
            ret.setnVisible(this.getnVisible());
            ret.setW(this.W.dup());
            ret.setL2(this.l2);
            ret.setMomentum(this.momentum);
            ret.setRenderEpochs(this.getRenderEpochs());
            ret.setSparsity(this.sparsity);
            ret.setRng(this.getRng());
            ret.setDist(this.getDist());
            ret.setAdaGrad(this.wAdaGrad);
            ret.setLossFunction(this.lossFunction);
            ret.setOptimizationAlgorithm(this.optimizationAlgo);
            return ret;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public NeuralNetwork.LossFunction getLossFunction() {
        return this.lossFunction;
    }

    @Override
    public void setLossFunction(NeuralNetwork.LossFunction lossFunction) {
        this.lossFunction = lossFunction;
    }

    @Override
    public NeuralNetwork.OptimizationAlgorithm getOptimizationAlgorithm() {
        return this.optimizationAlgo;
    }

    @Override
    public void setOptimizationAlgorithm(NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm) {
        this.optimizationAlgo = optimizationAlgorithm;
    }

    @Override
    public RealDistribution getDist() {
        return this.dist;
    }

    @Override
    public void setDist(RealDistribution dist) {
        this.dist = dist;
    }

    @Override
    public void merge(NeuralNetwork network, int batchSize) {
        this.W.addi(network.getW().sub(this.W).div((double)batchSize));
        this.hBias.addi(network.gethBias().sub(this.hBias).divi((double)batchSize));
        this.vBias.addi(network.getvBias().subi(this.vBias).divi((double)batchSize));
    }

    public void update(BaseNeuralNetwork n) {
        this.W = n.W;
        this.normalizeByInputRows = n.normalizeByInputRows;
        this.hBias = n.hBias;
        this.vBias = n.vBias;
        this.l2 = n.l2;
        this.useRegularization = n.useRegularization;
        this.momentum = n.momentum;
        this.nHidden = n.nHidden;
        this.nVisible = n.nVisible;
        this.rng = n.rng;
        this.sparsity = n.sparsity;
        this.wAdaGrad = n.wAdaGrad;
        this.hBiasAdaGrad = n.hBiasAdaGrad;
        this.vBiasAdaGrad = n.vBiasAdaGrad;
        this.optimizationAlgo = n.optimizationAlgo;
        this.lossFunction = n.lossFunction;
    }

    @Override
    public void load(InputStream is) {
        try {
            ObjectInputStream ois = new ObjectInputStream(is);
            BaseNeuralNetwork loaded = (BaseNeuralNetwork)ois.readObject();
            this.update(loaded);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public double negativeLogLikelihood() {
        DoubleMatrix z = this.reconstruct(this.input);
        if (this.useRegularization) {
            double reg = 2.0 / this.l2 * MatrixFunctions.pow((DoubleMatrix)this.W, (double)2.0).sum();
            double ret = -this.input.mul(MatrixUtil.log(z)).add(MatrixUtil.oneMinus(this.input).mul(MatrixUtil.log(MatrixUtil.oneMinus(z)))).columnSums().mean() + reg;
            if (this.normalizeByInputRows) {
                ret /= (double)this.input.rows;
            }
            return ret;
        }
        double likelihood = -this.input.mul(MatrixUtil.log(z)).add(MatrixUtil.oneMinus(this.input).mul(MatrixUtil.log(MatrixUtil.oneMinus(z)))).columnSums().mean();
        if (this.normalizeByInputRows) {
            likelihood /= (double)this.input.rows;
        }
        return likelihood;
    }

    public double negativeLoglikelihood(DoubleMatrix input) {
        DoubleMatrix z = this.reconstruct(input);
        if (this.useRegularization) {
            double reg = 2.0 / this.l2 * MatrixFunctions.pow((DoubleMatrix)this.W, (double)2.0).sum();
            return -input.mul(MatrixUtil.log(z)).add(MatrixUtil.oneMinus(input).mul(MatrixUtil.log(MatrixUtil.oneMinus(z)))).columnSums().mean() + reg;
        }
        return -input.mul(MatrixUtil.log(z)).add(MatrixUtil.oneMinus(input).mul(MatrixUtil.log(MatrixUtil.oneMinus(z)))).columnSums().mean();
    }

    @Override
    public double getReConstructionCrossEntropy() {
        DoubleMatrix preSigH = this.input.mmul(this.W).addRowVector(this.hBias);
        DoubleMatrix sigH = MatrixUtil.sigmoid(preSigH);
        DoubleMatrix preSigV = sigH.mmul(this.W.transpose()).addRowVector(this.vBias);
        DoubleMatrix sigV = MatrixUtil.sigmoid(preSigV);
        DoubleMatrix inner = this.input.mul(MatrixUtil.log(sigV)).add(MatrixUtil.oneMinus(this.input).mul(MatrixUtil.log(MatrixUtil.oneMinus(sigV))));
        double l = inner.length;
        if (this.useRegularization) {
            double normalized = l + this.l2RegularizedCoefficient();
            double ret = -inner.rowSums().mean() / normalized;
            if (this.normalizeByInputRows) {
                ret /= (double)this.input.rows;
            }
            return ret;
        }
        double ret = -inner.rowSums().mean();
        if (this.normalizeByInputRows) {
            ret /= (double)this.input.rows;
        }
        return ret;
    }

    @Override
    public boolean normalizeByInputRows() {
        return this.normalizeByInputRows;
    }

    @Override
    public int getnVisible() {
        return this.nVisible;
    }

    @Override
    public void setnVisible(int nVisible) {
        this.nVisible = nVisible;
    }

    @Override
    public int getnHidden() {
        return this.nHidden;
    }

    @Override
    public void setnHidden(int nHidden) {
        this.nHidden = nHidden;
    }

    @Override
    public DoubleMatrix getW() {
        return this.W;
    }

    @Override
    public void setW(DoubleMatrix w) {
        this.W = w;
    }

    @Override
    public DoubleMatrix gethBias() {
        return this.hBias;
    }

    @Override
    public void sethBias(DoubleMatrix hBias) {
        this.hBias = hBias;
    }

    @Override
    public DoubleMatrix getvBias() {
        return this.vBias;
    }

    @Override
    public void setvBias(DoubleMatrix vBias) {
        this.vBias = vBias;
    }

    @Override
    public RandomGenerator getRng() {
        return this.rng;
    }

    @Override
    public void setRng(RandomGenerator rng) {
        this.rng = rng;
    }

    @Override
    public DoubleMatrix getInput() {
        return this.input;
    }

    @Override
    public void setInput(DoubleMatrix input) {
        this.input = input;
    }

    @Override
    public double getSparsity() {
        return this.sparsity;
    }

    @Override
    public void setSparsity(double sparsity) {
        this.sparsity = sparsity;
    }

    @Override
    public double getMomentum() {
        return this.momentum;
    }

    @Override
    public void setMomentum(double momentum) {
        this.momentum = momentum;
    }

    @Override
    public double getL2() {
        return this.l2;
    }

    @Override
    public void setL2(double l2) {
        this.l2 = l2;
    }

    @Override
    public AdaGrad gethBiasAdaGrad() {
        return this.hBiasAdaGrad;
    }

    @Override
    public void setHbiasAdaGrad(AdaGrad adaGrad) {
        this.hBiasAdaGrad = adaGrad;
    }

    @Override
    public AdaGrad getVBiasAdaGrad() {
        return this.vBiasAdaGrad;
    }

    @Override
    public void setVBiasAdaGrad(AdaGrad adaGrad) {
        this.vBiasAdaGrad = adaGrad;
    }

    @Override
    public void write(OutputStream os) {
        try {
            ObjectOutputStream os2 = new ObjectOutputStream(os);
            os2.writeObject(this);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public abstract DoubleMatrix reconstruct(DoubleMatrix var1);

    public abstract double lossFunction(Object[] var1);

    public double lossFunction() {
        return this.lossFunction(null);
    }

    protected void applyDropOutIfNecessary(DoubleMatrix input) {
        this.doMask = this.dropOut > 0.0 ? DoubleMatrix.rand((int)input.rows, (int)this.nHidden).gt(this.dropOut) : DoubleMatrix.ones((int)input.rows, (int)this.nHidden);
    }

    @Override
    public abstract void train(DoubleMatrix var1, double var2, Object[] var4);

    @Override
    public double squaredLoss() {
        DoubleMatrix reconstructed = this.reconstruct(this.input);
        double loss = MatrixFunctions.powi((DoubleMatrix)reconstructed.sub(this.input), (double)2.0).sum() / (double)this.input.rows;
        if (this.useRegularization) {
            loss += 0.5 * this.l2 * MatrixFunctions.pow((DoubleMatrix)this.W, (double)2.0).sum();
        }
        return -loss;
    }

    @Override
    public DoubleMatrix hBiasMean() {
        DoubleMatrix hbiasMean = this.getInput().mmul(this.getW()).addRowVector(this.gethBias());
        return hbiasMean;
    }

    @Override
    public void epochDone(int epoch) {
        int plotEpochs = this.getRenderEpochs();
        if (plotEpochs <= 0) {
            return;
        }
        if (epoch % plotEpochs == 0 || epoch == 0) {
            NeuralNetPlotter plotter = new NeuralNetPlotter();
            plotter.plotNetworkGradient(this, this.getGradient(new Object[]{1, 0.001, 1000}));
        }
    }

    public static class Builder<E extends BaseNeuralNetwork> {
        private E ret = null;
        private DoubleMatrix W;
        protected Class<? extends NeuralNetwork> clazz;
        private DoubleMatrix vBias;
        private DoubleMatrix hBias;
        private int numVisible;
        private int numHidden;
        private RandomGenerator gen = new MersenneTwister(123);
        private DoubleMatrix input;
        private double sparsity = 0.01;
        private double l2 = 0.01;
        private double momentum = 0.5;
        private int renderWeightsEveryNumEpochs = -1;
        private double fanIn = 0.1;
        private boolean useRegularization = false;
        private RealDistribution dist;
        private boolean useAdaGrad = false;
        private boolean normalizeByInputRows = false;
        private double dropOut = 0.0;
        private NeuralNetwork.LossFunction lossFunction = NeuralNetwork.LossFunction.RECONSTRUCTION_CROSSENTROPY;
        private NeuralNetwork.OptimizationAlgorithm optimizationAlgo = NeuralNetwork.OptimizationAlgorithm.CONJUGATE_GRADIENT;

        public Builder<E> withOptmizationAlgo(NeuralNetwork.OptimizationAlgorithm optimizationAlgo) {
            this.optimizationAlgo = optimizationAlgo;
            return this;
        }

        public Builder<E> withLossFunction(NeuralNetwork.LossFunction lossFunction) {
            this.lossFunction = lossFunction;
            return this;
        }

        public Builder<E> withDropOut(double dropOut) {
            this.dropOut = dropOut;
            return this;
        }

        public Builder<E> normalizeByInputRows(boolean normalizeByInputRows) {
            this.normalizeByInputRows = normalizeByInputRows;
            return this;
        }

        public Builder<E> useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder<E> withDistribution(RealDistribution dist) {
            this.dist = dist;
            return this;
        }

        public Builder<E> useRegularization(boolean useRegularization) {
            this.useRegularization = useRegularization;
            return this;
        }

        public Builder<E> fanIn(double fanIn) {
            this.fanIn = fanIn;
            return this;
        }

        public Builder<E> withL2(double l2) {
            this.l2 = l2;
            return this;
        }

        public Builder<E> renderWeights(int numEpochs) {
            this.renderWeightsEveryNumEpochs = numEpochs;
            return this;
        }

        public E buildEmpty() {
            try {
                return (E)((BaseNeuralNetwork)this.clazz.newInstance());
            }
            catch (IllegalAccessException | InstantiationException e) {
                throw new RuntimeException(e);
            }
        }

        public Builder<E> withClazz(Class<? extends BaseNeuralNetwork> clazz) {
            this.clazz = clazz;
            return this;
        }

        public Builder<E> withSparsity(double sparsity) {
            this.sparsity = sparsity;
            return this;
        }

        public Builder<E> withMomentum(double momentum) {
            this.momentum = momentum;
            return this;
        }

        public Builder<E> withInput(DoubleMatrix input) {
            this.input = input;
            return this;
        }

        public Builder<E> asType(Class<E> clazz) {
            this.clazz = clazz;
            return this;
        }

        public Builder<E> withWeights(DoubleMatrix W) {
            this.W = W;
            return this;
        }

        public Builder<E> withVisibleBias(DoubleMatrix vBias) {
            this.vBias = vBias;
            return this;
        }

        public Builder<E> withHBias(DoubleMatrix hBias) {
            this.hBias = hBias;
            return this;
        }

        public Builder<E> numberOfVisible(int numVisible) {
            this.numVisible = numVisible;
            return this;
        }

        public Builder<E> numHidden(int numHidden) {
            this.numHidden = numHidden;
            return this;
        }

        public Builder<E> withRandom(RandomGenerator gen) {
            this.gen = gen;
            return this;
        }

        public E build() {
            return this.buildWithInput();
        }

        private E buildWithInput() {
            Constructor<?>[] c = this.clazz.getDeclaredConstructors();
            for (int i = 0; i < c.length; ++i) {
                Constructor<?> curr = c[i];
                curr.setAccessible(true);
                Class<?>[] classes = curr.getParameterTypes();
                if (classes == null || classes.length <= 0 || !classes[0].isAssignableFrom(DoubleMatrix.class)) continue;
                try {
                    this.ret = (BaseNeuralNetwork)curr.newInstance(this.input, this.numVisible, this.numHidden, this.W, this.hBias, this.vBias, this.gen, this.fanIn, this.dist);
                    ((BaseNeuralNetwork)this.ret).sparsity = this.sparsity;
                    ((BaseNeuralNetwork)this.ret).normalizeByInputRows = this.normalizeByInputRows;
                    ((BaseNeuralNetwork)this.ret).renderWeightsEveryNumEpochs = this.renderWeightsEveryNumEpochs;
                    ((BaseNeuralNetwork)this.ret).l2 = this.l2;
                    ((BaseNeuralNetwork)this.ret).momentum = this.momentum;
                    ((BaseNeuralNetwork)this.ret).useRegularization = this.useRegularization;
                    ((BaseNeuralNetwork)this.ret).useAdaGrad = this.useAdaGrad;
                    ((BaseNeuralNetwork)this.ret).dropOut = this.dropOut;
                    ((BaseNeuralNetwork)this.ret).optimizationAlgo = this.optimizationAlgo;
                    ((BaseNeuralNetwork)this.ret).lossFunction = this.lossFunction;
                    return this.ret;
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            return this.ret;
        }
    }
}

