/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.featuredetectors.rbm;

import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.optimize.optimizers.NeuralNetworkOptimizer;
import org.deeplearning4j.optimize.optimizers.rbm.RBMOptimizer;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.deeplearning4j.util.RBMUtil;
import org.nd4j.linalg.api.activation.Activations;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.sampling.Sampling;

public class RBM
extends BaseNeuralNetwork {
    private static final long serialVersionUID = 6189188205731511957L;
    protected NeuralNetworkOptimizer optimizer;
    protected INDArray sigma;
    protected INDArray hiddenSigma;

    protected RBM() {
    }

    public RBM(INDArray input, INDArray W, INDArray hbias, INDArray vbias, NeuralNetConfiguration conf) {
        super(input, W, hbias, vbias, conf);
    }

    public void contrastiveDivergence(double learningRate, int k, INDArray input) {
        if (input != null) {
            this.input = input;
        }
        this.lastMiniBatchSize = input.rows();
        NeuralNetworkGradient gradient = this.getGradient(new Object[]{k, learningRate, -1});
        double norm = gradient.getwGradient().norm2(Integer.MAX_VALUE).getDouble(0);
        this.getW().addi(gradient.getwGradient());
        this.gethBias().addi(gradient.gethBiasGradient());
        this.getvBias().addi(gradient.getvBiasGradient());
    }

    public void contrastiveDivergence(double learningRate, int k, INDArray input, int iteration) {
        if (input != null) {
            this.input = input;
        }
        this.lastMiniBatchSize = input.rows();
        NeuralNetworkGradient gradient = this.getGradient(new Object[]{k, learningRate, iteration});
        this.getW().addi(gradient.getwGradient());
        this.gethBias().addi(gradient.gethBiasGradient());
        this.getvBias().addi(gradient.getvBiasGradient());
    }

    @Override
    public NeuralNetworkGradient getGradient(Object[] params) {
        int iteration;
        int k = this.conf.getK();
        float learningRate = this.conf.getLr();
        int n = iteration = params[params.length - 1] == null ? 0 : (Integer)params[params.length - 1];
        if (this.wAdaGrad != null) {
            this.wAdaGrad.setMasterStepSize((double)learningRate);
        }
        if (this.hBiasAdaGrad != null) {
            this.hBiasAdaGrad.setMasterStepSize((double)learningRate);
        }
        if (this.vBiasAdaGrad != null) {
            this.vBiasAdaGrad.setMasterStepSize((double)learningRate);
        }
        Pair<INDArray, INDArray> probHidden = this.sampleHiddenGivenVisible(this.input);
        INDArray chainStart = probHidden.getSecond();
        INDArray nvMeans = null;
        INDArray nvSamples = null;
        INDArray nhMeans = null;
        INDArray nhSamples = null;
        for (int i = 0; i < k; ++i) {
            Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> matrices = i == 0 ? this.gibbhVh(chainStart) : this.gibbhVh(nhSamples);
            nvMeans = matrices.getFirst().getFirst();
            nvSamples = matrices.getFirst().getSecond();
            nhMeans = matrices.getSecond().getFirst();
            nhSamples = matrices.getSecond().getSecond();
        }
        INDArray wGradient = this.input.transpose().mmul(probHidden.getSecond()).sub(nvSamples.transpose().mmul(nhMeans));
        INDArray hBiasGradient = this.conf.getSparsity() != 0.0f ? probHidden.getSecond().rsubi((Number)Float.valueOf(this.conf.getSparsity())).mean(0) : probHidden.getSecond().sub(nhMeans).mean(0);
        INDArray vBiasGradient = this.input.sub(nvSamples).mean(0);
        NeuralNetworkGradient ret = new NeuralNetworkGradient(wGradient, vBiasGradient, hBiasGradient);
        this.updateGradientAccordingToParams(ret, iteration, learningRate);
        return ret;
    }

    @Override
    public void fit(INDArray data) {
        this.fit(data, null);
    }

    @Override
    public NeuralNetwork transpose() {
        RBM r = (RBM)super.transpose();
        HiddenUnit h = RBMUtil.inverse(this.conf.getVisibleUnit());
        VisibleUnit v = RBMUtil.inverse(this.conf.getHiddenUnit());
        if (h == null) {
            h = this.conf.getHiddenUnit();
        }
        if (v == null) {
            v = this.conf.getVisibleUnit();
        }
        r.sigma = this.sigma;
        r.hiddenSigma = this.hiddenSigma;
        return r;
    }

    @Override
    public NeuralNetwork clone() {
        RBM r = (RBM)super.clone();
        r.sigma = this.sigma;
        r.hiddenSigma = this.hiddenSigma;
        return r;
    }

    public double freeEnergy(INDArray visibleSample) {
        INDArray wxB = visibleSample.mmul(this.W).addiRowVector(this.hBias);
        double vBiasTerm = Nd4j.getBlasWrapper().dot(visibleSample, this.vBias);
        double hBiasTerm = (Double)Transforms.log((INDArray)Transforms.exp((INDArray)wxB).add((Number)1)).sum(Integer.MAX_VALUE).element();
        return -hBiasTerm - vBiasTerm;
    }

    @Override
    public Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray v) {
        if (this.conf.getHiddenUnit() == HiddenUnit.RECTIFIED) {
            INDArray h1Mean = this.propUp(v);
            INDArray sigH1Mean = Transforms.sigmoid((INDArray)h1Mean);
            MersenneTwister gen = new MersenneTwister(123);
            INDArray sqrtSigH1Mean = Transforms.sqrt((INDArray)sigH1Mean);
            INDArray sample = Sampling.normal((RandomGenerator)gen, (INDArray)h1Mean, (double)1.0);
            sample.muli(sqrtSigH1Mean);
            INDArray h1Sample = h1Mean.add(sample);
            h1Sample = Transforms.max((INDArray)h1Sample);
            this.applyDropOutIfNecessary(h1Sample);
            return new Pair<INDArray, INDArray>(h1Mean, h1Sample);
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.GAUSSIAN) {
            INDArray h1Mean = this.propUp(v);
            this.hiddenSigma = h1Mean.var(1);
            INDArray h1Sample = h1Mean.addi(Sampling.normal((RandomGenerator)this.conf.getRng(), (INDArray)h1Mean, (INDArray)this.hiddenSigma));
            this.applyDropOutIfNecessary(h1Sample);
            return new Pair<INDArray, INDArray>(h1Mean, h1Sample);
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.SOFTMAX) {
            INDArray h1Mean = this.propUp(v);
            INDArray h1Sample = (INDArray)Activations.softMaxRows().apply((Object)h1Mean);
            this.applyDropOutIfNecessary(h1Sample);
            return new Pair<INDArray, INDArray>(h1Mean, h1Sample);
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.BINARY) {
            INDArray h1Mean = this.propUp(v);
            INDArray h1Sample = Sampling.binomial((INDArray)h1Mean, (int)1, (RandomGenerator)this.conf.getRng());
            this.applyDropOutIfNecessary(h1Sample);
            return new Pair<INDArray, INDArray>(h1Mean, h1Sample);
        }
        throw new IllegalStateException("Hidden unit type must either be rectified linear or binary");
    }

    public Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>> gibbhVh(INDArray h) {
        Pair<INDArray, INDArray> v1MeanAndSample = this.sampleVisibleGivenHidden(h);
        INDArray vSample = v1MeanAndSample.getSecond();
        Pair<INDArray, INDArray> h1MeanAndSample = this.sampleHiddenGivenVisible(vSample);
        return new Pair<Pair<INDArray, INDArray>, Pair<INDArray, INDArray>>(v1MeanAndSample, h1MeanAndSample);
    }

    @Override
    public Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray h) {
        INDArray v1Mean = this.propDown(h);
        if (this.conf.getVisibleUnit() == VisibleUnit.GAUSSIAN) {
            INDArray v1Sample = v1Mean.add(Nd4j.randn((int)v1Mean.rows(), (int)v1Mean.columns(), (RandomGenerator)this.conf.getRng()));
            return new Pair<INDArray, INDArray>(v1Mean, v1Sample);
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.LINEAR) {
            INDArray v1Sample = Sampling.normal((RandomGenerator)this.conf.getRng(), (INDArray)v1Mean, (double)1.0);
            return new Pair<INDArray, INDArray>(v1Mean, v1Sample);
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.SOFTMAX) {
            INDArray v1Sample = (INDArray)Activations.softMaxRows().apply((Object)v1Mean);
            return new Pair<INDArray, INDArray>(v1Mean, v1Sample);
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.BINARY) {
            INDArray v1Sample = Sampling.binomial((INDArray)v1Mean, (int)1, (RandomGenerator)this.conf.getRng());
            return new Pair<INDArray, INDArray>(v1Mean, v1Sample);
        }
        throw new IllegalStateException("Visible type must either be binary,gaussian, softmax, or linear");
    }

    public INDArray propUp(INDArray v) {
        if (this.conf.getVisibleUnit() == VisibleUnit.GAUSSIAN) {
            this.sigma = v.var(0).divi((Number)this.input.rows());
        }
        INDArray preSig = v.mmul(this.W);
        if (this.conf.isConcatBiases()) {
            preSig = Nd4j.hstack((INDArray[])new INDArray[]{preSig, this.hBias});
        } else {
            preSig.addiRowVector(this.hBias);
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.RECTIFIED) {
            preSig = Transforms.max((INDArray)preSig);
            return preSig;
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.GAUSSIAN) {
            INDArray add = preSig.add(Nd4j.randn((int)preSig.rows(), (int)preSig.columns(), (RandomGenerator)this.conf.getRng()));
            preSig.addi(add);
            return preSig;
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.BINARY) {
            return Transforms.sigmoid((INDArray)preSig);
        }
        if (this.conf.getHiddenUnit() == HiddenUnit.SOFTMAX) {
            return (INDArray)Activations.softMaxRows().apply((Object)preSig);
        }
        throw new IllegalStateException("Hidden unit type should either be binary, gaussian, or rectified linear");
    }

    @Override
    public INDArray hiddenActivation(INDArray input) {
        return this.propUp(input);
    }

    @Override
    public void iterationDone(int iteration) {
        int plotEpochs = this.conf.getRenderWeightIterations();
        if (plotEpochs <= 0) {
            return;
        }
        if (iteration % plotEpochs == 0 || iteration == 0) {
            NeuralNetPlotter plotter = new NeuralNetPlotter();
            plotter.plotNetworkGradient(this, this.getGradient(new Object[]{1, 0.001, 1000}), this.getInput().rows());
        }
    }

    public INDArray propDown(INDArray h) {
        INDArray vMean = h.mmul(this.W.transpose());
        if (this.conf.isConcatBiases()) {
            vMean = Nd4j.hstack((INDArray[])new INDArray[]{vMean, this.vBias});
        } else {
            vMean.addiRowVector(this.vBias);
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.GAUSSIAN) {
            INDArray sample = Sampling.normal((RandomGenerator)this.conf.getRng(), (INDArray)vMean, (double)1.0);
            vMean.addi(sample);
            return vMean;
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.LINEAR) {
            vMean = Sampling.normal((RandomGenerator)this.conf.getRng(), (INDArray)vMean, (double)1.0);
            return vMean;
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.BINARY) {
            return Transforms.sigmoid((INDArray)vMean);
        }
        if (this.conf.getVisibleUnit() == VisibleUnit.SOFTMAX) {
            return (INDArray)Activations.softMaxRows().apply((Object)vMean);
        }
        throw new IllegalStateException("Visible unit type should either be binary or gaussian");
    }

    @Override
    public INDArray transform(INDArray v) {
        return this.propDown(this.propUp(v));
    }

    @Override
    public void fit(INDArray input, Object[] params) {
        if (input != null) {
            this.input = Transforms.stabilize((INDArray)input, (double)1.0);
        }
        this.lastMiniBatchSize = input.rows();
        if (this.conf.getVisibleUnit() == VisibleUnit.GAUSSIAN) {
            this.sigma = input.var(0);
            this.sigma.divi((Number)input.rows());
        }
        this.optimizer = new RBMOptimizer(this, this.conf.getLr(), params, this.conf.getOptimizationAlgo(), this.conf.getLossFunction());
        this.optimizer.train(input);
    }

    public String toString() {
        return "RBM{optimizer=" + this.optimizer + ", visibleType=" + (Object)((Object)this.conf.getVisibleUnit()) + ", hiddenType=" + (Object)((Object)this.conf.getVisibleUnit()) + ", sigma=" + this.sigma + ", hiddenSigma=" + this.hiddenSigma + "} " + super.toString();
    }

    @Override
    public void iterate(INDArray input, Object[] params) {
        if (this.conf.getVisibleUnit() == VisibleUnit.GAUSSIAN) {
            this.sigma = input.var(0).divi((Number)input.rows());
        }
        int k = (Integer)params[0];
        this.contrastiveDivergence(this.conf.getLr(), k, input);
    }

    public static class Builder
    extends BaseNeuralNetwork.Builder<RBM> {
        public Builder() {
            this.clazz = RBM.class;
        }

        @Override
        public RBM buildEmpty() {
            return (RBM)super.buildEmpty();
        }

        public Builder withClazz(Class<? extends BaseNeuralNetwork> clazz) {
            super.withClazz(clazz);
            return this;
        }

        public Builder withInput(INDArray input) {
            super.withInput(input);
            return this;
        }

        public Builder asType(Class<RBM> clazz) {
            super.asType(clazz);
            return this;
        }

        public Builder withWeights(INDArray W) {
            super.withWeights(W);
            return this;
        }

        public Builder withVisibleBias(INDArray vBias) {
            super.withVisibleBias(vBias);
            return this;
        }

        public Builder withHBias(INDArray hBias) {
            super.withHBias(hBias);
            return this;
        }

        @Override
        public RBM build() {
            RBM ret = (RBM)super.build();
            return ret;
        }
    }

    public static enum HiddenUnit {
        RECTIFIED,
        BINARY,
        GAUSSIAN,
        SOFTMAX;

    }

    public static enum VisibleUnit {
        BINARY,
        GAUSSIAN,
        SOFTMAX,
        LINEAR;

    }
}

