/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.plot;

import com.google.common.base.Function;
import com.google.common.util.concurrent.AtomicDouble;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.sptree.SpTree;
import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.plot.Tsne;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.learning.AdaGrad;
import org.nd4j.linalg.ops.transforms.Transforms;

public class BarnesHutTsne
extends Tsne
implements Model {
    private int N;
    private double theta;
    private INDArray rows;
    private INDArray cols;
    private INDArray vals;
    private String simiarlityFunction = "cosinesimilarity";
    private boolean invert = true;
    private INDArray x;
    private int numDimensions = 0;
    public static final String Y_GRAD = "yIncs";
    private SpTree tree;
    private INDArray gains;
    private INDArray yIncs;

    public BarnesHutTsne(INDArray x, INDArray y, int numDimensions, double perplexity, double theta, int maxIter, int stopLyingIteration, int momentumSwitchIteration, double momentum, double finalMomentum, double learningRate) {
        this.Y = y;
        this.x = x;
        this.numDimensions = numDimensions;
        this.perplexity = perplexity;
        this.theta = theta;
        this.maxIter = maxIter;
        this.stopLyingIteration = stopLyingIteration;
        this.momentum = momentum;
        this.finalMomentum = finalMomentum;
        this.learningRate = learningRate;
        this.switchMomentumIteration = momentumSwitchIteration;
    }

    public BarnesHutTsne(INDArray x, INDArray y, int numDimensions, String simiarlityFunction, double theta, boolean invert, int maxIter, double realMin, double initialMomentum, double finalMomentum, double momentum, int switchMomentumIteration, boolean normalize, boolean usePca, int stopLyingIteration, double tolerance, double learningRate, boolean useAdaGrad, double perplexity, double minGain) {
        this.maxIter = maxIter;
        this.realMin = realMin;
        this.initialMomentum = initialMomentum;
        this.finalMomentum = finalMomentum;
        this.momentum = momentum;
        this.normalize = normalize;
        this.useAdaGrad = useAdaGrad;
        this.usePca = usePca;
        this.stopLyingIteration = stopLyingIteration;
        this.learningRate = learningRate;
        this.switchMomentumIteration = switchMomentumIteration;
        this.tolerance = tolerance;
        this.perplexity = perplexity;
        this.minGain = minGain;
        this.Y = y;
        this.x = x;
        this.numDimensions = numDimensions;
        this.simiarlityFunction = simiarlityFunction;
        this.theta = theta;
        this.invert = invert;
    }

    public String getSimiarlityFunction() {
        return this.simiarlityFunction;
    }

    public void setSimiarlityFunction(String simiarlityFunction) {
        this.simiarlityFunction = simiarlityFunction;
    }

    public boolean isInvert() {
        return this.invert;
    }

    public void setInvert(boolean invert) {
        this.invert = invert;
    }

    public double getTheta() {
        return this.theta;
    }

    public double getPerplexity() {
        return this.perplexity;
    }

    public INDArray computeGaussianPerplexity(INDArray d, double u) {
        this.N = d.rows();
        int k = (int)(3.0 * u);
        if (u > (double)k) {
            throw new IllegalStateException("Illegal k value " + k + "greater than " + u);
        }
        this.rows = Nd4j.zeros((int)1, (int)(this.N + 1));
        this.cols = Nd4j.zeros((int)1, (int)(this.N * k));
        this.vals = Nd4j.zeros((int)1, (int)(this.N * k));
        for (int n = 0; n < this.N; ++n) {
            this.rows.putScalar(n + 1, this.rows.getDouble(n) + (double)k);
        }
        INDArray beta = Nd4j.ones((int)this.N, (int)1);
        double logU = FastMath.log((double)u);
        VPTree tree = new VPTree(d, this.simiarlityFunction, this.invert);
        logger.info("Calculating probabilities of data similarities...");
        for (int i = 0; i < this.N; ++i) {
            if (i % 500 == 0) {
                logger.info("Handled " + i + " records");
            }
            double betaMin = -1.7976931348623157E308;
            double betaMax = Double.MAX_VALUE;
            ArrayList<DataPoint> results = new ArrayList<DataPoint>();
            tree.search(new DataPoint(i, d.slice(i)), k + 1, results, new ArrayList<Double>());
            double betas = beta.getDouble(i);
            INDArray cArr = VPTree.buildFromData(results);
            Pair<INDArray, Double> pair = this.computeGaussianKernel(cArr, beta.getDouble(i), k);
            INDArray currP = pair.getFirst();
            double hDiff = pair.getSecond() - logU;
            int tries = 0;
            boolean found = false;
            while (!found && tries < 200) {
                if (hDiff < this.tolerance && -hDiff < this.tolerance) {
                    found = true;
                    continue;
                }
                if (hDiff > 0.0) {
                    betaMin = betas;
                    betas = betaMax == Double.MAX_VALUE || betaMax == -1.7976931348623157E308 ? (betas *= 2.0) : (betas + betaMax) / 2.0;
                } else {
                    betaMax = betas;
                    betas = betaMin == -1.7976931348623157E308 || betaMin == Double.MAX_VALUE ? (betas /= 2.0) : (betas + betaMin) / 2.0;
                }
                pair = this.computeGaussianKernel(cArr, betas, k);
                hDiff = pair.getSecond() - logU;
                ++tries;
            }
            currP.divi(currP.sum(new int[]{Integer.MAX_VALUE}));
            INDArray indices = Nd4j.create((int)1, (int)(k + 1));
            for (int j = 0; j < indices.length() && j < results.size(); ++j) {
                indices.putScalar(j, ((DataPoint)results.get(j)).getIndex());
            }
            for (int l = 0; l < k; ++l) {
                this.cols.putScalar(this.rows.getInt(new int[]{i}) + l, indices.getDouble(l + 1));
                this.vals.putScalar(this.rows.getInt(new int[]{i}) + l, currP.getDouble(l));
            }
        }
        return this.vals;
    }

    @Override
    public INDArray input() {
        return this.x;
    }

    @Override
    public void validateInput() {
    }

    @Override
    public ConvexOptimizer getOptimizer() {
        return null;
    }

    @Override
    public INDArray getParam(String param) {
        return null;
    }

    @Override
    public void initParams() {
    }

    @Override
    public Map<String, INDArray> paramTable() {
        return null;
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
    }

    @Override
    public void setParam(String key, INDArray val) {
    }

    @Override
    public void clear() {
    }

    protected Pair<Double, INDArray> gradient(INDArray p) {
        throw new UnsupportedOperationException();
    }

    public INDArray symmetrized(INDArray rowP, INDArray colP, INDArray valP) {
        int n;
        INDArray rowCounts = Nd4j.create((int)this.N);
        for (int n2 = 0; n2 < this.N; ++n2) {
            int begin = rowP.getInt(new int[]{n2});
            int end = rowP.getInt(new int[]{n2 + 1});
            for (int i = begin; i < end; ++i) {
                boolean present = false;
                for (int m = rowP.getInt(new int[]{colP.getInt(new int[]{i})}); m < rowP.getInt(new int[]{colP.getInt(new int[]{i}) + 1}); ++m) {
                    if (colP.getInt(new int[]{m}) != n2) continue;
                    present = true;
                }
                if (present) {
                    rowCounts.putScalar(n2, rowCounts.getDouble(n2) + 1.0);
                    continue;
                }
                rowCounts.putScalar(n2, rowCounts.getDouble(n2) + 1.0);
                rowCounts.putScalar(colP.getInt(new int[]{i}), rowCounts.getDouble(colP.getInt(new int[]{i})) + 1.0);
            }
        }
        int numElements = rowCounts.sum(new int[]{Integer.MAX_VALUE}).getInt(new int[]{0});
        INDArray offset = Nd4j.create((int)this.N);
        INDArray symRowP = Nd4j.create((int)(this.N + 1));
        INDArray symColP = Nd4j.create((int)numElements);
        INDArray symValP = Nd4j.create((int)numElements);
        for (n = 0; n < this.N; ++n) {
            symRowP.putScalar(n + 1, symRowP.getDouble(n) + rowCounts.getDouble(n));
        }
        for (n = 0; n < this.N; ++n) {
            for (int i = rowP.getInt(new int[]{n}); i < rowP.getInt(new int[]{n + 1}); ++i) {
                int colPI;
                boolean present = false;
                for (int m = rowP.getInt(new int[]{colP.getInt(new int[]{i})}); m < rowP.getInt(new int[]{colP.getInt(new int[]{i})}) + 1; ++m) {
                    if (colP.getInt(new int[]{m}) != n) continue;
                    present = true;
                    if (n >= colP.getInt(new int[]{i})) continue;
                    symColP.putScalar(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n}), colP.getInt(new int[]{i}));
                    symColP.putScalar(symRowP.getInt(new int[]{colP.getInt(new int[]{i})}) + offset.getInt(new int[]{colP.getInt(new int[]{i})}), n);
                    symValP.putScalar(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n}), valP.getDouble(i) + valP.getDouble(m));
                    symValP.putScalar(symRowP.getInt(new int[]{colP.getInt(new int[]{i})}) + offset.getInt(new int[]{colP.getInt(new int[]{i})}), valP.getDouble(i) + valP.getDouble(m));
                }
                if (!present && n < (colPI = colP.getInt(new int[]{i}))) {
                    symColP.putScalar(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n}), colPI);
                    symColP.putScalar(symRowP.getInt(new int[]{colP.getInt(new int[]{i})}) + offset.getInt(new int[]{colPI}), n);
                    symValP.putScalar(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n}), valP.getDouble(i));
                    symValP.putScalar(symRowP.getInt(new int[]{colPI}) + offset.getInt(new int[]{colPI}), valP.getDouble(i));
                }
                if (present && (!present || n >= colP.getInt(new int[]{i}))) continue;
                offset.putScalar(n, offset.getInt(new int[]{n}) + 1);
                colPI = colP.getInt(new int[]{i});
                if (colPI == n) continue;
                offset.putScalar(colPI, offset.getDouble(colPI) + 1.0);
            }
        }
        symValP.divi((Number)2.0);
        return symValP;
    }

    public Pair<INDArray, Double> computeGaussianKernel(INDArray distances, double beta, int k) {
        INDArray currP = Nd4j.create((int)k);
        for (int m = 0; m < k; ++m) {
            currP.putScalar(m, FastMath.exp((double)(-beta * distances.getDouble(m + 1))));
        }
        double sum = currP.sum(new int[]{Integer.MAX_VALUE}).getDouble(0);
        double h = 0.0;
        for (int m = 0; m < k; ++m) {
            h += beta * (distances.getDouble(m + 1) * currP.getDouble(m));
        }
        h = h / sum + FastMath.log((double)sum);
        return new Pair<INDArray, Double>(currP, h);
    }

    @Override
    public void fit() {
        boolean exact;
        boolean bl = exact = this.theta == 0.0;
        if (exact) {
            this.Y = super.calculate(this.x, this.numDimensions, this.perplexity);
        } else {
            if (this.Y == null) {
                this.Y = Nd4j.randn((int)this.x.rows(), (int)this.numDimensions, (Random)Nd4j.getRandom()).muli((Number)Float.valueOf(0.001f));
            }
            this.computeGaussianPerplexity(this.x, this.perplexity);
            this.vals = this.symmetrized(this.rows, this.cols, this.vals).divi(this.vals.sum(new int[]{Integer.MAX_VALUE}));
            this.vals.muli((Number)12);
            for (int i = 0; i < this.maxIter; ++i) {
                this.step(this.vals, i);
                if (i == this.switchMomentumIteration) {
                    this.momentum = this.finalMomentum;
                }
                if (i == this.stopLyingIteration) {
                    this.vals.divi((Number)12);
                }
                if (this.iterationListener != null) {
                    this.iterationListener.iterationDone(this, i);
                }
                logger.info("Error at iteration " + i + " is " + this.score());
            }
        }
    }

    public void step(INDArray p, int i) {
        this.update(this.gradient().getGradientFor(Y_GRAD), Y_GRAD);
    }

    @Override
    public void update(INDArray gradient, String paramType) {
        INDArray yGrads = gradient;
        this.gains = this.gains.add((Number)0.2).muli(Transforms.sign((INDArray)yGrads)).neqi(Transforms.sign((INDArray)this.yIncs)).addi(this.gains.mul((Number)0.8).muli(Transforms.sign((INDArray)yGrads)).neqi(Transforms.sign((INDArray)this.yIncs)));
        BooleanIndexing.applyWhere((INDArray)this.gains, (Condition)Conditions.lessThan((Number)this.minGain), (Function)new Value((Number)this.minGain));
        INDArray gradChange = this.gains.mul(yGrads);
        if (this.useAdaGrad) {
            if (this.adaGrad == null) {
                this.adaGrad = new AdaGrad();
            }
            gradChange = this.adaGrad.getGradient(gradChange, 0);
        } else {
            gradChange.muli((Number)this.learningRate);
        }
        this.yIncs.muli((Number)this.momentum).subi(gradChange);
        this.Y.addi(this.yIncs);
    }

    @Override
    public void plot(INDArray matrix, int nDims, List<String> labels, String path) throws IOException {
        this.fit(matrix, nDims);
        BufferedWriter write = new BufferedWriter(new FileWriter(new File(path)));
        for (int i = 0; i < this.Y.rows() && i < labels.size(); ++i) {
            String word = labels.get(i);
            if (word == null) continue;
            StringBuffer sb = new StringBuffer();
            INDArray wordVector = this.Y.getRow(i);
            for (int j = 0; j < wordVector.length(); ++j) {
                sb.append(wordVector.getDouble(j));
                if (j >= wordVector.length() - 1) continue;
                sb.append(",");
            }
            sb.append(",");
            sb.append(word);
            sb.append(" ");
            sb.append("\n");
            write.write(sb.toString());
        }
        write.flush();
        write.close();
    }

    @Override
    public double score() {
        INDArray buff = Nd4j.create((int)this.numDimensions);
        AtomicDouble sum_Q = new AtomicDouble(0.0);
        for (int n = 0; n < this.N; ++n) {
            this.tree.computeNonEdgeForces(n, this.theta, buff, sum_Q);
        }
        double C = 0.0;
        INDArray linear = this.Y;
        for (int n = 0; n < this.N; ++n) {
            int begin = this.rows.getInt(new int[]{n});
            int end = this.rows.getInt(new int[]{n + 1});
            int ind1 = n;
            for (int i = begin; i < end; ++i) {
                int ind2 = this.cols.getInt(new int[]{i});
                buff.assign(linear.slice(ind1));
                buff.subi(linear.slice(ind2));
                double Q = Transforms.pow((INDArray)buff, (Number)2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0);
                Q = 1.0 / (1.0 + Q) / sum_Q.doubleValue();
                C += this.vals.getDouble(i) * FastMath.log((double)(this.vals.getDouble(i) + Nd4j.EPS_THRESHOLD)) / (Q + Nd4j.EPS_THRESHOLD);
            }
        }
        return C;
    }

    @Override
    public void computeGradientAndScore() {
    }

    @Override
    public void accumulateScore(double accum) {
    }

    @Override
    public INDArray params() {
        return null;
    }

    @Override
    public int numParams() {
        return 0;
    }

    @Override
    public int numParams(boolean backwards) {
        return 0;
    }

    @Override
    public void setParams(INDArray params) {
    }

    @Override
    public void setParamsViewArray(INDArray params) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void setBackpropGradientsViewArray(INDArray gradients) {
        throw new UnsupportedOperationException();
    }

    @Override
    public void applyLearningRateScoreDecay() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    @Override
    public void fit(INDArray data) {
        this.x = data;
        this.fit();
    }

    public void fit(INDArray data, int nDims) {
        this.x = data;
        this.numDimensions = nDims;
        this.fit();
    }

    @Override
    public void iterate(INDArray input) {
    }

    @Override
    public Gradient gradient() {
        if (this.yIncs == null) {
            this.yIncs = Nd4j.zeros((int[])this.Y.shape());
        }
        if (this.gains == null) {
            this.gains = Nd4j.ones((int[])this.Y.shape());
        }
        AtomicDouble sumQ = new AtomicDouble(0.0);
        INDArray posF = Nd4j.create((int[])this.Y.shape());
        INDArray negF = Nd4j.create((int[])this.Y.shape());
        if (this.tree == null) {
            this.tree = new SpTree(this.Y);
        }
        this.tree.computeEdgeForces(this.rows, this.cols, this.vals, this.N, posF);
        for (int n = 0; n < this.N; ++n) {
            this.tree.computeNonEdgeForces(n, this.theta, negF.slice(n), sumQ);
        }
        INDArray dC = posF.subi(negF.divi((Number)sumQ));
        DefaultGradient ret = new DefaultGradient();
        ret.gradientForVariable().put(Y_GRAD, dC);
        return ret;
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<Gradient, Double>(this.gradient(), this.score());
    }

    @Override
    public int batchSize() {
        return 0;
    }

    @Override
    public NeuralNetConfiguration conf() {
        return null;
    }

    @Override
    public void setConf(NeuralNetConfiguration conf) {
    }

    public static class Builder
    extends Tsne.Builder {
        private double theta = 0.0;
        private boolean invert = true;
        private String similarityFunction = "cosinesimilarity";

        public Builder similarityFunction(String similarityFunction) {
            this.similarityFunction = similarityFunction;
            return this;
        }

        public Builder invertDistanceMetric(boolean invert) {
            this.invert = invert;
            return this;
        }

        public Builder theta(double theta) {
            this.theta = theta;
            return this;
        }

        @Override
        public Builder minGain(double minGain) {
            super.minGain(minGain);
            return this;
        }

        @Override
        public Builder perplexity(double perplexity) {
            super.perplexity(perplexity);
            return this;
        }

        @Override
        public Builder useAdaGrad(boolean useAdaGrad) {
            super.useAdaGrad(useAdaGrad);
            return this;
        }

        @Override
        public Builder learningRate(double learningRate) {
            super.learningRate(learningRate);
            return this;
        }

        @Override
        public Builder tolerance(double tolerance) {
            super.tolerance(tolerance);
            return this;
        }

        @Override
        public Builder stopLyingIteration(int stopLyingIteration) {
            super.stopLyingIteration(stopLyingIteration);
            return this;
        }

        @Override
        public Builder usePca(boolean usePca) {
            super.usePca(usePca);
            return this;
        }

        @Override
        public Builder normalize(boolean normalize) {
            super.normalize(normalize);
            return this;
        }

        @Override
        public Builder setMaxIter(int maxIter) {
            super.setMaxIter(maxIter);
            return this;
        }

        @Override
        public Builder setRealMin(double realMin) {
            super.setRealMin(realMin);
            return this;
        }

        @Override
        public Builder setInitialMomentum(double initialMomentum) {
            super.setInitialMomentum(initialMomentum);
            return this;
        }

        @Override
        public Builder setFinalMomentum(double finalMomentum) {
            super.setFinalMomentum(finalMomentum);
            return this;
        }

        @Override
        public Builder setMomentum(double momentum) {
            super.setMomentum(momentum);
            return this;
        }

        @Override
        public Builder setSwitchMomentumIteration(int switchMomentumIteration) {
            super.setSwitchMomentumIteration(switchMomentumIteration);
            return this;
        }

        @Override
        public BarnesHutTsne build() {
            return new BarnesHutTsne(null, null, 2, this.similarityFunction, this.theta, this.invert, this.maxIter, this.realMin, this.initialMomentum, this.finalMomentum, this.momentum, this.switchMomentumIteration, this.normalize, this.usePca, this.stopLyingIteration, this.tolerance, this.learningRate, this.useAdaGrad, this.perplexity, this.minGain);
        }
    }
}

