/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.rbm;

import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.optimize.NeuralNetworkOptimizer;
import org.deeplearning4j.rbm.RBMOptimizer;
import org.deeplearning4j.util.MatrixUtil;
import org.jblas.DoubleMatrix;

public class RBM
extends BaseNeuralNetwork {
    private static final long serialVersionUID = 6189188205731511957L;
    protected NeuralNetworkOptimizer optimizer;

    public RBM() {
    }

    public RBM(DoubleMatrix input, int n_visible, int n_hidden, DoubleMatrix W, DoubleMatrix hbias, DoubleMatrix vbias, RandomGenerator rng, double fanIn, RealDistribution dist) {
        super(input, n_visible, n_hidden, W, hbias, vbias, rng, fanIn, dist);
    }

    public void trainTillConvergence(double learningRate, int k, DoubleMatrix input) {
        if (input != null) {
            this.input = input;
        }
        this.optimizer = new RBMOptimizer(this, learningRate, new Object[]{k, learningRate});
        this.optimizer.train(input);
    }

    public void contrastiveDivergence(double learningRate, int k, DoubleMatrix input) {
        if (input != null) {
            this.input = input;
        }
        NeuralNetworkGradient gradient = this.getGradient(new Object[]{k, learningRate});
        this.W.addi(gradient.getwGradient());
        this.hBias.addi(gradient.gethBiasGradient());
        this.vBias.addi(gradient.getvBiasGradient());
    }

    @Override
    public NeuralNetworkGradient getGradient(Object[] params) {
        int k = (Integer)params[0];
        double learningRate = (Double)params[1];
        Pair<DoubleMatrix, DoubleMatrix> probHidden = this.sampleHiddenGivenVisible(this.input);
        DoubleMatrix chainStart = probHidden.getSecond();
        Pair<Pair<DoubleMatrix, DoubleMatrix>, Pair<DoubleMatrix, DoubleMatrix>> matrices = null;
        DoubleMatrix nvMeans = null;
        DoubleMatrix nvSamples = null;
        DoubleMatrix nhMeans = null;
        DoubleMatrix nhSamples = null;
        for (int i = 0; i < k; ++i) {
            matrices = i == 0 ? this.gibbhVh(chainStart) : this.gibbhVh(nhSamples);
            nvMeans = matrices.getFirst().getFirst();
            nvSamples = matrices.getFirst().getSecond();
            nhMeans = matrices.getSecond().getFirst();
            nhSamples = matrices.getSecond().getSecond();
        }
        DoubleMatrix wGradient = this.input.transpose().mmul(probHidden.getSecond()).sub(nvSamples.transpose().mmul(nhMeans)).mul(learningRate);
        if (this.useRegularization) {
            wGradient.subi(this.W.muli(this.l2));
        }
        if (this.momentum != 0.0) {
            wGradient.muli(1.0 - this.momentum);
        }
        wGradient.divi((double)this.input.rows);
        DoubleMatrix hBiasGradient = null;
        hBiasGradient = this.sparsity != 0.0 ? MatrixUtil.mean(probHidden.getSecond().add(-this.sparsity), 0).mul(learningRate) : MatrixUtil.mean(probHidden.getSecond().sub(nhMeans), 0).mul(learningRate);
        DoubleMatrix vBiasGradient = MatrixUtil.mean(this.input.sub(nvSamples), 0).mul(learningRate);
        return new NeuralNetworkGradient(wGradient, vBiasGradient, hBiasGradient);
    }

    public Pair<DoubleMatrix, DoubleMatrix> sampleHiddenGivenVisible(DoubleMatrix v) {
        DoubleMatrix h1Mean = this.propUp(v);
        DoubleMatrix h1Sample = MatrixUtil.binomial(h1Mean, 1, this.rng);
        return new Pair<DoubleMatrix, DoubleMatrix>(h1Mean, h1Sample);
    }

    public Pair<Pair<DoubleMatrix, DoubleMatrix>, Pair<DoubleMatrix, DoubleMatrix>> gibbhVh(DoubleMatrix h) {
        Pair<DoubleMatrix, DoubleMatrix> v1MeanAndSample = this.sampleVGivenH(h);
        DoubleMatrix vSample = v1MeanAndSample.getSecond();
        Pair<DoubleMatrix, DoubleMatrix> h1MeanAndSample = this.sampleHiddenGivenVisible(vSample);
        return new Pair<Pair<DoubleMatrix, DoubleMatrix>, Pair<DoubleMatrix, DoubleMatrix>>(v1MeanAndSample, h1MeanAndSample);
    }

    public Pair<DoubleMatrix, DoubleMatrix> sampleVGivenH(DoubleMatrix h) {
        DoubleMatrix v1Mean = this.propDown(h);
        DoubleMatrix v1Sample = MatrixUtil.binomial(v1Mean, 1, this.rng);
        return new Pair<DoubleMatrix, DoubleMatrix>(v1Mean, v1Sample);
    }

    public DoubleMatrix propUp(DoubleMatrix v) {
        DoubleMatrix preSig = v.mmul(this.W).addiRowVector(this.hBias);
        return MatrixUtil.sigmoid(preSig);
    }

    public DoubleMatrix propDown(DoubleMatrix h) {
        DoubleMatrix preSig = h.mmul(this.W.transpose()).addRowVector(this.vBias);
        return MatrixUtil.sigmoid(preSig);
    }

    @Override
    public DoubleMatrix reconstruct(DoubleMatrix v) {
        return this.propDown(this.propUp(v));
    }

    @Override
    public void trainTillConvergence(DoubleMatrix input, double lr, Object[] params) {
        if (input != null) {
            this.input = input;
        }
        this.optimizer = new RBMOptimizer(this, lr, params);
        this.optimizer.train(input);
    }

    @Override
    public double lossFunction(Object[] params) {
        return this.getReConstructionCrossEntropy();
    }

    @Override
    public void train(DoubleMatrix input, double lr, Object[] params) {
        int k = (Integer)params[0];
        this.contrastiveDivergence(lr, k, input);
    }

    public static class Builder
    extends BaseNeuralNetwork.Builder<RBM> {
        public Builder() {
            this.clazz = RBM.class;
        }
    }
}

