/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn;

import java.io.Serializable;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.nn.activation.ActivationFunction;
import org.deeplearning4j.nn.activation.Sigmoid;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;

public class HiddenLayer
implements Serializable {
    private static final long serialVersionUID = 915783367350830495L;
    private int nIn;
    private int nOut;
    private DoubleMatrix W;
    private DoubleMatrix b;
    private RandomGenerator rng;
    private DoubleMatrix input;
    private ActivationFunction activationFunction = new Sigmoid();

    private HiddenLayer() {
    }

    public HiddenLayer(int nIn, int nOut, DoubleMatrix W, DoubleMatrix b, RandomGenerator rng, DoubleMatrix input, ActivationFunction activationFunction) {
        this.nIn = nIn;
        this.nOut = nOut;
        this.input = input;
        this.activationFunction = activationFunction;
        this.rng = rng == null ? new MersenneTwister(1234) : rng;
        if (W == null) {
            NormalDistribution u = new NormalDistribution(this.rng, 0.0, 0.01, 1.0E-9);
            this.W = DoubleMatrix.zeros((int)nIn, (int)nOut);
            for (int i = 0; i < this.W.rows; ++i) {
                this.W.putRow(i, new DoubleMatrix(u.sample(this.W.columns)));
            }
        } else {
            this.W = W;
        }
        this.b = b == null ? DoubleMatrix.zeros((int)nOut) : b;
    }

    public HiddenLayer(int nIn, int nOut, DoubleMatrix W, DoubleMatrix b, RandomGenerator rng, DoubleMatrix input) {
        this.nIn = nIn;
        this.nOut = nOut;
        this.input = input;
        this.rng = rng == null ? new MersenneTwister(1234) : rng;
        if (W == null) {
            NormalDistribution u = new NormalDistribution(this.rng, 0.0, 0.01, 1.0E-9);
            this.W = DoubleMatrix.zeros((int)nIn, (int)nOut);
            for (int i = 0; i < this.W.rows; ++i) {
                this.W.putRow(i, new DoubleMatrix(u.sample(this.W.columns)));
            }
        } else {
            this.W = W;
        }
        this.b = b == null ? DoubleMatrix.zeros((int)nOut) : b;
    }

    public synchronized int getnIn() {
        return this.nIn;
    }

    public synchronized void setnIn(int nIn) {
        this.nIn = nIn;
    }

    public synchronized int getnOut() {
        return this.nOut;
    }

    public synchronized void setnOut(int nOut) {
        this.nOut = nOut;
    }

    public synchronized DoubleMatrix getW() {
        return this.W;
    }

    public synchronized void setW(DoubleMatrix w) {
        this.W = w;
    }

    public synchronized DoubleMatrix getB() {
        return this.b;
    }

    public synchronized void setB(DoubleMatrix b) {
        this.b = b;
    }

    public synchronized RandomGenerator getRng() {
        return this.rng;
    }

    public synchronized void setRng(RandomGenerator rng) {
        this.rng = rng;
    }

    public synchronized DoubleMatrix getInput() {
        return this.input;
    }

    public synchronized void setInput(DoubleMatrix input) {
        this.input = input;
    }

    public synchronized ActivationFunction getActivationFunction() {
        return this.activationFunction;
    }

    public synchronized void setActivationFunction(ActivationFunction activationFunction) {
        this.activationFunction = activationFunction;
    }

    public HiddenLayer clone() {
        HiddenLayer layer = new HiddenLayer();
        layer.b = this.b.dup();
        layer.W = this.W.dup();
        layer.input = this.input.dup();
        layer.activationFunction = this.activationFunction;
        layer.nOut = this.nOut;
        layer.nIn = this.nIn;
        layer.rng = this.rng;
        return layer;
    }

    public HiddenLayer transpose() {
        HiddenLayer layer = new HiddenLayer();
        layer.b = this.b.dup();
        layer.W = this.W.transpose();
        layer.input = this.input.transpose();
        layer.activationFunction = this.activationFunction;
        layer.nOut = this.nIn;
        layer.nIn = this.nOut;
        layer.rng = this.rng;
        return layer;
    }

    public synchronized DoubleMatrix activate() {
        return (DoubleMatrix)this.getActivationFunction().apply(this.getInput().mmul(this.getW()).addRowVector(this.getB()));
    }

    public synchronized DoubleMatrix activate(DoubleMatrix input) {
        if (input != null) {
            this.input = input;
        }
        return this.activate();
    }

    public DoubleMatrix sampleHGivenV(DoubleMatrix input) {
        this.input = input;
        DoubleMatrix ret = MatrixUtil.binomial(this.activate(), 1, this.rng);
        return ret;
    }

    public DoubleMatrix sample_h_given_v() {
        DoubleMatrix output = this.activate();
        DoubleMatrix ret = MatrixUtil.binomial(output, 1, this.rng);
        return ret;
    }

    public static class Builder {
        private int nIn;
        private int nOut;
        private DoubleMatrix W;
        private DoubleMatrix b;
        private RandomGenerator rng;
        private DoubleMatrix input;
        private ActivationFunction activationFunction = new Sigmoid();

        public Builder nIn(int nIn) {
            this.nIn = nIn;
            return this;
        }

        public Builder nOut(int nOut) {
            this.nOut = nOut;
            return this;
        }

        public Builder withWeights(DoubleMatrix W) {
            this.W = W;
            return this;
        }

        public Builder withRng(RandomGenerator gen) {
            this.rng = gen;
            return this;
        }

        public Builder withActivation(ActivationFunction function) {
            this.activationFunction = function;
            return this;
        }

        public Builder withBias(DoubleMatrix b) {
            this.b = b;
            return this;
        }

        public Builder withInput(DoubleMatrix input) {
            this.input = input;
            return this;
        }

        public HiddenLayer build() {
            HiddenLayer ret = new HiddenLayer(this.nIn, this.nOut, this.W, this.b, this.rng, this.input);
            ret.activationFunction = this.activationFunction;
            return ret;
        }
    }
}

