/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.feedforward.rbm;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.RBM;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BasePretrainNetwork;
import org.deeplearning4j.util.RBMUtil;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public class RBM
extends BasePretrainNetwork {
    private final Random rng = Nd4j.getRandom();
    private static final long serialVersionUID = 6189188205731511957L;
    protected INDArray sigma;
    protected INDArray hiddenSigma;

    public RBM(NeuralNetConfiguration conf) {
        super(conf);
    }

    public RBM(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    public void contrastiveDivergence() {
        Gradient gradient = this.gradient();
        this.getParam("vb").subi(gradient.gradientForVariable().get("vb"));
        this.getParam("b").subi(gradient.gradientForVariable().get("b"));
        this.getParam("W").subi(gradient.gradientForVariable().get("W"));
    }

    @Override
    public Gradient gradient() {
        int k = this.conf.getK();
        Pair<INDArray, INDArray> probHidden = this.sampleHiddenGivenVisible(this.input());
        INDArray chainStart = probHidden.getSecond();
        INDArray nvMeans = null;
        INDArray nvSamples = null;
        INDArray nhMeans = null;
        INDArray nhSamples = null;
        for (int i = 0; i < k; ++i) {
            Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> matrices = i == 0 ? this.gibbhVh(chainStart) : this.gibbhVh(nhSamples);
            nvMeans = matrices.getFirst().getFirst();
            nvSamples = matrices.getFirst().getSecond();
            nhMeans = matrices.getSecond().getFirst();
            nhSamples = matrices.getSecond().getSecond();
        }
        INDArray wGradient = this.input().transposei().mmul(probHidden.getSecond()).subi(nvSamples.transposei().mmul(nhMeans));
        INDArray hBiasGradient = this.conf.getSparsity() != 0.0 ? probHidden.getSecond().rsub((Number)this.conf.getSparsity()).mean(0) : probHidden.getSecond().sub(nhMeans).mean(0);
        INDArray vBiasGradient = this.input.sub(nvSamples).mean(0);
        DefaultGradient ret = new DefaultGradient();
        ret.gradientForVariable().put("vb", vBiasGradient);
        ret.gradientForVariable().put("b", hBiasGradient);
        ret.gradientForVariable().put("W", wGradient);
        return ret;
    }

    @Override
    public Layer transpose() {
        RBM r = (RBM)super.transpose();
        RBM.HiddenUnit h = RBMUtil.inverse(this.conf.getVisibleUnit());
        RBM.VisibleUnit v = RBMUtil.inverse(this.conf.getHiddenUnit());
        if (h == null) {
            h = this.conf.getHiddenUnit();
        }
        if (v == null) {
            v = this.conf.getVisibleUnit();
        }
        r.sigma = this.sigma;
        r.hiddenSigma = this.hiddenSigma;
        return r;
    }

    @Override
    public Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray v) {
        INDArray h1Sample;
        INDArray h1Mean = this.propUp(v);
        switch (this.conf.getHiddenUnit()) {
            case RECTIFIED: {
                INDArray sigH1Mean = Transforms.sigmoid((INDArray)h1Mean);
                INDArray sqrtSigH1Mean = Transforms.sqrt((INDArray)sigH1Mean);
                INDArray sample = Nd4j.getDistributions().createNormal(h1Mean, 1.0).sample(h1Mean.shape());
                sample.muli(sqrtSigH1Mean);
                h1Sample = h1Mean.add(sample);
                h1Sample = Transforms.max((INDArray)h1Sample, (double)0.0);
                this.applyDropOutIfNecessary(h1Sample);
                break;
            }
            case GAUSSIAN: {
                h1Sample = h1Mean.add(Nd4j.randn((int)h1Mean.rows(), (int)h1Mean.columns(), (Random)this.rng));
                this.applyDropOutIfNecessary(h1Sample);
                break;
            }
            case SOFTMAX: {
                h1Sample = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", h1Mean), 0);
                this.applyDropOutIfNecessary(h1Sample);
                break;
            }
            case BINARY: {
                h1Sample = Nd4j.getDistributions().createBinomial(1, h1Mean).sample(h1Mean.shape());
                this.applyDropOutIfNecessary(h1Sample);
                break;
            }
            default: {
                throw new IllegalStateException("Hidden unit type must either be rectified linear or binary");
            }
        }
        return new Pair<INDArray, INDArray>(h1Mean, h1Sample);
    }

    public Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> gibbhVh(INDArray h) {
        Pair<INDArray, INDArray> v1MeanAndSample = this.sampleVisibleGivenHidden(h);
        INDArray vSample = v1MeanAndSample.getSecond();
        Pair<INDArray, INDArray> h1MeanAndSample = this.sampleHiddenGivenVisible(vSample);
        return new Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>(v1MeanAndSample, h1MeanAndSample);
    }

    @Override
    public Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray h) {
        INDArray v1Sample;
        INDArray v1Mean = this.propDown(h);
        switch (this.conf.getVisibleUnit()) {
            case GAUSSIAN: {
                v1Sample = v1Mean.add(Nd4j.randn((int)v1Mean.rows(), (int)v1Mean.columns(), (Random)this.rng));
                break;
            }
            case LINEAR: {
                v1Sample = Nd4j.getDistributions().createNormal(v1Mean, 1.0).sample(v1Mean.shape());
                break;
            }
            case SOFTMAX: {
                v1Sample = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", v1Mean), 0);
                break;
            }
            case BINARY: {
                v1Sample = Nd4j.getDistributions().createBinomial(1, v1Mean).sample(v1Mean.shape());
                break;
            }
            default: {
                throw new IllegalStateException("Visible type must be one of Binary, Gaussian, SoftMax or Linear");
            }
        }
        return new Pair<INDArray, INDArray>(v1Mean, v1Sample);
    }

    public INDArray propUp(INDArray v) {
        INDArray W = this.getParam("W");
        INDArray hBias = this.getParam("b");
        if (this.conf.getVisibleUnit() == RBM.VisibleUnit.GAUSSIAN) {
            this.sigma = v.var(0).divi((Number)this.input.rows());
        }
        INDArray preSig = v.mmul(W).addiRowVector(hBias);
        switch (this.conf.getHiddenUnit()) {
            case RECTIFIED: {
                preSig = Transforms.max((INDArray)preSig, (double)0.0);
                return preSig;
            }
            case GAUSSIAN: {
                INDArray add = preSig.add(Nd4j.randn((int)preSig.rows(), (int)preSig.columns(), (Random)this.rng));
                preSig.addi(add);
                return preSig;
            }
            case BINARY: {
                return Transforms.sigmoid((INDArray)preSig);
            }
            case SOFTMAX: {
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", preSig), 0);
            }
        }
        throw new IllegalStateException("Hidden unit type should either be binary, gaussian, or rectified linear");
    }

    public INDArray propDown(INDArray h) {
        INDArray W = this.getParam("W").transpose();
        INDArray vBias = this.getParam("vb");
        INDArray vMean = h.mmul(W).addiRowVector(vBias);
        switch (this.conf.getVisibleUnit()) {
            case GAUSSIAN: {
                INDArray sample = Nd4j.getDistributions().createNormal(vMean, 1.0).sample(vMean.shape());
                vMean.addi(sample);
                return vMean;
            }
            case LINEAR: {
                return vMean;
            }
            case BINARY: {
                return Transforms.sigmoid((INDArray)vMean);
            }
            case SOFTMAX: {
                return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", vMean), 0);
            }
        }
        throw new IllegalStateException("Visible unit type should either be binary or gaussian");
    }

    @Override
    public INDArray transform(INDArray v) {
        INDArray propUp = this.propUp(v);
        return this.propDown(propUp);
    }

    @Override
    public void fit(INDArray input) {
        if (this.conf.getVisibleUnit() == RBM.VisibleUnit.GAUSSIAN) {
            this.sigma = input.var(0);
            this.sigma.divi((Number)input.rows());
        }
        super.fit(input);
    }

    public String toString() {
        return "RBM{, visibleType=" + (Object)((Object)this.conf.getVisibleUnit()) + ", hiddenType=" + (Object)((Object)this.conf.getVisibleUnit()) + ", sigma=" + this.sigma + ", hiddenSigma=" + this.hiddenSigma + "} " + super.toString();
    }

    @Override
    public void iterate(INDArray input) {
        if (this.conf.getVisibleUnit() == RBM.VisibleUnit.GAUSSIAN) {
            this.sigma = input.var(0).divi((Number)input.rows());
        }
        this.input = input;
        this.contrastiveDivergence();
    }
}

