/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.classifiers.dbn;

import java.util.List;
import java.util.Map;
import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.models.featuredetectors.rbm.RBM;
import org.deeplearning4j.nn.BaseMultiLayerNetwork;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.layers.Layer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.transformation.MatrixTransform;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DBN
extends BaseMultiLayerNetwork {
    private static final long serialVersionUID = -9068772752220902983L;
    private static Logger log = LoggerFactory.getLogger(DBN.class);
    private boolean useRBMPropUpAsActivations = true;

    @Override
    public Layer createHiddenLayer(int index, INDArray layerInput) {
        return (Layer)super.createHiddenLayer(index, layerInput);
    }

    @Override
    public void pretrain(DataSetIterator iter, Object[] otherParams) {
        if (!this.pretrain) {
            return;
        }
        int passes = otherParams.length > 3 ? (Integer)otherParams[3] : 1;
        for (int i = 0; i < passes; ++i) {
            this.pretrain(this.input, this.defaultConfiguration.getK(), this.defaultConfiguration.getLr(), this.defaultConfiguration.getNumIterations());
        }
    }

    @Override
    public void pretrain(INDArray input, Object[] otherParams) {
        this.pretrain(input, this.defaultConfiguration.getK(), this.defaultConfiguration.getLr(), this.defaultConfiguration.getNumIterations());
    }

    public void pretrain(DataSetIterator iter, int k, float learningRate, int epochs) {
        if (!this.pretrain) {
            return;
        }
        for (int i = 0; i < this.getnLayers(); ++i) {
            int epoch;
            float realLearningRate;
            DataSet next;
            if (i == 0) {
                while (iter.hasNext()) {
                    next = (DataSet)iter.next();
                    this.input = next.getFeatureMatrix();
                    if (this.getInput() == null || this.getNeuralNets() == null || this.getNeuralNets()[0] == null || this.getNeuralNets() == null || this.getNeuralNets()[0] == null) {
                        this.setInput(this.input);
                        this.initializeLayers(this.input);
                    } else {
                        this.setInput(this.input);
                    }
                    realLearningRate = ((NeuralNetConfiguration)this.layerWiseConfigurations.get(i)).getLr();
                    if (this.forceNumIterations()) {
                        for (epoch = 0; epoch < epochs; ++epoch) {
                            log.info("Error on iteration " + epoch + " for layer " + (i + 1) + " is " + this.getNeuralNets()[i].score());
                            this.getNeuralNets()[i].iterate(next.getFeatureMatrix(), new Object[]{k, Float.valueOf(learningRate)});
                            this.getNeuralNets()[i].iterationDone(epoch);
                        }
                        continue;
                    }
                    this.getNeuralNets()[i].fit(next.getFeatureMatrix(), new Object[]{k, Float.valueOf(realLearningRate), epochs});
                }
                iter.reset();
                continue;
            }
            while (iter.hasNext()) {
                next = (DataSet)iter.next();
                INDArray layerInput = next.getFeatureMatrix();
                for (int j = 1; j <= i; ++j) {
                    layerInput = this.activationFromPrevLayer(j, layerInput);
                }
                log.info("Training on layer " + (i + 1));
                realLearningRate = ((NeuralNetConfiguration)this.layerWiseConfigurations.get(i)).getLr();
                if (this.forceNumIterations()) {
                    for (epoch = 0; epoch < epochs; ++epoch) {
                        log.info("Error on epoch " + epoch + " for layer " + (i + 1) + " is " + this.getNeuralNets()[i].score());
                        this.getNeuralNets()[i].iterate(layerInput, new Object[]{k, Float.valueOf(learningRate)});
                        this.getNeuralNets()[i].iterationDone(epoch);
                    }
                    continue;
                }
                this.getNeuralNets()[i].fit(layerInput, new Object[]{k, Float.valueOf(realLearningRate), epochs});
            }
            iter.reset();
        }
    }

    public void pretrain(INDArray input, int k, float learningRate, int epochs) {
        if (!this.pretrain) {
            return;
        }
        if (this.isUseGaussNewtonVectorProductBackProp()) {
            log.warn("WARNING; Gauss newton back vector back propagation is primarily used for hessian free which does not involve pretrain; just finetune. Use this at your own risk");
        }
        if (this.getInput() == null || this.getNeuralNets() == null || this.getNeuralNets()[0] == null || this.getNeuralNets() == null || this.getNeuralNets()[0] == null) {
            this.setInput(input);
            this.initializeLayers(input);
        } else {
            this.setInput(input);
        }
        INDArray layerInput = null;
        for (int i = 0; i < this.getnLayers(); ++i) {
            layerInput = i == 0 ? this.getInput() : this.activationFromPrevLayer(i - 1, layerInput);
            log.info("Training on layer " + (i + 1));
            float realLearningRate = this.layers[i].conf().getLr();
            if (this.forceNumIterations()) {
                for (int epoch = 0; epoch < epochs; ++epoch) {
                    log.info("Error on epoch " + epoch + " for layer " + (i + 1) + " is " + this.getNeuralNets()[i].score());
                    this.getNeuralNets()[i].iterate(layerInput, new Object[]{k, Float.valueOf(learningRate)});
                    this.getNeuralNets()[i].iterationDone(epoch);
                }
                continue;
            }
            this.getNeuralNets()[i].fit(layerInput, new Object[]{k, Float.valueOf(realLearningRate), epochs});
        }
    }

    public void pretrain(int k, float learningRate, int epochs) {
        this.pretrain(this.getInput(), k, learningRate, epochs);
    }

    @Override
    public NeuralNetwork createLayer(INDArray input, INDArray W, INDArray hBias, INDArray vBias, int index) {
        RBM ret = (RBM)new RBM.Builder().withInput(input).withWeights(W).withHBias(hBias).withVisibleBias(vBias).configure((NeuralNetConfiguration)this.layerWiseConfigurations.get(index)).build();
        return ret;
    }

    @Override
    public NeuralNetwork[] createNetworkLayers(int numLayers) {
        return new RBM[numLayers];
    }

    @Override
    public void fit(INDArray data, Object[] params) {
        this.pretrain(data, this.defaultConfiguration.getK(), this.defaultConfiguration.getLr(), this.defaultConfiguration.getNumIterations());
    }

    public static class Builder
    extends BaseMultiLayerNetwork.Builder<DBN> {
        private boolean useRBMPropUpAsActivation = false;

        public Builder() {
            this.clazz = DBN.class;
        }

        public Builder configure(NeuralNetConfiguration conf) {
            super.configure(conf);
            return this;
        }

        public Builder useGaussNewtonVectorProductBackProp(boolean useGaussNewtonVectorProductBackProp) {
            super.useGaussNewtonVectorProductBackProp(useGaussNewtonVectorProductBackProp);
            return this;
        }

        public Builder useDropConnection(boolean useDropConnect) {
            super.useDropConnection(useDropConnect);
            return this;
        }

        public Builder layerWiseConfiguration(List<NeuralNetConfiguration> layerWiseConfiguration) {
            super.layerWiseConfiguration(layerWiseConfiguration);
            return this;
        }

        public Builder withVisibleBiasTransforms(Map<Integer, MatrixTransform> visibleBiasTransforms) {
            super.withVisibleBiasTransforms(visibleBiasTransforms);
            return this;
        }

        public Builder withHiddenBiasTransforms(Map<Integer, MatrixTransform> hiddenBiasTransforms) {
            super.withHiddenBiasTransforms(hiddenBiasTransforms);
            return this;
        }

        public Builder forceIterations() {
            this.shouldForceEpochs = true;
            return this;
        }

        public Builder disableBackProp() {
            this.backProp = false;
            return this;
        }

        public Builder transformWeightsAt(int layer, MatrixTransform transform) {
            this.weightTransforms.put(layer, transform);
            return this;
        }

        public Builder transformWeightsAt(Map<Integer, MatrixTransform> transforms) {
            this.weightTransforms.putAll(transforms);
            return this;
        }

        public Builder hiddenLayerSizes(Integer[] hiddenLayerSizes) {
            super.hiddenLayerSizes(hiddenLayerSizes);
            return this;
        }

        public Builder hiddenLayerSizes(int[] hiddenLayerSizes) {
            super.hiddenLayerSizes(hiddenLayerSizes);
            return this;
        }

        public Builder withInput(INDArray input) {
            super.withInput(input);
            return this;
        }

        public Builder withLabels(INDArray labels) {
            super.withLabels(labels);
            return this;
        }

        public Builder withClazz(Class<? extends BaseMultiLayerNetwork> clazz) {
            this.clazz = clazz;
            return this;
        }

        public Builder pretrain(boolean pretrain) {
            this.pretrain = pretrain;
            return this;
        }

        @Override
        public DBN build() {
            DBN ret = (DBN)super.build();
            ret.useRBMPropUpAsActivations = this.useRBMPropUpAsActivation;
            ret.initializeLayers(Nd4j.zeros((int)1, (int)ret.defaultConfiguration.getnIn()));
            return ret;
        }
    }
}

