/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.featuredetectors.rbm;

import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BasePretrainNetwork;
import org.deeplearning4j.util.RBMUtil;
import org.nd4j.linalg.api.activation.Activations;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.sampling.Sampling;

public class RBM
extends BasePretrainNetwork {
    private static final long serialVersionUID = 6189188205731511957L;
    protected INDArray sigma;
    protected INDArray hiddenSigma;

    public RBM(NeuralNetConfiguration conf) {
        super(conf);
    }

    public RBM(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    public void contrastiveDivergence() {
        Gradient gradient = this.getGradient();
        this.getParam("vb").subi(gradient.gradientLookupTable().get("vb"));
        this.getParam("b").subi(gradient.gradientLookupTable().get("b"));
        this.getParam("W").subi(gradient.gradientLookupTable().get("W"));
    }

    @Override
    public Gradient getGradient() {
        int k = this.conf.getK();
        Pair<INDArray, INDArray> probHidden = this.sampleHiddenGivenVisible(this.input);
        INDArray chainStart = probHidden.getSecond();
        INDArray nvMeans = null;
        INDArray nvSamples = null;
        INDArray nhMeans = null;
        INDArray nhSamples = null;
        for (int i = 0; i < k; ++i) {
            Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> matrices = i == 0 ? this.gibbhVh(chainStart) : this.gibbhVh(nhSamples);
            nvMeans = matrices.getFirst().getFirst();
            nvSamples = matrices.getFirst().getSecond();
            nhMeans = matrices.getSecond().getFirst();
            nhSamples = matrices.getSecond().getSecond();
        }
        INDArray wGradient = this.input.transpose().mmul(probHidden.getSecond()).subi(nvSamples.transpose().mmul(nhMeans));
        INDArray hBiasGradient = this.conf.getSparsity() != 0.0 ? probHidden.getSecond().rsub((Number)this.conf.getSparsity()).mean(0) : probHidden.getSecond().sub(nhMeans).mean(0);
        INDArray vBiasGradient = this.input.sub(nvSamples).mean(0);
        DefaultGradient ret = new DefaultGradient();
        ret.gradientLookupTable().put("vb", vBiasGradient);
        ret.gradientLookupTable().put("b", hBiasGradient);
        ret.gradientLookupTable().put("W", wGradient);
        return ret;
    }

    @Override
    public Layer transpose() {
        RBM r = (RBM)super.transpose();
        HiddenUnit h = RBMUtil.inverse(this.conf.getVisibleUnit());
        VisibleUnit v = RBMUtil.inverse(this.conf.getHiddenUnit());
        if (h == null) {
            h = this.conf.getHiddenUnit();
        }
        if (v == null) {
            v = this.conf.getVisibleUnit();
        }
        r.sigma = this.sigma;
        r.hiddenSigma = this.hiddenSigma;
        return r;
    }

    public double freeEnergy(INDArray visibleSample) {
        INDArray W = this.getParam("W");
        INDArray hBias = this.getParam("b");
        INDArray vBias = this.getParam("vb");
        INDArray wxB = visibleSample.mmul(W).addRowVector(hBias);
        double vBiasTerm = Nd4j.getBlasWrapper().dot(visibleSample, vBias);
        double hBiasTerm = Transforms.log((INDArray)Transforms.exp((INDArray)wxB).add((Number)1)).sum(Integer.MAX_VALUE).getDouble(0);
        return -hBiasTerm - vBiasTerm;
    }

    @Override
    public Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray v) {
        if (this.conf.getHiddenUnit() == HiddenUnit.RECTIFIED) {
            INDArray h1Mean = this.propUp(v);
            INDArray sigH1Mean = Transforms.sigmoid((INDArray)h1Mean);
            MersenneTwister gen = new MersenneTwister(123);
            INDArray sqrtSigH1Mean = Transforms.sqrt((INDArray)sigH1Mean);
            INDArray sample = Sampling.normal((RandomGenerator)gen, (INDArray)h1Mean, (double)1.0);
            sample.muli(sqrtSigH1Mean);
            INDArray h1Sample = h1Mean.add(sample);
            h1Sample = Transforms.max((INDArray)h1Sample);
            this.applyDropOutIfNecessary(h1Sample);
            return new Pair<INDArray, INDArray>(h1Mean, h1Sample);
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.GAUSSIAN) {
            INDArray h1Mean = this.propUp(v);
            this.hiddenSigma = h1Mean.var(1);
            INDArray h1Sample = h1Mean.addi(Sampling.normal((RandomGenerator)this.conf.getRng(), (INDArray)h1Mean, (INDArray)this.hiddenSigma));
            this.applyDropOutIfNecessary(h1Sample);
            return new Pair<INDArray, INDArray>(h1Mean, h1Sample);
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.SOFTMAX) {
            INDArray h1Mean = this.propUp(v);
            INDArray h1Sample = (INDArray)Activations.softMaxRows().apply((Object)h1Mean);
            this.applyDropOutIfNecessary(h1Sample);
            return new Pair<INDArray, INDArray>(h1Mean, h1Sample);
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.BINARY) {
            INDArray h1Mean = this.propUp(v);
            INDArray h1Sample = Sampling.binomial((INDArray)h1Mean, (int)1, (RandomGenerator)this.conf.getRng());
            this.applyDropOutIfNecessary(h1Sample);
            return new Pair<INDArray, INDArray>(h1Mean, h1Sample);
        }
        throw new IllegalStateException("Hidden unit type must either be rectified linear or binary");
    }

    public Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> gibbhVh(INDArray h) {
        Pair<INDArray, INDArray> v1MeanAndSample = this.sampleVisibleGivenHidden(h);
        INDArray vSample = v1MeanAndSample.getSecond();
        Pair<INDArray, INDArray> h1MeanAndSample = this.sampleHiddenGivenVisible(vSample);
        return new Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>(v1MeanAndSample, h1MeanAndSample);
    }

    @Override
    public Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray h) {
        INDArray v1Mean = this.propDown(h);
        if (this.conf.getVisibleUnit() == VisibleUnit.GAUSSIAN) {
            INDArray v1Sample = v1Mean.add(Nd4j.randn((int)v1Mean.rows(), (int)v1Mean.columns(), (RandomGenerator)this.conf.getRng()));
            return new Pair<INDArray, INDArray>(v1Mean, v1Sample);
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.LINEAR) {
            INDArray v1Sample = Sampling.normal((RandomGenerator)this.conf.getRng(), (INDArray)v1Mean, (double)1.0);
            return new Pair<INDArray, INDArray>(v1Mean, v1Sample);
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.SOFTMAX) {
            INDArray v1Sample = (INDArray)Activations.softMaxRows().apply((Object)v1Mean);
            return new Pair<INDArray, INDArray>(v1Mean, v1Sample);
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.BINARY) {
            INDArray v1Sample = Sampling.binomial((INDArray)v1Mean, (int)1, (RandomGenerator)this.conf.getRng());
            return new Pair<INDArray, INDArray>(v1Mean, v1Sample);
        }
        throw new IllegalStateException("Visible type must either be binary,gaussian, softmax, or linear");
    }

    public INDArray propUp(INDArray v) {
        INDArray W = this.getParam("W");
        INDArray hBias = this.getParam("b");
        if (this.conf.getVisibleUnit() == VisibleUnit.GAUSSIAN) {
            this.sigma = v.var(0).divi((Number)this.input.rows());
        }
        INDArray preSig = v.mmul(W);
        if (this.conf.isConcatBiases()) {
            preSig = Nd4j.hstack((INDArray[])new INDArray[]{preSig, hBias});
        } else {
            preSig.addiRowVector(hBias);
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.RECTIFIED) {
            preSig = Transforms.max((INDArray)preSig);
            return preSig;
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.GAUSSIAN) {
            INDArray add = preSig.add(Nd4j.randn((int)preSig.rows(), (int)preSig.columns(), (RandomGenerator)this.conf.getRng()));
            preSig.addi(add);
            return preSig;
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.BINARY) {
            return Transforms.sigmoid((INDArray)preSig);
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.SOFTMAX) {
            return (INDArray)Activations.softMaxRows().apply((Object)preSig);
        }
        throw new IllegalStateException("Hidden unit type should either be binary, gaussian, or rectified linear");
    }

    public INDArray propDown(INDArray h) {
        INDArray W = this.getParam("W");
        INDArray vBias = this.getParam("vb");
        INDArray vMean = h.mmul(W.transpose());
        if (this.conf.isConcatBiases()) {
            vMean = Nd4j.hstack((INDArray[])new INDArray[]{vMean, vBias});
        } else {
            vMean.addiRowVector(vBias);
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.GAUSSIAN) {
            INDArray sample = Sampling.normal((RandomGenerator)this.conf.getRng(), (INDArray)vMean, (double)1.0);
            vMean.addi(sample);
            return vMean;
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.LINEAR) {
            vMean = Sampling.normal((RandomGenerator)this.conf.getRng(), (INDArray)vMean, (double)1.0);
            return vMean;
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.BINARY) {
            return Transforms.sigmoid((INDArray)vMean);
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.SOFTMAX) {
            return (INDArray)Activations.softMaxRows().apply((Object)vMean);
        }
        throw new IllegalStateException("Visible unit type should either be binary or gaussian");
    }

    @Override
    public INDArray transform(INDArray v) {
        return this.propDown(this.propUp(v));
    }

    @Override
    public void fit(INDArray input) {
        if (this.conf.getVisibleUnit() == VisibleUnit.GAUSSIAN) {
            this.sigma = input.var(0);
            this.sigma.divi((Number)input.rows());
        }
        super.fit(input);
    }

    public String toString() {
        return "RBM{, visibleType=" + (Object)((Object)this.conf.getVisibleUnit()) + ", hiddenType=" + (Object)((Object)this.conf.getVisibleUnit()) + ", sigma=" + this.sigma + ", hiddenSigma=" + this.hiddenSigma + "} " + super.toString();
    }

    @Override
    public void iterate(INDArray input) {
        if (this.conf.getVisibleUnit() == VisibleUnit.GAUSSIAN) {
            this.sigma = input.var(0).divi((Number)input.rows());
        }
        this.input = input;
        this.contrastiveDivergence();
    }

    public static enum HiddenUnit {
        RECTIFIED,
        BINARY,
        GAUSSIAN,
        SOFTMAX;

    }

    public static enum VisibleUnit {
        BINARY,
        GAUSSIAN,
        SOFTMAX,
        LINEAR;

    }
}

