/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public abstract class BasePretrainNetwork
extends BaseLayer {
    private static final long serialVersionUID = -7074102204433996574L;

    public BasePretrainNetwork(NeuralNetConfiguration conf) {
        super(conf);
    }

    public BasePretrainNetwork(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    @Override
    public void setScore() {
        if (this.input == null) {
            return;
        }
        INDArray output = this.transform(this.input);
        this.score = -LossFunctions.score((INDArray)this.input, (LossFunctions.LossFunction)this.conf.getLossFunction(), (INDArray)output, (double)this.conf.getL2(), (boolean)this.conf.isUseRegularization());
        if (this.conf.isMinimize()) {
            this.score = -this.score;
        }
    }

    public INDArray getCorruptedInput(INDArray x, double corruptionLevel) {
        INDArray corrupted = Nd4j.getDistributions().createBinomial(1, 1.0 - corruptionLevel).sample(x.shape());
        corrupted.muli(x);
        return corrupted;
    }

    protected Gradient createGradient(INDArray wGradient, INDArray vBiasGradient, INDArray hBiasGradient) {
        DefaultGradient ret = new DefaultGradient();
        ret.gradientForVariable().put("vb", vBiasGradient);
        ret.gradientForVariable().put("b", hBiasGradient);
        ret.gradientForVariable().put("W", wGradient);
        return ret;
    }

    public abstract Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray var1);

    public abstract Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray var1);
}

