/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.optimize.optimizers.rbm;

import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.optimize.optimizers.NeuralNetworkOptimizer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class RBMOptimizer
extends NeuralNetworkOptimizer {
    private static final long serialVersionUID = 3676032651650426749L;
    protected int k = -1;
    protected int numTimesIterated = 0;

    public RBMOptimizer(BaseNeuralNetwork network, float lr, Object[] trainingParams, NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm, LossFunctions.LossFunction lossFunction) {
        super(network, lr, trainingParams, optimizationAlgorithm, lossFunction);
        if (this.extraParams.length == 1 && this.extraParams[0] == null) {
            this.extraParams[0] = 1;
        }
    }

    @Override
    public INDArray getValueGradient(int iteration) {
        int k = this.extraParams != null && this.extraParams.length < 1 ? 1 : (Integer)this.extraParams[0];
        ++this.numTimesIterated;
        if (this.k <= 0) {
            this.k = k;
        }
        if (this.numTimesIterated % 10 == 0) {
            ++this.k;
        }
        return super.getValueGradient(iteration);
    }
}

