/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.recurrent;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.optimize.Solver;
import org.deeplearning4j.util.Dropout;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.ops.transforms.Transforms;

public class ImageLSTM
extends BaseLayer<org.deeplearning4j.nn.conf.layers.ImageLSTM> {
    private INDArray iFogZ;
    private INDArray iFogA;
    private INDArray memCellActivations;
    private INDArray hIn;
    private INDArray hOut;
    private INDArray outputActivations;
    private INDArray u;
    private INDArray u2;
    private INDArray xi;
    private INDArray xs;

    public ImageLSTM(NeuralNetConfiguration conf) {
        super(conf);
        throw new UnsupportedOperationException("Layer disabled: Version in development and will be provided in a later release.");
    }

    public ImageLSTM(NeuralNetConfiguration conf, INDArray input) {
        super(conf, input);
        throw new UnsupportedOperationException("Layer disabled: Version in development and will be provided in a later release.");
    }

    public void setInput(INDArray xi, INDArray xs) {
        this.xi = xi;
        this.xs = xs;
        this.setInput(Nd4j.vstack((INDArray[])new INDArray[]{xi, xs}));
    }

    public Pair<Gradient, INDArray> backpropGradient(Gradient gradient, INDArray esilon) {
        INDArray inputWeights = this.getParam("W");
        INDArray recurrentWeights = this.getParam("RW");
        INDArray dHin = Nd4j.zeros((int[])this.hIn.shape());
        INDArray dX = Nd4j.zeros((int[])this.input.shape());
        INDArray delta = gradient.getGradientFor("b");
        INDArray inputWeightGradients = this.hOut.transpose().mul(delta);
        INDArray biasGradients = Nd4j.sum((INDArray)delta, (int)0);
        INDArray dHout = inputWeights.mul(delta);
        dHout = Nd4j.vstack((INDArray[])new INDArray[]{Nd4j.zeros((int)dHout.columns()), dHout});
        if (this.conf.isUseDropConnect() & this.conf.getLayer().getDropOut() > 0.0) {
            dHout.muli(this.u2);
        }
        INDArray dIFogZ = Nd4j.zeros((int[])this.iFogZ.shape());
        INDArray dIFogA = Nd4j.zeros((int[])this.iFogA.shape());
        INDArray recurrentWeightGradients = Nd4j.zeros((int[])recurrentWeights.shape());
        INDArray dC = Nd4j.zeros((int[])this.memCellActivations.shape());
        int sequenceLen = this.hOut.rows();
        int hiddenLayerSize = this.hOut.columns();
        for (int t = sequenceLen - 1; t > 0; --t) {
            if (this.conf.getLayer().getActivationFunction().equals("tanh")) {
                INDArray tanhCt = Transforms.tanh((INDArray)this.memCellActivations.slice(t));
                dIFogA.slice(t).put(new INDArrayIndex[]{NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))}, tanhCt.mul(dHout.slice(t)));
                dC.slice(t).addi(Transforms.pow((INDArray)tanhCt, (Number)2).rsubi((Number)1).muli(this.iFogA.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))}).mul(dHout.slice(t))));
            } else {
                dIFogA.slice(t).put(new INDArrayIndex[]{NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))}, this.memCellActivations.slice(t).mul(dHout.slice(t)));
                dC.slice(t).addi(this.iFogA.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))}).mul(dHout.slice(t)));
            }
            if (t > 0) {
                dIFogA.slice(t).put(new INDArrayIndex[]{NDArrayIndex.interval((int)hiddenLayerSize, (int)(2 * hiddenLayerSize))}, this.memCellActivations.slice(t - 1).mul(dC.slice(t)));
                dC.slice(t - 1).addi(this.iFogA.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)hiddenLayerSize, (int)(2 * hiddenLayerSize))}).mul(dC.slice(t)));
            }
            dIFogA.slice(t).put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)hiddenLayerSize)}, this.iFogA.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)this.iFogA.columns())}).mul(dC.slice(t)));
            dIFogA.slice(t).put(new INDArrayIndex[]{NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)dIFogA.columns())}, this.iFogA.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)hiddenLayerSize)}).mul(dC.slice(t)));
            dIFogZ.slice(t).put(new INDArrayIndex[]{NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)dIFogZ.columns())}, Transforms.pow((INDArray)this.iFogA.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)this.iFogA.columns())}), (Number)2).rsubi((Number)1).mul(dIFogA.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)dIFogA.columns())})));
            INDArray activations = this.iFogA.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(3 * hiddenLayerSize))});
            dIFogZ.slice(t).put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(3 * hiddenLayerSize))}, activations.mul(activations.rsub((Number)1)).mul(dIFogA.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(3 * hiddenLayerSize))})));
            recurrentWeightGradients.addi(this.hIn.slice(t).transpose().mmul(dIFogZ.slice(t)));
            dHin.slice(t).assign(dIFogZ.slice(t).mmul(recurrentWeights.transpose()));
            INDArray get = dHin.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)1, (int)(1 + hiddenLayerSize))});
            dX.slice(t).assign(get);
            if (t > 0) {
                dHout.slice(t - 1).addi(dHin.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)(1 + hiddenLayerSize), (int)dHin.columns())}));
            }
            if (!(this.conf.isUseDropConnect() & this.conf.getLayer().getDropOut() > 0.0)) continue;
            dX.muli(this.u);
        }
        this.clear();
        DefaultGradient retGradient = new DefaultGradient();
        retGradient.gradientForVariable().put("W", inputWeightGradients);
        retGradient.gradientForVariable().put("RW", recurrentWeightGradients);
        retGradient.gradientForVariable().put("b", biasGradients);
        return new Pair<Gradient, INDArray>(retGradient, dHout);
    }

    @Override
    public INDArray activate(boolean training) {
        INDArray decoderWeights = this.getParam("W");
        INDArray recurrentWeights = this.getParam("RW");
        INDArray decoderBias = this.getParam("b");
        if (this.conf.getLayer().getDropOut() > 0.0) {
            double scale = 1.0 / (1.0 - this.conf.getLayer().getDropOut());
            this.u = Nd4j.rand((int[])this.input.shape()).lti((Number)(1.0 - this.conf.getLayer().getDropOut())).muli((Number)scale);
            this.input.muli(this.u);
        }
        int sequenceLen = this.input.size(0);
        int hiddenLayerSize = decoderWeights.size(0);
        int recurrentSize = recurrentWeights.size(0);
        this.hIn = Nd4j.zeros((int)sequenceLen, (int)recurrentSize);
        this.hOut = Nd4j.zeros((int)sequenceLen, (int)hiddenLayerSize);
        this.iFogZ = Nd4j.zeros((int)sequenceLen, (int)(hiddenLayerSize * 4));
        this.iFogA = Nd4j.zeros((int[])this.iFogZ.shape());
        this.memCellActivations = Nd4j.zeros((int)sequenceLen, (int)hiddenLayerSize);
        for (int t = 0; t < sequenceLen; ++t) {
            INDArray prevOutputActivations = t == 0 ? Nd4j.zeros((int)hiddenLayerSize) : this.hOut.slice(t - 1);
            INDArray prevMemCellActivations = t == 0 ? Nd4j.zeros((int)hiddenLayerSize) : this.memCellActivations.slice(t - 1);
            this.hIn.slice(t).put(t, 0, (Number)1);
            this.hIn.slice(t).put(new INDArrayIndex[]{NDArrayIndex.interval((int)1, (int)(1 + hiddenLayerSize)), NDArrayIndex.interval((int)t, (int)(t + 1))}, this.input.slice(t));
            this.hIn.slice(t).put(new INDArrayIndex[]{NDArrayIndex.interval((int)(1 + hiddenLayerSize), (int)this.hIn.columns()), NDArrayIndex.interval((int)0, (int)1)}, prevOutputActivations);
            this.iFogZ.slice(t).put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(hiddenLayerSize * 4)), NDArrayIndex.interval((int)0, (int)1)}, this.hIn.slice(t).mmul(recurrentWeights));
            this.iFogA.slice(t).put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(3 * hiddenLayerSize))}, Transforms.sigmoid((INDArray)this.iFogZ.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(3 * hiddenLayerSize))})));
            this.iFogA.slice(t).put(new INDArrayIndex[]{NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)(this.iFogA.columns() - 1))}, Transforms.tanh((INDArray)this.iFogZ.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)(this.iFogZ.columns() - 1))})));
            this.memCellActivations.slice(t).put(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)hiddenLayerSize)}, this.iFogA.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)hiddenLayerSize)}).mul(this.iFogA.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)(3 * hiddenLayerSize), (int)this.iFogA.columns())})));
            if (t > 0) {
                this.memCellActivations.slice(t).addi(this.iFogA.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)hiddenLayerSize, (int)(2 * hiddenLayerSize))}).mul(prevMemCellActivations));
            }
            if (this.conf.getLayer().getActivationFunction().equals("tanh")) {
                this.hOut.slice(t).assign(this.iFogA.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))}).mul(Transforms.tanh((INDArray)this.memCellActivations.slice(t))));
                continue;
            }
            this.hOut.slice(t).assign(this.iFogA.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)(2 * hiddenLayerSize), (int)(3 * hiddenLayerSize))}).mul(this.memCellActivations.slice(t)));
        }
        if (this.conf.isUseDropConnect() && training && this.conf.getLayer().getDropOut() > 0.0) {
            this.u2 = Dropout.applyDropout(this.hOut, this.conf.getLayer().getDropOut(), this.u2);
            this.hOut.muli(this.u2);
        }
        this.outputActivations = this.hOut.get(new INDArrayIndex[]{NDArrayIndex.interval((int)1, (int)this.hOut.rows())}).mmul(decoderWeights).addiRowVector(decoderBias);
        return this.outputActivations;
    }

    public Collection<Pair<List<Integer>, Double>> predict(INDArray xi, INDArray ws) {
        INDArray decoderWeights = this.getParam("W");
        int d = decoderWeights.rows();
        Triple<INDArray, INDArray, INDArray> yhc = this.lstmTick(xi, Nd4j.zeros((int)d), Nd4j.zeros((int)d));
        BeamSearch search = new BeamSearch(20, ws, yhc.getSecond(), yhc.getThird());
        return search.search();
    }

    @Override
    public void clear() {
        this.hIn = null;
        this.input = null;
        this.iFogZ = null;
        this.iFogA = null;
        this.u = null;
        this.u2 = null;
        this.memCellActivations = null;
        this.outputActivations = null;
    }

    private Pair<Integer, Double> yMax(INDArray y) {
        INDArray y1 = y.linearView();
        double max = y.max(new int[]{Integer.MAX_VALUE}).getDouble(0);
        INDArray e1 = Transforms.exp((INDArray)y1.rsub((Number)max));
        INDArray p1 = e1.divi(e1.sum(new int[]{Integer.MAX_VALUE}));
        y1 = Transforms.log((INDArray)p1.addi((Number)Nd4j.EPS_THRESHOLD));
        INDArray[] sorted = Nd4j.sortWithIndices((INDArray)y1, (int)0, (boolean)true);
        int ix = sorted[0].getInt(new int[]{0});
        return new Pair<Integer, Double>(ix, sorted[1].getDouble(ix));
    }

    private Triple<INDArray, INDArray, INDArray> lstmTick(INDArray x, INDArray hPrev, INDArray cPrev) {
        INDArray decoderWeights = this.getParam("W");
        INDArray recurrentWeights = this.getParam("RW");
        INDArray decoderBias = this.getParam("b");
        int t = 0;
        int d = decoderWeights.rows();
        INDArray hIn = Nd4j.zeros((int)1, (int)recurrentWeights.rows());
        hIn.putRow(0, Nd4j.ones((int)hIn.columns()));
        hIn.slice(t).put(new INDArrayIndex[]{NDArrayIndex.interval((int)1, (int)(1 + d))}, x);
        hIn.slice(t).put(new INDArrayIndex[]{NDArrayIndex.interval((int)(1 + d), (int)hIn.columns())}, hPrev);
        INDArray iFog = Nd4j.zeros((int)1, (int)(d * 4));
        INDArray iFogf = Nd4j.zeros((int[])iFog.shape());
        INDArray c = Nd4j.zeros((int)d);
        iFog.putScalar(t, hIn.slice(t).mmul(recurrentWeights).getDouble(0));
        INDArrayIndex[] indices = new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(3 * d))};
        iFogf.slice(t).put(indices, Transforms.sigmoid((INDArray)this.iFogA.slice(t).get(indices)));
        INDArrayIndex[] after = new INDArrayIndex[]{NDArrayIndex.interval((int)(3 * d), (int)iFogf.columns())};
        iFogf.slice(t).put(after, Transforms.tanh((INDArray)iFogf.slice(t).get(after)));
        c.slice(t).assign(iFogf.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)0, (int)d)}).mul(iFogf.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)(3 * d), (int)iFogf.columns())})).addi(iFogf.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)d, (int)(2 * d))})).muli(cPrev));
        if (this.conf.getLayer().getActivationFunction().equals("tanh")) {
            this.outputActivations.slice(t).assign(iFogf.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)(2 * d), (int)(3 * d))}).mul(Transforms.tanh((INDArray)c.slice(t))));
        } else {
            this.outputActivations.slice(t).assign(iFogf.slice(t).get(new INDArrayIndex[]{NDArrayIndex.interval((int)(2 * d), (int)(3 * d))}).mul(c.slice(t)));
        }
        INDArray y = this.outputActivations.mmul(decoderWeights).addiRowVector(decoderBias);
        return new Triple<INDArray, INDArray, INDArray>(y, this.outputActivations, c);
    }

    @Override
    public double calcL2() {
        if (!this.conf.isUseRegularization() || this.conf.getL2() <= 0.0) {
            return 0.0;
        }
        double l2 = Transforms.pow((INDArray)this.getParam("RW"), (Number)2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.pow((INDArray)this.getParam("W"), (Number)2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0);
        return 0.5 * this.conf.getL2() * l2;
    }

    @Override
    public double calcL1() {
        if (!this.conf.isUseRegularization() || this.conf.getL1() <= 0.0) {
            return 0.0;
        }
        double l1 = Transforms.abs((INDArray)this.getParam("RW")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0) + Transforms.abs((INDArray)this.getParam("W")).sum(new int[]{Integer.MAX_VALUE}).getDouble(0);
        return this.conf.getL1() * l1;
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.RECURRENT;
    }

    @Override
    public Layer transpose() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public void fit(INDArray data) {
        this.xi = data.slice(0);
        INDArrayIndex[] everythingElse = new INDArrayIndex[]{NDArrayIndex.interval((int)1, (int)data.rows()), NDArrayIndex.interval((int)0, (int)data.columns())};
        this.xs = data.get(everythingElse);
        Solver solver = new Solver.Builder().configure(this.conf).model(this).listeners(this.getListeners()).build();
        solver.optimize();
    }

    @Override
    public int batchSize() {
        return this.xi.rows();
    }

    private static class Beam {
        private double logProba = 0.0;
        private List<Integer> indices;
        private INDArray hidden;
        private INDArray c;

        public Beam(double logProba, List<Integer> indices, INDArray hidden, INDArray c) {
            this.logProba = logProba;
            this.indices = indices;
            this.hidden = hidden;
            this.c = c;
        }

        public double getLogProba() {
            return this.logProba;
        }

        public void setLogProba(double logProba) {
            this.logProba = logProba;
        }

        public List<Integer> getIndices() {
            return this.indices;
        }

        public void setIndices(List<Integer> indices) {
            this.indices = indices;
        }

        public INDArray getHidden() {
            return this.hidden;
        }

        public void setHidden(INDArray hidden) {
            this.hidden = hidden;
        }

        public INDArray getC() {
            return this.c;
        }

        public void setC(INDArray c) {
            this.c = c;
        }
    }

    private class BeamSearch {
        private List<Beam> beams = new ArrayList<Beam>();
        private int nSteps = 0;
        private INDArray h;
        private INDArray c;
        private INDArray ws;
        private int beamSize = 5;

        public BeamSearch(int nSteps, INDArray ws, INDArray h, INDArray c) {
            this.nSteps = nSteps;
            this.h = h;
            this.c = c;
            this.ws = ws;
            this.beams.add(new Beam(0.0, new ArrayList<Integer>(), h, c));
        }

        public Collection<Pair<List<Integer>, Double>> search() {
            if (this.beamSize > 1) {
                do {
                    ArrayList<Beam> candidates = new ArrayList<Beam>();
                    for (Beam beam : this.beams) {
                        int ixPrev = beam.getIndices().get(beam.getIndices().size() - 1);
                        if (ixPrev == 0 && !beam.getIndices().isEmpty()) {
                            candidates.add(beam);
                            continue;
                        }
                        Triple yhc = ImageLSTM.this.lstmTick(this.ws.slice(ixPrev), beam.getHidden(), beam.getC());
                        INDArray y1 = ((INDArray)yhc.getFirst()).ravel();
                        double maxy1 = y1.max(new int[]{Integer.MAX_VALUE}).getDouble(0);
                        INDArray e1 = Transforms.exp((INDArray)y1.subi((Number)maxy1));
                        INDArray p1 = e1.divi(Nd4j.sum((INDArray)e1, (int)Integer.MAX_VALUE));
                        y1 = Transforms.log((INDArray)p1.addi((Number)Nd4j.EPS_THRESHOLD));
                        INDArray[] topIndices = Nd4j.sortWithIndices((INDArray)y1, (int)0, (boolean)false);
                        int i = 0;
                        while (i < this.beamSize) {
                            int idx = topIndices[0].getInt(new int[]{i++});
                            ArrayList<Integer> beamCopy = new ArrayList<Integer>(beam.getIndices());
                            beamCopy.add(idx);
                            candidates.add(new Beam(beam.getLogProba() + y1.getDouble(idx), beamCopy, (INDArray)yhc.getSecond(), (INDArray)yhc.getThird()));
                        }
                    }
                    ++this.nSteps;
                } while (this.nSteps < 20);
                ArrayList<Pair<List<Integer>, Double>> ret = new ArrayList<Pair<List<Integer>, Double>>();
                for (Beam b : this.beams) {
                    ret.add(new Pair<List<Integer>, Double>(b.getIndices(), b.getLogProba()));
                }
                return ret;
            }
            int ixPrev = 0;
            double predictedLogProba = 0.0;
            ArrayList predix = new ArrayList();
            do {
                Triple yhc = ImageLSTM.this.lstmTick(this.ws.slice(ixPrev), this.h, this.c);
                Pair yMax = ImageLSTM.this.yMax((INDArray)yhc.getFirst());
                predix.add(yMax.getFirst());
                predictedLogProba += ((Double)yMax.getSecond()).doubleValue();
                ++this.nSteps;
            } while (ixPrev != 0 && this.nSteps < 20);
            return Collections.singletonList(new Pair(predix, predictedLogProba));
        }
    }
}

