/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.featuredetectors.da;

import java.io.Serializable;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.BaseNeuralNetwork;
import org.deeplearning4j.nn.api.NeuralNetwork;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.optimize.optimizers.da.DenoisingAutoEncoderOptimizer;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.deeplearning4j.util.MathUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public class DenoisingAutoEncoder
extends BaseNeuralNetwork
implements Serializable {
    private static final long serialVersionUID = -6445530486350763837L;

    private DenoisingAutoEncoder() {
    }

    public DenoisingAutoEncoder(INDArray input, INDArray W, INDArray hbias, INDArray vbias, NeuralNetConfiguration conf) {
        super(input, W, hbias, vbias, conf);
    }

    public INDArray getCorruptedInput(INDArray x, float corruptionLevel) {
        INDArray tilde_x = Nd4j.zeros((int)x.rows(), (int)x.columns());
        for (int i = 0; i < x.rows(); ++i) {
            for (int j = 0; j < x.columns(); ++j) {
                tilde_x.put(i, j, (Number)MathUtils.binomial(this.conf.getRng(), 1, 1.0f - corruptionLevel));
            }
        }
        INDArray ret = tilde_x.mul(x);
        return ret;
    }

    @Override
    public Pair<INDArray, INDArray> sampleHiddenGivenVisible(INDArray v) {
        INDArray ret = this.getHiddenValues(v);
        return new Pair<INDArray, INDArray>(ret, ret);
    }

    @Override
    public Pair<INDArray, INDArray> sampleVisibleGivenHidden(INDArray h) {
        INDArray ret = this.getReconstructedInput(h);
        return new Pair<INDArray, INDArray>(ret, ret);
    }

    @Override
    public INDArray hiddenActivation(INDArray input) {
        return this.getHiddenValues(input);
    }

    public INDArray getHiddenValues(INDArray x) {
        INDArray preAct;
        if (this.conf.isConcatBiases()) {
            INDArray concat = Nd4j.hstack((INDArray[])new INDArray[]{this.W, this.hBias.transpose()});
            preAct = x.mmul(concat);
        } else {
            preAct = x.mmul(this.W).addiRowVector(this.hBias);
        }
        INDArray ret = Transforms.sigmoid((INDArray)preAct);
        this.applyDropOutIfNecessary(ret);
        return ret;
    }

    public INDArray getReconstructedInput(INDArray y) {
        if (this.conf.isConcatBiases()) {
            INDArray preAct = y.mmul(this.W.transpose());
            preAct = Nd4j.hstack((INDArray[])new INDArray[]{preAct, Nd4j.ones((int)preAct.rows(), (int)1)});
            return Transforms.sigmoid((INDArray)preAct);
        }
        INDArray preAct = y.mmul(this.W.transpose());
        preAct.addiRowVector(this.vBias);
        return Transforms.sigmoid((INDArray)preAct);
    }

    public void train(INDArray x, float lr, float corruptionLevel, int iteration) {
        if (x != null) {
            this.input = x;
        }
        this.lastMiniBatchSize = x.rows();
        NeuralNetworkGradient gradient = this.getGradient(new Object[]{Float.valueOf(corruptionLevel), Float.valueOf(lr), iteration});
        this.vBias.addi(gradient.getvBiasGradient());
        this.W.addi(gradient.getwGradient());
        this.hBias.addi(gradient.gethBiasGradient());
    }

    @Override
    public INDArray transform(INDArray x) {
        INDArray y = this.getHiddenValues(x);
        return this.getReconstructedInput(y);
    }

    @Override
    public void fit(INDArray input, Object[] params) {
        if (input != null) {
            this.input = input;
        }
        this.lastMiniBatchSize = input.rows();
        this.optimizer = new DenoisingAutoEncoderOptimizer((NeuralNetwork)this, this.conf.getLr(), params, this.conf.getOptimizationAlgo(), this.conf.getLossFunction());
        this.optimizer.train(input);
    }

    @Override
    public void fit(INDArray data) {
        this.fit(data, null);
    }

    @Override
    public void iterate(INDArray input, Object[] params) {
        float corruptionLevel = this.conf.getCorruptionLevel();
        if (input != null) {
            this.input = this.preProcessInput(input);
        }
        this.lastMiniBatchSize = input.rows();
        NeuralNetworkGradient gradient = this.getGradient(new Object[]{Float.valueOf(corruptionLevel), Float.valueOf(this.conf.getLr()), 0});
        this.vBias.addi(gradient.getvBiasGradient());
        this.W.addi(gradient.getwGradient());
        this.hBias.addi(gradient.gethBiasGradient());
    }

    @Override
    public void iterationDone(int iteration) {
        int plotEpochs = this.conf.getRenderWeightsEveryNumEpochs();
        if (plotEpochs <= 0) {
            return;
        }
        if (iteration % plotEpochs == 0 || iteration == 0) {
            NeuralNetPlotter plotter = new NeuralNetPlotter();
            plotter.plotNetworkGradient(this, this.getGradient(new Object[]{0.3, 0.001, 1000}), this.getInput().rows());
        }
    }

    @Override
    public NeuralNetworkGradient getGradient(Object[] params) {
        float corruptionLevel = this.conf.getCorruptionLevel();
        float lr = this.conf.getLr();
        int iteration = this.conf.getNumIterations();
        if (this.wAdaGrad != null) {
            this.wAdaGrad.setMasterStepSize((double)lr);
        }
        if (this.hBiasAdaGrad != null) {
            this.hBiasAdaGrad.setMasterStepSize((double)lr);
        }
        if (this.vBiasAdaGrad != null) {
            this.vBiasAdaGrad.setMasterStepSize((double)lr);
        }
        INDArray corruptedX = this.getCorruptedInput(this.input, corruptionLevel);
        INDArray y = this.getHiddenValues(corruptedX);
        INDArray z = this.getReconstructedInput(y);
        INDArray visibleLoss = this.input.sub(z);
        INDArray hiddenLoss = this.conf.getSparsity() == 0.0f ? visibleLoss.mmul(this.W).mul(y).mul(y.rsub((Number)1)) : visibleLoss.mmul(this.W).mul(y).mul(y.add((Number)Float.valueOf(-this.conf.getSparsity())));
        INDArray wGradient = corruptedX.transpose().mmul(hiddenLoss).add(visibleLoss.transpose().mmul(y));
        INDArray hBiasGradient = hiddenLoss.mean(0);
        INDArray vBiasGradient = visibleLoss.mean(0);
        NeuralNetworkGradient gradient = new NeuralNetworkGradient(wGradient, vBiasGradient, hBiasGradient);
        this.updateGradientAccordingToParams(gradient, iteration, lr);
        return gradient;
    }

    public static class Builder
    extends BaseNeuralNetwork.Builder<DenoisingAutoEncoder> {
        public Builder() {
            this.clazz = DenoisingAutoEncoder.class;
        }

        public Builder withClazz(Class<? extends BaseNeuralNetwork> clazz) {
            super.withClazz(clazz);
            return this;
        }

        public Builder withInput(INDArray input) {
            super.withInput(input);
            return this;
        }

        public Builder withWeights(INDArray W) {
            super.withWeights(W);
            return this;
        }

        public Builder withVisibleBias(INDArray vBias) {
            super.withVisibleBias(vBias);
            return this;
        }

        public Builder withHBias(INDArray hBias) {
            super.withHBias(hBias);
            return this;
        }
    }
}

