/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import java.lang.reflect.Constructor;
import java.util.Arrays;
import org.deeplearning4j.nn.WeightInitUtil;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public abstract class BaseLayer
implements Layer {
    protected INDArray W;
    protected INDArray b;
    protected INDArray input;
    protected NeuralNetConfiguration conf;
    protected INDArray dropoutMask;

    public BaseLayer(NeuralNetConfiguration conf, INDArray W, INDArray b, INDArray input) {
        this.input = input;
        this.conf = conf;
        this.W = W == null ? this.createWeightMatrix() : W;
        this.b = b == null ? this.createBias() : b;
    }

    protected INDArray createBias() {
        return Nd4j.zeros((int)this.conf.getnOut());
    }

    protected INDArray createWeightMatrix() {
        INDArray W = WeightInitUtil.initWeights(this.conf.getnIn(), this.conf.getnOut(), this.conf.getWeightInit(), this.conf.getActivationFunction(), this.conf.getDist());
        return W;
    }

    @Override
    public INDArray preOutput(INDArray x) {
        if (x == null) {
            throw new IllegalArgumentException("No null input allowed");
        }
        this.input = x;
        INDArray ret = this.input.mmul(this.W);
        if (ret.columns() != this.b.columns()) {
            throw new IllegalStateException("This is weird");
        }
        if (this.conf.isConcatBiases()) {
            ret = Nd4j.hstack((INDArray[])new INDArray[]{ret, this.b});
        } else {
            ret.addiRowVector(this.b);
        }
        return ret;
    }

    @Override
    public INDArray activate() {
        INDArray activation = (INDArray)this.conf.getActivationFunction().apply((Object)this.getInput().mmul(this.getW()).addRowVector(this.getB()));
        return activation;
    }

    @Override
    public INDArray activate(INDArray input) {
        if (input != null) {
            this.input = Transforms.stabilize((INDArray)input, (double)1.0);
        }
        return this.activate();
    }

    @Override
    public NeuralNetConfiguration conf() {
        return this.conf;
    }

    @Override
    public void setConfiguration(NeuralNetConfiguration conf) {
        this.conf = conf;
    }

    @Override
    public INDArray getW() {
        return this.W;
    }

    @Override
    public void setW(INDArray W) {
        assert (W.rows() == this.conf().getnIn() && W.columns() == this.conf.getnOut()) : "Weight matrix must be of shape " + Arrays.toString(new int[]{this.conf().getnIn(), this.conf.getnOut()});
        this.W = W;
    }

    @Override
    public INDArray getB() {
        return this.b;
    }

    @Override
    public void setB(INDArray b) {
        assert (b.columns() == this.conf().getnOut()) : "The bias must have " + this.conf().getnOut() + " columns";
        this.b = b;
    }

    @Override
    public INDArray getInput() {
        return this.input;
    }

    @Override
    public void setInput(INDArray input) {
        this.input = input;
    }

    protected void applyDropOutIfNecessary(INDArray input) {
        if (this.conf.getDropOut() > 0.0f) {
            INDArray mask = Nd4j.rand((int)input.rows(), (int)input.columns());
            mask.gti((Number)2);
            this.dropoutMask = Nd4j.rand((int)input.rows(), (int)input.columns()).gt((Number)Float.valueOf(this.conf.getDropOut()));
        } else {
            this.dropoutMask = Nd4j.ones((int)input.rows(), (int)this.conf.getnOut());
        }
        input.muli(this.dropoutMask);
    }

    public void merge(Layer l, int batchSize) {
        if (this.conf.isUseRegularization()) {
            this.W.addi(l.getW().subi(this.W).div((Number)batchSize));
            this.b.addi(l.getB().subi(this.b).div((Number)batchSize));
        } else {
            this.W.addi(l.getW().subi(this.W));
            this.b.addi(l.getB().subi(this.b));
        }
    }

    @Override
    public Layer clone() {
        Layer layer = null;
        try {
            Constructor<?> c = this.getClass().getConstructor(NeuralNetConfiguration.class, INDArray.class, INDArray.class, INDArray.class);
            layer = (Layer)c.newInstance(this.conf, this.W.dup(), this.b.dup(), this.input != null ? this.input.dup() : null);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return layer;
    }

    @Override
    public Layer transpose() {
        Layer layer = null;
        try {
            Constructor<?> c = this.getClass().getConstructor(NeuralNetConfiguration.class, INDArray.class, INDArray.class, INDArray.class);
            NeuralNetConfiguration clone = this.conf.clone();
            int nIn = clone.getnOut();
            int nOut = clone.getnIn();
            clone.setnIn(nIn);
            clone.setnOut(nOut);
            layer = (Layer)c.newInstance(this.conf, this.W.transpose().dup(), this.b.transpose().dup(), this.input != null ? this.input.transpose().dup() : null);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return layer;
    }

    public String toString() {
        return "BaseLayer{W=" + this.W + ", b=" + this.b + ", input=" + this.input + ", conf=" + this.conf + ", dropoutMask=" + this.dropoutMask + '}';
    }
}

