/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.plot;

import com.google.common.base.Function;
import com.google.common.util.concurrent.AtomicDouble;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.sptree.SpTree;
import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.plot.Tsne;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.learning.legacy.AdaGrad;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class BarnesHutTsne
implements Model {
    private static final Logger log = LoggerFactory.getLogger(BarnesHutTsne.class);
    protected int maxIter = 1000;
    protected double realMin = Nd4j.EPS_THRESHOLD;
    protected double initialMomentum = 0.5;
    protected double finalMomentum = 0.8;
    protected double minGain = 0.01;
    protected double momentum = this.initialMomentum;
    protected int switchMomentumIteration = 100;
    protected boolean normalize = true;
    protected boolean usePca = false;
    protected int stopLyingIteration = 250;
    protected double tolerance = 1.0E-5;
    protected double learningRate = 500.0;
    protected AdaGrad adaGrad;
    protected boolean useAdaGrad = true;
    protected double perplexity = 30.0;
    protected INDArray Y;
    private int N;
    private double theta;
    private INDArray rows;
    private INDArray cols;
    private INDArray vals;
    private String simiarlityFunction = "cosinesimilarity";
    private boolean invert = true;
    private INDArray x;
    private int numDimensions = 0;
    public static final String Y_GRAD = "yIncs";
    private SpTree tree;
    private INDArray gains;
    private INDArray yIncs;
    protected transient IterationListener iterationListener;

    public BarnesHutTsne(int numDimensions, String simiarlityFunction, double theta, boolean invert, int maxIter, double realMin, double initialMomentum, double finalMomentum, double momentum, int switchMomentumIteration, boolean normalize, int stopLyingIteration, double tolerance, double learningRate, boolean useAdaGrad, double perplexity, IterationListener iterationListener, double minGain) {
        this.maxIter = maxIter;
        this.realMin = realMin;
        this.initialMomentum = initialMomentum;
        this.finalMomentum = finalMomentum;
        this.momentum = momentum;
        this.normalize = normalize;
        this.useAdaGrad = useAdaGrad;
        this.stopLyingIteration = stopLyingIteration;
        this.learningRate = learningRate;
        this.switchMomentumIteration = switchMomentumIteration;
        this.tolerance = tolerance;
        this.perplexity = perplexity;
        this.minGain = minGain;
        this.numDimensions = numDimensions;
        this.simiarlityFunction = simiarlityFunction;
        this.theta = theta;
        this.iterationListener = iterationListener;
        this.invert = invert;
    }

    public String getSimiarlityFunction() {
        return this.simiarlityFunction;
    }

    public void setSimiarlityFunction(String simiarlityFunction) {
        this.simiarlityFunction = simiarlityFunction;
    }

    public boolean isInvert() {
        return this.invert;
    }

    public void setInvert(boolean invert) {
        this.invert = invert;
    }

    public double getTheta() {
        return this.theta;
    }

    public double getPerplexity() {
        return this.perplexity;
    }

    public int getNumDimensions() {
        return this.numDimensions;
    }

    public void setNumDimensions(int numDimensions) {
        this.numDimensions = numDimensions;
    }

    public INDArray computeGaussianPerplexity(INDArray d, double u) {
        this.N = d.rows();
        int k = (int)(3.0 * u);
        if (u > (double)k) {
            throw new IllegalStateException("Illegal k value " + k + "greater than " + u);
        }
        this.rows = Nd4j.zeros((int)1, (int)(this.N + 1));
        this.cols = Nd4j.zeros((int)1, (int)(this.N * k));
        this.vals = Nd4j.zeros((int)1, (int)(this.N * k));
        for (int n = 0; n < this.N; ++n) {
            this.rows.putScalar(n + 1, this.rows.getDouble(n) + (double)k);
        }
        INDArray beta = Nd4j.ones((int)this.N, (int)1);
        double logU = FastMath.log((double)u);
        VPTree tree = new VPTree(d, this.simiarlityFunction, this.invert);
        log.info("Calculating probabilities of data similarities...");
        for (int i = 0; i < this.N; ++i) {
            if (i % 500 == 0) {
                log.info("Handled " + i + " records");
            }
            double betaMin = -1.7976931348623157E308;
            double betaMax = Double.MAX_VALUE;
            ArrayList results = new ArrayList();
            tree.search(d.slice(i), k + 1, results, new ArrayList());
            double betas = beta.getDouble(i);
            INDArray cArr = VPTree.buildFromData(results);
            Pair<INDArray, Double> pair = this.computeGaussianKernel(cArr, beta.getDouble(i), k);
            INDArray currP = (INDArray)pair.getFirst();
            double hDiff = (Double)pair.getSecond() - logU;
            int tries = 0;
            boolean found = false;
            while (!found && tries < 200) {
                if (hDiff < this.tolerance && -hDiff < this.tolerance) {
                    found = true;
                    continue;
                }
                if (hDiff > 0.0) {
                    betaMin = betas;
                    betas = betaMax == Double.MAX_VALUE || betaMax == -1.7976931348623157E308 ? (betas *= 2.0) : (betas + betaMax) / 2.0;
                } else {
                    betaMax = betas;
                    betas = betaMin == -1.7976931348623157E308 || betaMin == Double.MAX_VALUE ? (betas /= 2.0) : (betas + betaMin) / 2.0;
                }
                pair = this.computeGaussianKernel(cArr, betas, k);
                hDiff = (Double)pair.getSecond() - logU;
                ++tries;
            }
            currP.divi(currP.sum(new int[]{Integer.MAX_VALUE}));
            INDArray indices = Nd4j.create((int)1, (int)(k + 1));
            for (int j = 0; j < indices.length() && j < results.size(); ++j) {
                indices.putScalar(j, ((DataPoint)results.get(j)).getIndex());
            }
            for (int l = 0; l < k; ++l) {
                this.cols.putScalar(this.rows.getInt(new int[]{i}) + l, indices.getDouble(l + 1));
                this.vals.putScalar(this.rows.getInt(new int[]{i}) + l, currP.getDouble(l));
            }
        }
        return this.vals;
    }

    public INDArray input() {
        return this.x;
    }

    public void validateInput() {
    }

    public ConvexOptimizer getOptimizer() {
        return null;
    }

    public INDArray getParam(String param) {
        return null;
    }

    public void initParams() {
    }

    public void addListeners(IterationListener ... listener) {
    }

    public Map<String, INDArray> paramTable() {
        return null;
    }

    public Map<String, INDArray> paramTable(boolean backprapParamsOnly) {
        return null;
    }

    public void setParamTable(Map<String, INDArray> paramTable) {
    }

    public void setParam(String key, INDArray val) {
    }

    public void clear() {
    }

    protected Pair<Double, INDArray> gradient(INDArray p) {
        throw new UnsupportedOperationException();
    }

    public INDArray symmetrized(INDArray rowP, INDArray colP, INDArray valP) {
        int n;
        INDArray rowCounts = Nd4j.create((int)this.N);
        for (int n2 = 0; n2 < this.N; ++n2) {
            int begin = rowP.getInt(new int[]{n2});
            int end = rowP.getInt(new int[]{n2 + 1});
            for (int i = begin; i < end; ++i) {
                boolean present = false;
                for (int m = rowP.getInt(new int[]{colP.getInt(new int[]{i})}); m < rowP.getInt(new int[]{colP.getInt(new int[]{i}) + 1}); ++m) {
                    if (colP.getInt(new int[]{m}) != n2) continue;
                    present = true;
                }
                if (present) {
                    rowCounts.putScalar(n2, rowCounts.getDouble(n2) + 1.0);
                    continue;
                }
                rowCounts.putScalar(n2, rowCounts.getDouble(n2) + 1.0);
                rowCounts.putScalar(colP.getInt(new int[]{i}), rowCounts.getDouble(colP.getInt(new int[]{i})) + 1.0);
            }
        }
        int numElements = rowCounts.sum(new int[]{Integer.MAX_VALUE}).getInt(new int[]{0});
        INDArray offset = Nd4j.create((int)this.N);
        INDArray symRowP = Nd4j.create((int)(this.N + 1));
        INDArray symColP = Nd4j.create((int)numElements);
        INDArray symValP = Nd4j.create((int)numElements);
        for (n = 0; n < this.N; ++n) {
            symRowP.putScalar(n + 1, symRowP.getDouble(n) + rowCounts.getDouble(n));
        }
        for (n = 0; n < this.N; ++n) {
            for (int i = rowP.getInt(new int[]{n}); i < rowP.getInt(new int[]{n + 1}); ++i) {
                int colPI;
                boolean present = false;
                for (int m = rowP.getInt(new int[]{colP.getInt(new int[]{i})}); m < rowP.getInt(new int[]{colP.getInt(new int[]{i})}) + 1; ++m) {
                    if (colP.getInt(new int[]{m}) != n) continue;
                    present = true;
                    if (n >= colP.getInt(new int[]{i})) continue;
                    symColP.putScalar(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n}), colP.getInt(new int[]{i}));
                    symColP.putScalar(symRowP.getInt(new int[]{colP.getInt(new int[]{i})}) + offset.getInt(new int[]{colP.getInt(new int[]{i})}), n);
                    symValP.putScalar(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n}), valP.getDouble(i) + valP.getDouble(m));
                    symValP.putScalar(symRowP.getInt(new int[]{colP.getInt(new int[]{i})}) + offset.getInt(new int[]{colP.getInt(new int[]{i})}), valP.getDouble(i) + valP.getDouble(m));
                }
                if (!present && n < (colPI = colP.getInt(new int[]{i}))) {
                    symColP.putScalar(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n}), colPI);
                    symColP.putScalar(symRowP.getInt(new int[]{colP.getInt(new int[]{i})}) + offset.getInt(new int[]{colPI}), n);
                    symValP.putScalar(symRowP.getInt(new int[]{n}) + offset.getInt(new int[]{n}), valP.getDouble(i));
                    symValP.putScalar(symRowP.getInt(new int[]{colPI}) + offset.getInt(new int[]{colPI}), valP.getDouble(i));
                }
                if (present && (!present || n >= colP.getInt(new int[]{i}))) continue;
                offset.putScalar(n, offset.getInt(new int[]{n}) + 1);
                colPI = colP.getInt(new int[]{i});
                if (colPI == n) continue;
                offset.putScalar(colPI, offset.getDouble(colPI) + 1.0);
            }
        }
        symValP.divi((Number)2.0);
        return symValP;
    }

    public Pair<INDArray, Double> computeGaussianKernel(INDArray distances, double beta, int k) {
        INDArray currP = Nd4j.create((int)k);
        for (int m = 0; m < k; ++m) {
            currP.putScalar(m, FastMath.exp((double)(-beta * distances.getDouble(m + 1))));
        }
        double sum = currP.sum(new int[]{Integer.MAX_VALUE}).getDouble(0);
        double h = 0.0;
        for (int m = 0; m < k; ++m) {
            h += beta * (distances.getDouble(m + 1) * currP.getDouble(m));
        }
        h = h / sum + FastMath.log((double)sum);
        return new Pair((Object)currP, (Object)h);
    }

    public void init() {
    }

    public void setListeners(Collection<IterationListener> listeners) {
    }

    public void setListeners(IterationListener ... listeners) {
    }

    public void fit() {
        if (this.theta == 0.0) {
            log.debug("theta == 0, using decomposed version, might be slow");
            Tsne decomposedTsne = new Tsne(this.maxIter, this.realMin, this.initialMomentum, this.finalMomentum, this.minGain, this.momentum, this.switchMomentumIteration, this.normalize, this.usePca, this.stopLyingIteration, this.tolerance, this.learningRate, this.useAdaGrad, this.perplexity);
            this.Y = decomposedTsne.calculate(this.x, this.numDimensions, this.perplexity);
        } else {
            if (this.Y == null) {
                this.Y = Nd4j.randn((int)this.x.rows(), (int)this.numDimensions, (Random)Nd4j.getRandom()).muli((Number)Float.valueOf(0.001f));
            }
            this.computeGaussianPerplexity(this.x, this.perplexity);
            this.vals = this.symmetrized(this.rows, this.cols, this.vals).divi(this.vals.sum(new int[]{Integer.MAX_VALUE}));
            this.vals.muli((Number)12);
            for (int i = 0; i < this.maxIter; ++i) {
                this.step(this.vals, i);
                if (i == this.switchMomentumIteration) {
                    this.momentum = this.finalMomentum;
                }
                if (i == this.stopLyingIteration) {
                    this.vals.divi((Number)12);
                }
                if (this.iterationListener != null) {
                    this.iterationListener.iterationDone((Model)this, i);
                }
                log.info("Error at iteration " + i + " is " + this.score());
            }
        }
    }

    public void update(Gradient gradient) {
    }

    public void step(INDArray p, int i) {
        this.update(this.gradient().getGradientFor(Y_GRAD), Y_GRAD);
    }

    public void update(INDArray gradient, String paramType) {
        INDArray yGrads = gradient;
        this.gains = this.gains.add((Number)0.2).muli(Transforms.sign((INDArray)yGrads)).neqi(Transforms.sign((INDArray)this.yIncs)).addi(this.gains.mul((Number)0.8).muli(Transforms.sign((INDArray)yGrads)).neqi(Transforms.sign((INDArray)this.yIncs)));
        BooleanIndexing.applyWhere((INDArray)this.gains, (Condition)Conditions.lessThan((Number)this.minGain), (Function)new Value((Number)this.minGain));
        INDArray gradChange = this.gains.mul(yGrads);
        if (this.useAdaGrad) {
            if (this.adaGrad == null) {
                this.adaGrad = new AdaGrad(gradient.shape(), this.learningRate);
                this.adaGrad.setStateViewArray(Nd4j.zeros((int[])gradient.shape()).reshape(1, gradChange.length()), gradChange.shape(), gradient.ordering(), true);
            }
            gradChange = this.adaGrad.getGradient(gradChange, 0);
        } else {
            gradChange.muli((Number)this.learningRate);
        }
        this.yIncs.muli((Number)this.momentum).subi(gradChange);
        this.Y.addi(this.yIncs);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public void saveAsFile(List<String> labels, String path) throws IOException {
        try (BufferedWriter write = null;){
            write = new BufferedWriter(new FileWriter(new File(path)));
            for (int i = 0; i < this.Y.rows() && i < labels.size(); ++i) {
                String word = labels.get(i);
                if (word == null) continue;
                StringBuilder sb = new StringBuilder();
                INDArray wordVector = this.Y.getRow(i);
                for (int j = 0; j < wordVector.length(); ++j) {
                    sb.append(wordVector.getDouble(j));
                    if (j >= wordVector.length() - 1) continue;
                    sb.append(",");
                }
                sb.append(",");
                sb.append(word);
                sb.append(" ");
                sb.append("\n");
                write.write(sb.toString());
            }
            write.flush();
            write.close();
        }
    }

    @Deprecated
    public void plot(INDArray matrix, int nDims, List<String> labels, String path) throws IOException {
        this.fit(matrix, nDims);
        this.saveAsFile(labels, path);
    }

    public double score() {
        INDArray buff = Nd4j.create((int)this.numDimensions);
        AtomicDouble sum_Q = new AtomicDouble(0.0);
        for (int n = 0; n < this.N; ++n) {
            this.tree.computeNonEdgeForces(n, this.theta, buff, sum_Q);
        }
        double C = 0.0;
        INDArray linear = this.Y;
        for (int n = 0; n < this.N; ++n) {
            int begin = this.rows.getInt(new int[]{n});
            int end = this.rows.getInt(new int[]{n + 1});
            int ind1 = n;
            for (int i = begin; i < end; ++i) {
                int ind2 = this.cols.getInt(new int[]{i});
                buff.assign(linear.slice(ind1));
                buff.subi(linear.slice(ind2));
                double Q = Transforms.pow((INDArray)buff, (Number)2).sum(new int[]{Integer.MAX_VALUE}).getDouble(0);
                Q = 1.0 / (1.0 + Q) / sum_Q.doubleValue();
                C += this.vals.getDouble(i) * FastMath.log((double)(this.vals.getDouble(i) + Nd4j.EPS_THRESHOLD)) / (Q + Nd4j.EPS_THRESHOLD);
            }
        }
        return C;
    }

    public void computeGradientAndScore() {
    }

    public void accumulateScore(double accum) {
    }

    public INDArray params() {
        return null;
    }

    public int numParams() {
        return 0;
    }

    public int numParams(boolean backwards) {
        return 0;
    }

    public void setParams(INDArray params) {
    }

    public void setParamsViewArray(INDArray params) {
        throw new UnsupportedOperationException();
    }

    public INDArray getGradientsViewArray() {
        throw new UnsupportedOperationException();
    }

    public void setBackpropGradientsViewArray(INDArray gradients) {
        throw new UnsupportedOperationException();
    }

    public void applyLearningRateScoreDecay() {
        throw new UnsupportedOperationException("Not yet implemented");
    }

    public void fit(INDArray data) {
        this.x = data;
        this.fit();
    }

    @Deprecated
    public void fit(INDArray data, int nDims) {
        this.x = data;
        this.numDimensions = nDims;
        this.fit();
    }

    public void iterate(INDArray input) {
    }

    public Gradient gradient() {
        if (this.yIncs == null) {
            this.yIncs = Nd4j.zeros((int[])this.Y.shape());
        }
        if (this.gains == null) {
            this.gains = Nd4j.ones((int[])this.Y.shape());
        }
        AtomicDouble sumQ = new AtomicDouble(0.0);
        INDArray posF = Nd4j.create((int[])this.Y.shape());
        INDArray negF = Nd4j.create((int[])this.Y.shape());
        if (this.tree == null) {
            this.tree = new SpTree(this.Y);
        }
        this.tree.computeEdgeForces(this.rows, this.cols, this.vals, this.N, posF);
        for (int n = 0; n < this.N; ++n) {
            this.tree.computeNonEdgeForces(n, this.theta, negF.slice(n), sumQ);
        }
        INDArray dC = posF.subi(negF.divi((Number)sumQ));
        DefaultGradient ret = new DefaultGradient();
        ret.gradientForVariable().put(Y_GRAD, dC);
        return ret;
    }

    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair((Object)this.gradient(), (Object)this.score());
    }

    public int batchSize() {
        return 0;
    }

    public NeuralNetConfiguration conf() {
        return null;
    }

    public void setConf(NeuralNetConfiguration conf) {
    }

    public INDArray getData() {
        return this.Y;
    }

    public void setData(INDArray data) {
        this.Y = data;
    }

    public static class Builder {
        private int maxIter = 1000;
        private double realMin = 1.0E-12f;
        private double initialMomentum = 0.5;
        private double finalMomentum = 0.8f;
        private double momentum = 0.5;
        private int switchMomentumIteration = 100;
        private boolean normalize = true;
        private int stopLyingIteration = 100;
        private double tolerance = 1.0E-5f;
        private double learningRate = 0.1f;
        private boolean useAdaGrad = false;
        private double perplexity = 30.0;
        private double minGain = 0.1f;
        private double theta = 0.5;
        private boolean invert = true;
        private int numDim = 2;
        private String similarityFunction = "cosinesimilarity";

        public Builder minGain(double minGain) {
            this.minGain = minGain;
            return this;
        }

        public Builder perplexity(double perplexity) {
            this.perplexity = perplexity;
            return this;
        }

        public Builder useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder learningRate(double learningRate) {
            this.learningRate = learningRate;
            return this;
        }

        public Builder tolerance(double tolerance) {
            this.tolerance = tolerance;
            return this;
        }

        public Builder stopLyingIteration(int stopLyingIteration) {
            this.stopLyingIteration = stopLyingIteration;
            return this;
        }

        public Builder normalize(boolean normalize) {
            this.normalize = normalize;
            return this;
        }

        public Builder setMaxIter(int maxIter) {
            this.maxIter = maxIter;
            return this;
        }

        public Builder setRealMin(double realMin) {
            this.realMin = realMin;
            return this;
        }

        public Builder setInitialMomentum(double initialMomentum) {
            this.initialMomentum = initialMomentum;
            return this;
        }

        public Builder setFinalMomentum(double finalMomentum) {
            this.finalMomentum = finalMomentum;
            return this;
        }

        public Builder setMomentum(double momentum) {
            this.momentum = momentum;
            return this;
        }

        public Builder setSwitchMomentumIteration(int switchMomentumIteration) {
            this.switchMomentumIteration = switchMomentumIteration;
            return this;
        }

        public Builder similarityFunction(String similarityFunction) {
            this.similarityFunction = similarityFunction;
            return this;
        }

        public Builder invertDistanceMetric(boolean invert) {
            this.invert = invert;
            return this;
        }

        public Builder theta(double theta) {
            this.theta = theta;
            return this;
        }

        public Builder numDimension(int numDim) {
            this.numDim = numDim;
            return this;
        }

        public BarnesHutTsne build() {
            return new BarnesHutTsne(this.numDim, this.similarityFunction, this.theta, this.invert, this.maxIter, this.realMin, this.initialMomentum, this.finalMomentum, this.momentum, this.switchMomentumIteration, this.normalize, this.stopLyingIteration, this.tolerance, this.learningRate, this.useAdaGrad, this.perplexity, null, this.minGain);
        }
    }
}

