/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.LossFunction;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.LossCalculation;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public abstract class BasePretrainNetwork<LayerConfT extends org.deeplearning4j.nn.conf.layers.BasePretrainNetwork>
extends BaseLayer<LayerConfT> {
    public BasePretrainNetwork(NeuralNetConfiguration conf) {
        super(conf);
    }

    public BasePretrainNetwork(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
    }

    public INDArray getCorruptedInput(INDArray x, double corruptionLevel) {
        INDArray corrupted = Nd4j.getDistributions().createBinomial(1, 1.0 - corruptionLevel).sample(x.shape());
        corrupted.muli(x);
        return corrupted;
    }

    protected Gradient createGradient(INDArray wGradient, INDArray vBiasGradient, INDArray hBiasGradient) {
        DefaultGradient ret = new DefaultGradient();
        ret.gradientForVariable().put("vb", vBiasGradient);
        ret.gradientForVariable().put("b", hBiasGradient);
        ret.gradientForVariable().put("W", wGradient);
        return ret;
    }

    public abstract Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray var1);

    public abstract Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray var1);

    @Override
    protected void setScoreWithZ(INDArray z) {
        if (((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getLossFunction() == LossFunctions.LossFunction.CUSTOM) {
            LossFunction create = Nd4j.getOpFactory().createLossFunction(((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getCustomLossFunction(), this.input, z);
            create.exec();
            this.score = create.currentResult().doubleValue();
        } else {
            this.score = LossCalculation.builder().l1(this.calcL1()).l2(this.calcL2()).labels(this.input).z(z).lossFunction(((org.deeplearning4j.nn.conf.layers.BasePretrainNetwork)this.layerConf()).getLossFunction()).miniBatch(this.conf.isMiniBatch()).miniBatchSize(this.getInputMiniBatchSize()).useRegularization(this.conf.isUseRegularization()).build().score();
        }
    }
}

