/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.plot;

import com.google.common.base.Function;
import com.google.common.util.concurrent.AtomicDouble;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.clustering.quadtree.QuadTree;
import org.deeplearning4j.clustering.vptree.VpTreeNode;
import org.deeplearning4j.clustering.vptree.VpTreePoint;
import org.deeplearning4j.clustering.vptree.VpTreePointINDArray;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.plot.Tsne;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.ops.transforms.Transforms;

public class BarnesHutTsne
extends Tsne
implements Model {
    private int n;
    private int d;
    private double perplexity;
    private double theta;
    private INDArray rows;
    private INDArray cols;
    private INDArray vals;
    private INDArray p;
    private INDArray x;
    private int numDimensions = 0;
    public static final String Y_GRAD = "yIncs";

    public BarnesHutTsne(INDArray x, int n, int d, INDArray y, int numDimensions, double perplexity, double theta, int maxIter, int stopLyingIteration, int momentumSwitchIteration, double momentum, double finalMomentum, double learningRate) {
        this.n = n;
        this.d = d;
        this.y = y;
        this.x = x;
        this.numDimensions = numDimensions;
        this.perplexity = perplexity;
        this.theta = theta;
        this.maxIter = maxIter;
        this.stopLyingIteration = stopLyingIteration;
        this.momentum = momentum;
        this.finalMomentum = finalMomentum;
        this.learningRate = learningRate;
        this.switchMomentumIteration = momentumSwitchIteration;
    }

    @Override
    public INDArray computeGaussianPerplexity(final INDArray d, double u) {
        int N = d.rows();
        final int k = (int)(3.0 * u);
        this.rows = Nd4j.zeros((int)(N + 1));
        this.cols = Nd4j.zeros((int)N, (int)k);
        this.vals = Nd4j.zeros((int)N, (int)k);
        for (int n = 1; n < N; ++n) {
            this.rows.putScalar(n, this.rows.getDouble(n - 1) + (double)k);
        }
        final INDArray beta = Nd4j.ones((int)N, (int)1);
        final double logU = Math.log(u);
        final List<VpTreePointINDArray> list = VpTreePointINDArray.dataPoints(d);
        final VpTreeNode<VpTreePointINDArray> tree = VpTreeNode.buildVpTree(list);
        log.info("Calculating probabilities of data similarities...");
        ExecutorService service = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        int i = 0;
        while (i < N) {
            if (i % 500 == 0) {
                log.info("Handled " + i + " records");
            }
            final int j = i++;
            service.submit(new Runnable(){

                @Override
                public void run() {
                    double betaMin = Double.NEGATIVE_INFINITY;
                    double betaMax = Double.POSITIVE_INFINITY;
                    Counter<VpTreePoint> c = tree.findNearByPointsWithDistancesK((VpTreePoint)list.get(j), k + 1);
                    INDArray row = d.slice(j);
                    Pair<INDArray, INDArray> pair = BarnesHutTsne.this.hBeta(row, BarnesHutTsne.this.toNDArray(c), beta.getDouble(j));
                    INDArray currP = pair.getSecond();
                    INDArray hDiff = pair.getFirst().sub((Number)logU);
                    for (int tries = 0; BooleanIndexing.and((INDArray)Transforms.abs((INDArray)hDiff), (Condition)Conditions.greaterThan((Number)BarnesHutTsne.this.tolerance)) && tries < 50; ++tries) {
                        if (BooleanIndexing.and((INDArray)hDiff, (Condition)Conditions.greaterThan((Number)0))) {
                            if (Double.isInfinite(betaMax)) {
                                beta.putScalar(j, beta.getDouble(j) * 2.0);
                            } else {
                                beta.putScalar(j, (beta.getDouble(j) + betaMax) / 2.0);
                            }
                            betaMin = beta.getDouble(j);
                        } else {
                            if (Double.isInfinite(betaMin)) {
                                beta.putScalar(j, beta.getDouble(j) / 2.0);
                            } else {
                                beta.putScalar(j, (beta.getDouble(j) + betaMin) / 2.0);
                            }
                            betaMax = beta.getDouble(j);
                        }
                        pair = BarnesHutTsne.this.hBeta(row, BarnesHutTsne.this.toNDArray(c), beta.getDouble(j));
                        hDiff = pair.getFirst().subi((Number)logU);
                    }
                    INDArray currPAssign = currP.div(currP.sum(Integer.MAX_VALUE));
                    INDArray indices = BarnesHutTsne.this.toIndex(c);
                    for (int i = 0; i < k; ++i) {
                        BarnesHutTsne.this.cols.putScalar(new int[]{BarnesHutTsne.this.rows.getInt(new int[]{BarnesHutTsne.this.n}), i}, indices.getDouble(i + 1));
                        BarnesHutTsne.this.vals.putScalar(new int[]{BarnesHutTsne.this.rows.getInt(new int[]{BarnesHutTsne.this.n}), i}, currPAssign.getDouble(i));
                    }
                    BarnesHutTsne.this.cols.slice(j).assign(BarnesHutTsne.this.toIndex(c));
                }
            });
        }
        try {
            service.shutdown();
            service.awaitTermination(1L, TimeUnit.DAYS);
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        return this.vals;
    }

    @Override
    public INDArray input() {
        return this.x;
    }

    @Override
    public void validateInput() {
    }

    @Override
    protected Pair<Double, INDArray> gradient(INDArray p) {
        this.p = p;
        return new Pair<Double, INDArray>(this.score(), this.getGradient().gradientLookupTable().get(Y_GRAD));
    }

    @Override
    public INDArray getYGradient(int n, INDArray PQ, INDArray qu) {
        INDArray yGrads = Nd4j.create((int[])this.y.shape());
        for (int i = 0; i < n; ++i) {
            INDArray sum1 = Nd4j.tile((INDArray)PQ.getRow(i).mul(qu.getRow(i)), (int[])new int[]{this.y.columns(), 1}).transpose().mul(this.y.getRow(i).broadcast(this.y.shape()).sub(this.y)).sum(0);
            yGrads.putRow(i, sum1);
        }
        return yGrads;
    }

    private INDArray toIndex(Counter<VpTreePointINDArray> counter) {
        INDArray ret = Nd4j.create((int)counter.size());
        List<VpTreePointINDArray> list = counter.getSortedKeys();
        for (int i = 0; i < list.size(); ++i) {
            ret.putScalar(i, list.get(i).getIndex());
        }
        return ret;
    }

    private INDArray toNDArray(Counter<VpTreePointINDArray> counter) {
        INDArray ret = Nd4j.create((int)counter.size());
        List<VpTreePointINDArray> list = counter.getSortedKeys();
        for (int i = 0; i < list.size(); ++i) {
            ret.putScalar(i, counter.getCount(list.get(i)));
        }
        return ret;
    }

    public Pair<INDArray, INDArray> hBeta(INDArray d, INDArray distances, double beta) {
        INDArray P = Transforms.exp((INDArray)d.neg().muli((Number)beta).muli(distances));
        INDArray sum = P.sum(Integer.MAX_VALUE);
        INDArray otherSum = d.mul(P).sum(0);
        INDArray H = Transforms.log((INDArray)sum).addi(otherSum.muli((Number)beta).muli(distances).divi(sum));
        P.divi(sum);
        return new Pair<INDArray, INDArray>(H, P);
    }

    @Override
    public void fit() {
        boolean exact;
        boolean bl = exact = this.theta == 0.0;
        if (exact) {
            this.y = super.calculate(this.x, this.numDimensions, this.perplexity);
        } else {
            INDArray p = this.computeGaussianPerplexity(this.x, this.perplexity);
            for (int i = 0; i < this.maxIter; ++i) {
                this.step(p, i);
                if (i == this.switchMomentumIteration) {
                    this.momentum = this.finalMomentum;
                }
                if (i == this.stopLyingIteration) {
                    p.divi((Number)4);
                }
                if (this.iterationListener == null) continue;
                this.iterationListener.iterationDone(i);
            }
        }
    }

    @Override
    public void update(Gradient gradient) {
    }

    @Override
    public double score() {
        int QT_NO_DIMS = 2;
        QuadTree tree = new QuadTree(this.y);
        INDArray buff = Nd4j.create((int)QT_NO_DIMS);
        AtomicDouble sum_Q = new AtomicDouble(0.0);
        for (int n = 0; n < this.y.rows(); ++n) {
            tree.computeNonEdgeForces(n, this.theta, buff, sum_Q);
        }
        double C = 0.0;
        for (int n = 0; n < this.y.rows(); ++n) {
            INDArray row1 = this.rows.slice(n);
            int begin = row1.getInt(new int[]{0});
            int end = row1.getInt(new int[]{1});
            for (int i = begin; i < end; ++i) {
                buff.assign(this.y.slice(n));
                buff.subi(this.cols.getRow(i));
                double Q = Nd4j.getBlasWrapper().dot(buff, buff);
                Q = 1.0 / (1.0 + Q) / sum_Q.doubleValue();
                double val = this.vals.getDouble(i, 0);
                C += val * Math.log((val + (double)1.4E-45f) / (Q + 3.4028234663852886E38));
            }
        }
        return C;
    }

    @Override
    public INDArray transform(INDArray data) {
        return null;
    }

    @Override
    public INDArray params() {
        return null;
    }

    @Override
    public int numParams() {
        return 0;
    }

    @Override
    public void setParams(INDArray params) {
    }

    @Override
    public void fit(INDArray data) {
    }

    @Override
    public void iterate(INDArray input) {
    }

    @Override
    public Gradient getGradient() {
        INDArray dC;
        if (this.yIncs == null) {
            this.yIncs = Nd4j.zeros((int[])this.y.shape());
        }
        if (this.gains == null) {
            this.gains = Nd4j.ones((int[])this.y.shape());
        }
        AtomicDouble sum_Q = new AtomicDouble(0.0);
        INDArray pos_f = Nd4j.create((int)this.p.rows(), (int)this.p.columns());
        INDArray neg_f = Nd4j.create((int)this.p.rows(), (int)this.p.columns());
        QuadTree quad = new QuadTree(this.p);
        quad.computeEdgeForces(this.rows, this.cols, this.p, this.p.rows(), pos_f);
        for (int n = 0; n < this.p.rows(); ++n) {
            quad.computeNonEdgeForces(n, this.theta, neg_f, sum_Q);
        }
        INDArray yGrads = dC = pos_f.subi(neg_f.divi((Number)sum_Q));
        this.gains = this.gains.add((Number)0.2).muli(yGrads.cond(Conditions.greaterThan((Number)0)).neqi(this.yIncs.cond(Conditions.greaterThan((Number)0)))).addi(this.gains.mul((Number)0.8).muli(yGrads.cond(Conditions.greaterThan((Number)0)).eqi(this.yIncs.cond(Conditions.greaterThan((Number)0)))));
        BooleanIndexing.applyWhere((INDArray)this.gains, (Condition)Conditions.lessThan((Number)this.minGain), (Function)new Value((Number)this.minGain));
        INDArray gradChange = this.gains.mul(yGrads);
        if (this.useAdaGrad) {
            gradChange = this.adaGrad.getGradient(gradChange);
        } else {
            gradChange.muli((Number)this.learningRate);
        }
        this.yIncs.muli((Number)this.momentum).subi(gradChange);
        DefaultGradient ret = new DefaultGradient();
        ret.gradientLookupTable().put(Y_GRAD, this.yIncs);
        return ret;
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return null;
    }

    @Override
    public int batchSize() {
        return 0;
    }

    @Override
    public NeuralNetConfiguration conf() {
        return null;
    }

    @Override
    public void setConf(NeuralNetConfiguration conf) {
    }
}

