/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.feedforward.rbm;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.RBM;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BasePretrainNetwork;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.deeplearning4j.util.Dropout;
import org.deeplearning4j.util.RBMUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public class RBM
extends BasePretrainNetwork<org.deeplearning4j.nn.conf.layers.RBM> {
    private long seed;
    @Deprecated
    protected INDArray sigma;
    @Deprecated
    protected INDArray hiddenSigma;

    public RBM(NeuralNetConfiguration conf) {
        super(conf);
        this.seed = conf.getSeed();
    }

    public RBM(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
        this.seed = conf.getSeed();
    }

    @Deprecated
    public void contrastiveDivergence() {
        Gradient gradient = this.gradient();
        this.getParam("vb").subi(gradient.gradientForVariable().get("vb"));
        this.getParam("b").subi(gradient.gradientForVariable().get("b"));
        this.getParam("W").subi(gradient.gradientForVariable().get("W"));
    }

    @Override
    public void computeGradientAndScore() {
        int k = ((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getK();
        Pair<INDArray, INDArray> probHidden = this.sampleHiddenGivenVisible(this.input());
        INDArray chainStart = probHidden.getFirst();
        INDArray negVProb = null;
        INDArray negVSamples = null;
        INDArray negHProb = null;
        INDArray negHSamples = null;
        for (int i = 0; i < k; ++i) {
            Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> matrices = i == 0 ? this.gibbhVh(chainStart) : this.gibbhVh(negHSamples);
            negVProb = matrices.getFirst().getFirst();
            negVSamples = matrices.getFirst().getSecond();
            negHProb = matrices.getSecond().getFirst();
            negHSamples = matrices.getSecond().getSecond();
        }
        INDArray wGradient = this.input().transposei().mmul(probHidden.getFirst()).subi(negVProb.transpose().mmul(negHProb));
        INDArray hBiasGradient = ((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getSparsity() != 0.0 ? probHidden.getFirst().rsub((Number)((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getSparsity()).sum(new int[]{0}) : probHidden.getFirst().sub(negHProb).sum(new int[]{0});
        INDArray delta = this.input.sub(negVProb);
        INDArray vBiasGradient = delta.sum(new int[]{0});
        if (this.conf.isPretrain()) {
            wGradient.negi();
            hBiasGradient.negi();
            vBiasGradient.negi();
        }
        this.gradient = this.createGradient(wGradient, vBiasGradient, hBiasGradient);
        this.setScoreWithZ(negVSamples);
        if (this.trainingListeners != null && this.trainingListeners.size() > 0) {
            for (TrainingListener tl : this.trainingListeners) {
                tl.onBackwardPass(this);
            }
        }
    }

    public Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> gibbhVh(INDArray h) {
        Pair<INDArray, INDArray> v1MeanAndSample = this.sampleVisibleGivenHidden(h);
        INDArray negVProb = v1MeanAndSample.getFirst();
        Pair<INDArray, INDArray> h1MeanAndSample = this.sampleHiddenGivenVisible(negVProb);
        return new Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>(v1MeanAndSample, h1MeanAndSample);
    }

    @Override
    public Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray v) {
        INDArray hSample;
        INDArray hProb = this.propUp(v);
        switch (((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getHiddenUnit()) {
            case IDENTITY: {
                hSample = hProb;
                break;
            }
            case BINARY: {
                Distribution dist = Nd4j.getDistributions().createBinomial(1, hProb);
                hSample = dist.sample(hProb.shape());
                break;
            }
            case GAUSSIAN: {
                Distribution dist = Nd4j.getDistributions().createNormal(hProb, 1.0);
                hSample = dist.sample(hProb.shape());
                break;
            }
            case RECTIFIED: {
                INDArray sigH1Mean = Transforms.sigmoid((INDArray)hProb);
                INDArray sqrtSigH1Mean = Transforms.sqrt((INDArray)sigH1Mean);
                INDArray sample = Nd4j.getDistributions().createNormal(hProb, 1.0).sample(hProb.shape());
                sample.muli(sqrtSigH1Mean);
                hSample = hProb.add(sample);
                hSample = Transforms.max((INDArray)hSample, (double)0.0);
                break;
            }
            case SOFTMAX: {
                hSample = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", hProb));
                break;
            }
            default: {
                throw new IllegalStateException("Hidden unit type must either be Binary, Gaussian, SoftMax or Rectified " + this.layerId());
            }
        }
        return new Pair<INDArray, INDArray>(hProb, hSample);
    }

    @Override
    public Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray h) {
        INDArray vSample;
        INDArray vProb = this.propDown(h);
        switch (((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getVisibleUnit()) {
            case IDENTITY: {
                vSample = vProb;
                break;
            }
            case BINARY: {
                Distribution dist = Nd4j.getDistributions().createBinomial(1, vProb);
                vSample = dist.sample(vProb.shape());
                break;
            }
            case GAUSSIAN: 
            case LINEAR: {
                Distribution dist = Nd4j.getDistributions().createNormal(vProb, 1.0);
                vSample = dist.sample(vProb.shape());
                break;
            }
            case SOFTMAX: {
                vSample = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", vProb));
                break;
            }
            default: {
                throw new IllegalStateException("Visible type must be one of Binary, Gaussian, SoftMax or Linear " + this.layerId());
            }
        }
        return new Pair<INDArray, INDArray>(vProb, vSample);
    }

    @Override
    public INDArray preOutput(INDArray v, boolean training) {
        INDArray hBias = this.getParam("b");
        INDArray W = this.getParam("W");
        if (training && this.conf.isUseDropConnect() && this.conf.getLayer().getDropOut() > 0.0) {
            W = Dropout.applyDropConnect(this, "W");
        }
        return v.mmul(W).addiRowVector(hBias);
    }

    public INDArray propUp(INDArray v) {
        return this.propUp(v, true);
    }

    public INDArray propUp(INDArray v, boolean training) {
        INDArray preSig = this.preOutput(v, training);
        switch (((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getHiddenUnit()) {
            case IDENTITY: {
                return preSig;
            }
            case BINARY: {
                return Transforms.sigmoid((INDArray)preSig);
            }
            case GAUSSIAN: {
                Distribution dist = Nd4j.getDistributions().createNormal(preSig, 1.0);
                preSig = dist.sample(preSig.shape());
                return preSig;
            }
            case RECTIFIED: {
                preSig = Transforms.max((INDArray)preSig, (double)0.0);
                return preSig;
            }
            case SOFTMAX: {
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", preSig));
            }
        }
        throw new IllegalStateException("Hidden unit type should either be binary, gaussian, or rectified linear " + this.layerId());
    }

    public INDArray propUpDerivative(INDArray z) {
        switch (((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getHiddenUnit()) {
            case IDENTITY: {
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("identity", z).derivative());
            }
            case BINARY: {
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("sigmoid", z).derivative());
            }
            case GAUSSIAN: {
                Distribution dist = Nd4j.getDistributions().createNormal(z, 1.0);
                INDArray gaussian = dist.sample(z.shape());
                INDArray derivative = z.mul((Number)-2).mul(gaussian);
                return derivative;
            }
            case RECTIFIED: {
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("relu", z).derivative());
            }
            case SOFTMAX: {
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", z).derivative());
            }
        }
        throw new IllegalStateException("Hidden unit type should either be binary, gaussian, or rectified linear " + this.layerId());
    }

    public INDArray propDown(INDArray h) {
        INDArray W = this.getParam("W").transpose();
        INDArray vBias = this.getParam("vb");
        INDArray vMean = h.mmul(W).addiRowVector(vBias);
        switch (((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getVisibleUnit()) {
            case IDENTITY: {
                return vMean;
            }
            case BINARY: {
                return Transforms.sigmoid((INDArray)vMean);
            }
            case GAUSSIAN: {
                Distribution dist = Nd4j.getDistributions().createNormal(vMean, 1.0);
                vMean = dist.sample(vMean.shape());
                return vMean;
            }
            case LINEAR: {
                return vMean;
            }
            case SOFTMAX: {
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", vMean));
            }
        }
        throw new IllegalStateException("Visible unit type should either be binary or gaussian " + this.layerId());
    }

    @Override
    public INDArray activate(boolean training) {
        if (training && this.conf.getLayer().getDropOut() > 0.0) {
            Dropout.applyDropout(this.input, this.conf.getLayer().getDropOut());
        }
        INDArray propUp = this.propUp(this.input, training);
        return propUp;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        INDArray z = this.preOutput(this.input, true);
        INDArray activationDerivative = this.propUpDerivative(z);
        INDArray delta = epsilon.muli(activationDerivative);
        if (this.maskArray != null) {
            delta.muliColumnVector(this.maskArray);
        }
        DefaultGradient ret = new DefaultGradient();
        INDArray weightGrad = (INDArray)this.gradientViews.get("W");
        Nd4j.gemm((INDArray)this.input, (INDArray)delta, (INDArray)weightGrad, (boolean)true, (boolean)false, (double)1.0, (double)0.0);
        INDArray biasGrad = (INDArray)this.gradientViews.get("b");
        delta.sum(biasGrad, new int[]{0});
        INDArray vBiasGradient = (INDArray)this.gradientViews.get("vb");
        ret.gradientForVariable().put("W", weightGrad);
        ret.gradientForVariable().put("b", biasGrad);
        ret.gradientForVariable().put("vb", vBiasGradient);
        INDArray epsilonNext = ((INDArray)this.params.get("W")).mmul(delta.transpose()).transpose();
        return new Pair<Gradient, INDArray>(ret, epsilonNext);
    }

    @Override
    @Deprecated
    public void iterate(INDArray input) {
        if (((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getVisibleUnit() == RBM.VisibleUnit.GAUSSIAN) {
            this.sigma = input.var(new int[]{0}).divi((Number)input.rows());
        }
        this.input = input.dup();
        this.applyDropOutIfNecessary(true);
        this.contrastiveDivergence();
    }

    @Override
    @Deprecated
    public Layer transpose() {
        RBM r = (RBM)super.transpose();
        RBM.HiddenUnit h = RBMUtil.inverse(((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getVisibleUnit());
        RBM.VisibleUnit v = RBMUtil.inverse(((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getHiddenUnit());
        if (h == null) {
            h = ((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getHiddenUnit();
        }
        if (v == null) {
            v = ((org.deeplearning4j.nn.conf.layers.RBM)this.layerConf()).getVisibleUnit();
        }
        ((org.deeplearning4j.nn.conf.layers.RBM)r.layerConf()).setHiddenUnit(h);
        ((org.deeplearning4j.nn.conf.layers.RBM)r.layerConf()).setVisibleUnit(v);
        INDArray vb = this.getParam("b").dup();
        INDArray b = this.getParam("vb").dup();
        r.setParam("vb", vb);
        r.setParam("b", b);
        r.sigma = this.sigma;
        r.hiddenSigma = this.hiddenSigma;
        return r;
    }

    @Override
    public boolean isPretrainLayer() {
        return true;
    }
}

