/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.recurrent;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.berkeley.Triple;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.ops.transforms.Transforms;

public class LSTM
extends BaseLayer {
    private INDArray iFog;
    private INDArray iFogF;
    private INDArray c;
    private INDArray x;
    private INDArray hIn;
    private INDArray hOut;
    private INDArray u;
    private INDArray u2;
    private INDArray xi;
    private INDArray xs;

    public LSTM(NeuralNetConfiguration conf) {
        super(conf);
    }

    public INDArray forward(INDArray xi, INDArray xs) {
        this.xs = xs;
        this.xi = xi;
        this.x = Nd4j.vstack((INDArray[])new INDArray[]{xi, xs});
        return this.activate(this.x);
    }

    public Gradient backward(INDArray y) {
        INDArray decoderWeights = this.getParam("decoderweights");
        INDArray recurrentWeights = this.getParam("recurrentweights");
        INDArray dY = Nd4j.vstack((INDArray[])new INDArray[]{Nd4j.zeros((int)y.columns()), y});
        INDArray dWd = this.hOut.transpose().mmul(dY);
        INDArray dBd = Nd4j.sum((INDArray)dWd, (int)0);
        INDArray dHout = dY.mmul(decoderWeights.transpose());
        if (this.conf.getDropOut() > 0.0) {
            dHout.muli(this.u2);
        }
        INDArray dIFog = Nd4j.zeros((int[])this.iFog.shape());
        INDArray dIFogF = Nd4j.zeros((int[])this.iFogF.shape());
        INDArray dRecurrentWeights = Nd4j.zeros((int[])recurrentWeights.shape());
        INDArray dHin = Nd4j.zeros((int[])this.hIn.shape());
        INDArray dC = Nd4j.zeros((int[])this.c.shape());
        INDArray dx = Nd4j.zeros((int[])this.x.shape());
        int n = this.hOut.rows();
        int d = this.hOut.columns();
        for (int t = n - 1; t > 0; --t) {
            if (this.conf.getActivationFunction().equals("tanh")) {
                INDArray tanhCt = Transforms.tanh((INDArray)this.c.slice(t));
                dIFogF.slice(t).put(new NDArrayIndex[]{NDArrayIndex.interval((int)(2 * d), (int)(3 * d))}, tanhCt.mul(dHout.slice(t)));
                dC.slice(t).addi(Transforms.pow((INDArray)tanhCt, (Number)2).rsubi((Number)1).muli(this.iFogF.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)(2 * d), (int)(3 * d))}).mul(dHout.slice(t))));
            } else {
                dIFogF.slice(t).put(new NDArrayIndex[]{NDArrayIndex.interval((int)(2 * d), (int)(3 * d))}, this.c.slice(t).mul(dHout.slice(t)));
                dC.slice(t).addi(this.iFogF.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)(2 * d), (int)(3 * d))}).mul(dHout.slice(t)));
            }
            if (t > 0) {
                dIFogF.slice(t).put(new NDArrayIndex[]{NDArrayIndex.interval((int)d, (int)(2 * d))}, this.c.slice(t - 1).mul(dC.slice(t)));
                dC.slice(t - 1).addi(this.iFogF.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)d, (int)(2 * d))}).mul(dC.slice(t)));
            }
            dIFogF.slice(t).put(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)d)}, this.iFogF.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)(3 * d), (int)this.iFogF.columns())}).mul(dC.slice(t)));
            dIFogF.slice(t).put(new NDArrayIndex[]{NDArrayIndex.interval((int)(3 * d), (int)dIFogF.columns())}, this.iFogF.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)d)}).mul(dC.slice(t)));
            dIFog.slice(t).put(new NDArrayIndex[]{NDArrayIndex.interval((int)(3 * d), (int)dIFog.columns())}, Transforms.pow((INDArray)this.iFogF.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)(3 * d), (int)this.iFogF.columns())}), (Number)2).rsubi((Number)1).mul(dIFogF.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)(3 * d), (int)dIFogF.columns())})));
            y = this.iFogF.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(3 * d))});
            dIFogF.slice(t).put(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(3 * d))}, y.mul(y.rsub((Number)1)).mul(dIFogF.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(3 * d))})));
            dRecurrentWeights.addi(this.hIn.slice(t).transpose().mmul(dIFog.slice(t)));
            dHin.slice(t).assign(dIFog.slice(t).mmul(recurrentWeights.transpose()));
            INDArray get = dHin.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)1, (int)(1 + d))});
            dx.slice(t).assign(get);
            if (t > 0) {
                dHout.slice(t - 1).addi(dHin.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)(1 + d), (int)dHin.columns())}));
            }
            if (!(this.conf.getDropOut() > 0.0)) continue;
            dx.muli(this.u);
        }
        this.clear();
        DefaultGradient gradient = new DefaultGradient();
        gradient.gradientForVariable().put("decoderbias", dBd);
        gradient.gradientForVariable().put("decoderweights", dWd);
        gradient.gradientForVariable().put("recurrentweights", dRecurrentWeights);
        return gradient;
    }

    @Override
    public INDArray activate(INDArray input) {
        INDArray decoderWeights = this.getParam("decoderweights");
        INDArray recurrentWeights = this.getParam("recurrentweights");
        INDArray decoderBias = this.getParam("decoderbias");
        if (this.conf.getDropOut() > 0.0) {
            double scale = 1.0 / (1.0 - this.conf.getDropOut());
            this.u = Nd4j.rand((int[])this.x.shape()).lti((Number)(1.0 - this.conf.getDropOut())).muli((Number)scale);
            this.x.muli(this.u);
        }
        int n = this.x.rows();
        int d = decoderWeights.rows();
        this.hIn = Nd4j.zeros((int)n, (int)recurrentWeights.rows());
        this.hOut = Nd4j.zeros((int)n, (int)d);
        this.iFog = Nd4j.zeros((int)n, (int)(d * 4));
        this.iFogF = Nd4j.zeros((int[])this.iFog.shape());
        this.c = Nd4j.zeros((int)n, (int)d);
        for (int t = 0; t < n; ++t) {
            INDArray prev = t == 0 ? Nd4j.zeros((int)d) : this.hOut.getRow(t - 1);
            this.hIn.put(t, 0, (Number)1.0);
            this.hIn.slice(t).put(new NDArrayIndex[]{NDArrayIndex.interval((int)1, (int)(1 + d))}, this.x.slice(t));
            this.hIn.slice(t).put(new NDArrayIndex[]{NDArrayIndex.interval((int)(1 + d), (int)this.hIn.columns())}, prev);
            this.iFog.putRow(t, this.hIn.slice(t).mmul(recurrentWeights));
            this.iFogF.slice(t).put(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(3 * d))}, Transforms.sigmoid((INDArray)this.iFog.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(3 * d))})));
            this.iFogF.slice(t).put(new NDArrayIndex[]{NDArrayIndex.interval((int)(3 * d), (int)(this.iFogF.columns() - 1))}, Transforms.tanh((INDArray)this.iFog.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)(3 * d), (int)(this.iFog.columns() - 1))})));
            INDArray cPut = this.iFogF.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)d)}).mul(this.iFogF.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)(3 * d), (int)this.iFogF.columns())}));
            this.c.putRow(t, cPut);
            if (t > 0) {
                this.c.slice(t).addi(this.iFogF.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)d, (int)(2 * d))}).mul(this.c.getRow(t - 1)));
            }
            if (this.conf.getActivationFunction().equals("tanh")) {
                this.hOut.slice(t).assign(this.iFogF.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)(2 * d), (int)(3 * d))}).mul(Transforms.tanh((INDArray)this.c.getRow(t))));
                continue;
            }
            this.hOut.slice(t).assign(this.iFogF.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)(2 * d), (int)(3 * d))}).mul(this.c.getRow(t)));
        }
        if (this.conf.getDropOut() > 0.0) {
            double scale = 1.0 / (1.0 - this.conf.getDropOut());
            this.u2 = Nd4j.rand((int[])this.hOut.shape()).lti((Number)(1.0 - this.conf.getDropOut())).muli((Number)scale);
            this.hOut.muli(this.u2);
        }
        INDArray y = this.hOut.get(new NDArrayIndex[]{NDArrayIndex.interval((int)1, (int)this.hOut.rows())}).mmul(decoderWeights).addiRowVector(decoderBias);
        return y;
    }

    public Collection<Pair<List<Integer>, Double>> predict(INDArray xi, INDArray ws) {
        INDArray decoderWeights = this.getParam("decoderweights");
        int d = decoderWeights.rows();
        Triple<INDArray, INDArray, INDArray> yhc = this.lstmTick(xi, Nd4j.zeros((int)d), Nd4j.zeros((int)d));
        BeamSearch search = new BeamSearch(20, ws, yhc.getSecond(), yhc.getThird());
        return search.search();
    }

    @Override
    public void clear() {
        this.u = null;
        this.hIn = null;
        this.hOut = null;
        this.iFog = null;
        this.iFogF = null;
        this.c = null;
        this.x = null;
        this.u2 = null;
    }

    private Pair<Integer, Double> yMax(INDArray y) {
        INDArray y1 = y.linearView();
        double max = y.max(Integer.MAX_VALUE).getDouble(0);
        INDArray e1 = Transforms.exp((INDArray)y1.rsub((Number)max));
        INDArray p1 = e1.divi(e1.sum(Integer.MAX_VALUE));
        y1 = Transforms.log((INDArray)p1.addi((Number)Nd4j.EPS_THRESHOLD));
        INDArray[] sorted = Nd4j.sortWithIndices((INDArray)y1, (int)0, (boolean)true);
        int ix = sorted[0].getInt(new int[]{0});
        return new Pair<Integer, Double>(ix, sorted[1].getDouble(ix));
    }

    private Triple<INDArray, INDArray, INDArray> lstmTick(INDArray x, INDArray hPrev, INDArray cPrev) {
        INDArray decoderWeights = this.getParam("decoderweights");
        INDArray recurrentWeights = this.getParam("recurrentweights");
        INDArray decoderBias = this.getParam("decoderbias");
        int t = 0;
        int d = decoderWeights.rows();
        INDArray hIn = Nd4j.zeros((int)1, (int)recurrentWeights.rows());
        hIn.putRow(0, Nd4j.ones((int)hIn.columns()));
        hIn.slice(t).put(new NDArrayIndex[]{NDArrayIndex.interval((int)1, (int)(1 + d))}, x);
        hIn.slice(t).put(new NDArrayIndex[]{NDArrayIndex.interval((int)(1 + d), (int)hIn.columns())}, hPrev);
        INDArray iFog = Nd4j.zeros((int)1, (int)(d * 4));
        INDArray iFogf = Nd4j.zeros((int[])iFog.shape());
        INDArray c = Nd4j.zeros((int)d);
        iFog.putScalar(t, hIn.slice(t).mmul(recurrentWeights).getDouble(0));
        NDArrayIndex[] indices = new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)(3 * d))};
        iFogf.slice(t).put(indices, Transforms.sigmoid((INDArray)this.iFogF.slice(t).get(indices)));
        NDArrayIndex[] after = new NDArrayIndex[]{NDArrayIndex.interval((int)(3 * d), (int)iFogf.columns())};
        iFogf.slice(t).put(after, Transforms.tanh((INDArray)iFogf.slice(t).get(after)));
        c.slice(t).assign(iFogf.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)d)}).mul(iFogf.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)(3 * d), (int)iFogf.columns())})).addi(iFogf.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)d, (int)(2 * d))})).muli(cPrev));
        if (this.conf.getActivationFunction().equals("tanh")) {
            this.hOut.slice(t).assign(iFogf.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)(2 * d), (int)(3 * d))}).mul(Transforms.tanh((INDArray)c.slice(t))));
        } else {
            this.hOut.slice(t).assign(iFogf.slice(t).get(new NDArrayIndex[]{NDArrayIndex.interval((int)(2 * d), (int)(3 * d))}).mul(c.slice(t)));
        }
        INDArray y = this.hOut.mmul(decoderWeights).addiRowVector(decoderBias);
        return new Triple<INDArray, INDArray, INDArray>(y, this.hOut, c);
    }

    @Override
    public void fit() {
        Solver solver = new Solver.Builder().model(this).configure(this.conf()).listeners(this.getIterationListeners()).build();
        solver.optimize();
    }

    @Override
    public void update(Gradient gradient) {
        this.setParams(this.params().addi(gradient.gradient()));
    }

    @Override
    public double score() {
        INDArray forward = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", this.forward(this.xi, this.xs)).derivative(), 1);
        return LossFunctions.score((INDArray)this.xs, (LossFunctions.LossFunction)this.conf.getLossFunction(), (INDArray)forward, (double)this.conf.getL2(), (boolean)this.conf.isUseRegularization());
    }

    @Override
    public INDArray transform(INDArray data) {
        return Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", this.forward(this.xi, this.xs)).derivative(), 1);
    }

    @Override
    public void setParams(INDArray params) {
        int count = 0;
        INDArray decoderWeights = this.getParam("decoderweights");
        INDArray recurrentWeights = this.getParam("recurrentweights");
        INDArray decoderBias = this.getParam("decoderbias");
        INDArray recurrentWeightsLinear = recurrentWeights.linearView();
        INDArray decoderWeightsLinear = decoderWeights.linearView();
        INDArray decoderBiasLinear = decoderBias.linearView();
        int recurrentPlusDecoder = recurrentWeightsLinear.length() + decoderWeightsLinear.length();
        boolean pastRecurrentWeights = false;
        for (int i = 0; i < params.length(); ++i) {
            if (count == recurrentWeightsLinear.length()) {
                count = 0;
                pastRecurrentWeights = true;
            } else if (count == decoderWeightsLinear.length() && pastRecurrentWeights) {
                count = 0;
            }
            if (i < recurrentWeights.length()) {
                recurrentWeights.linearView().putScalar(count++, params.getDouble(i));
                continue;
            }
            if (i < recurrentPlusDecoder) {
                decoderWeightsLinear.putScalar(count++, params.getDouble(i));
                continue;
            }
            decoderBiasLinear.putScalar(count++, params.getDouble(i));
        }
    }

    @Override
    public void fit(INDArray data) {
        this.xi = data.slice(0);
        NDArrayIndex[] everythingElse = new NDArrayIndex[]{NDArrayIndex.interval((int)1, (int)data.rows()), NDArrayIndex.interval((int)0, (int)data.columns())};
        this.xs = data.get(everythingElse);
        Solver solver = new Solver.Builder().configure(this.conf).model(this).listeners(this.getIterationListeners()).build();
        solver.optimize();
    }

    @Override
    public void iterate(INDArray input) {
    }

    @Override
    public Gradient gradient() {
        INDArray forward = this.forward(this.xi, this.xs);
        INDArray probas = Nd4j.getExecutioner().execAndReturn(Nd4j.getOpFactory().createTransform("softmax", forward).derivative(), 1);
        return this.backward(probas);
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<Gradient, Double>(this.gradient(), this.score());
    }

    @Override
    public int batchSize() {
        return this.xi.rows();
    }

    private static class Beam {
        private double logProba = 0.0;
        private List<Integer> indices;
        private INDArray hidden;
        private INDArray c;

        public Beam(double logProba, List<Integer> indices, INDArray hidden, INDArray c) {
            this.logProba = logProba;
            this.indices = indices;
            this.hidden = hidden;
            this.c = c;
        }

        public double getLogProba() {
            return this.logProba;
        }

        public void setLogProba(double logProba) {
            this.logProba = logProba;
        }

        public List<Integer> getIndices() {
            return this.indices;
        }

        public void setIndices(List<Integer> indices) {
            this.indices = indices;
        }

        public INDArray getHidden() {
            return this.hidden;
        }

        public void setHidden(INDArray hidden) {
            this.hidden = hidden;
        }

        public INDArray getC() {
            return this.c;
        }

        public void setC(INDArray c) {
            this.c = c;
        }
    }

    private class BeamSearch {
        private List<Beam> beams = new ArrayList<Beam>();
        private int nSteps = 0;
        private INDArray h;
        private INDArray c;
        private INDArray ws;
        private int beamSize = 5;

        public BeamSearch(int nSteps, INDArray ws, INDArray h, INDArray c) {
            this.nSteps = nSteps;
            this.h = h;
            this.c = c;
            this.ws = ws;
            this.beams.add(new Beam(0.0, new ArrayList<Integer>(), h, c));
        }

        public Collection<Pair<List<Integer>, Double>> search() {
            if (this.beamSize > 1) {
                do {
                    ArrayList<Beam> candidates = new ArrayList<Beam>();
                    for (Beam beam : this.beams) {
                        int ixPrev = beam.getIndices().get(beam.getIndices().size() - 1);
                        if (ixPrev == 0 && !beam.getIndices().isEmpty()) {
                            candidates.add(beam);
                            continue;
                        }
                        Triple yhc = LSTM.this.lstmTick(this.ws.slice(ixPrev), beam.getHidden(), beam.getC());
                        INDArray y1 = ((INDArray)yhc.getFirst()).ravel();
                        double maxy1 = y1.max(Integer.MAX_VALUE).getDouble(0);
                        INDArray e1 = Transforms.exp((INDArray)y1.subi((Number)maxy1));
                        INDArray p1 = e1.divi(Nd4j.sum((INDArray)e1, (int)Integer.MAX_VALUE));
                        y1 = Transforms.log((INDArray)p1.addi((Number)Nd4j.EPS_THRESHOLD));
                        INDArray[] topIndices = Nd4j.sortWithIndices((INDArray)y1, (int)0, (boolean)false);
                        int i = 0;
                        while (i < this.beamSize) {
                            int idx = topIndices[0].getInt(new int[]{i++});
                            ArrayList<Integer> beamCopy = new ArrayList<Integer>(beam.getIndices());
                            beamCopy.add(idx);
                            candidates.add(new Beam(beam.getLogProba() + y1.getDouble(idx), beamCopy, (INDArray)yhc.getSecond(), (INDArray)yhc.getThird()));
                        }
                    }
                    ++this.nSteps;
                } while (this.nSteps < 20);
                ArrayList<Pair<List<Integer>, Double>> ret = new ArrayList<Pair<List<Integer>, Double>>();
                for (Beam b : this.beams) {
                    ret.add(new Pair<List<Integer>, Double>(b.getIndices(), b.getLogProba()));
                }
                return ret;
            }
            int ixPrev = 0;
            double predictedLogProba = 0.0;
            ArrayList predix = new ArrayList();
            do {
                Triple yhc = LSTM.this.lstmTick(this.ws.slice(ixPrev), this.h, this.c);
                Pair yMax = LSTM.this.yMax((INDArray)yhc.getFirst());
                predix.add(yMax.getFirst());
                predictedLogProba += ((Double)yMax.getSecond()).doubleValue();
                ++this.nSteps;
            } while (ixPrev != 0 && this.nSteps < 20);
            return Collections.singletonList(new Pair(predix, predictedLogProba));
        }
    }
}

