/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.plot;

import com.google.common.base.Function;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStream;
import java.io.Serializable;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dimensionalityreduction.PCA;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.indexing.functions.Value;
import org.nd4j.linalg.indexing.functions.Zero;
import org.nd4j.linalg.learning.AdaGrad;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

public class Tsne
implements Serializable {
    protected int maxIter = 1000;
    protected double realMin = 1.0E-12;
    protected double initialMomentum = 0.5;
    protected double finalMomentum = 0.8;
    protected double minGain = 0.01;
    protected double momentum = this.initialMomentum;
    protected int switchMomentumIteration = 100;
    protected boolean normalize = true;
    protected boolean usePca = false;
    protected int stopLyingIteration = 250;
    protected double tolerance = 1.0E-5;
    protected double learningRate = 500.0;
    protected AdaGrad adaGrad;
    protected boolean useAdaGrad = true;
    protected double perplexity = 30.0;
    protected INDArray gains;
    protected INDArray yIncs;
    protected INDArray y;
    protected transient IterationListener iterationListener;
    protected static ClassPathResource r = new ClassPathResource("/scripts/tsne.py");
    protected static ClassPathResource r2 = new ClassPathResource("/scripts/render.py");
    protected static Logger log;

    public Tsne() {
    }

    protected static void loadIntoTmp() {
        File script = new File("/tmp/tsne.py");
        try {
            List lines = IOUtils.readLines((InputStream)r.getInputStream());
            FileUtils.writeLines((File)script, (Collection)lines);
        }
        catch (IOException e) {
            throw new IllegalStateException("Unable to load python file");
        }
        File script2 = new File("/tmp/render.py");
        try {
            List lines2 = IOUtils.readLines((InputStream)r2.getInputStream());
            FileUtils.writeLines((File)script2, (Collection)lines2);
        }
        catch (IOException e) {
            throw new IllegalStateException("Unable to load python file");
        }
    }

    public Tsne(int maxIter, double realMin, double initialMomentum, double finalMomentum, double momentum, int switchMomentumIteration, boolean normalize, boolean usePca, int stopLyingIteration, double tolerance, double learningRate, boolean useAdaGrad, double perplexity, double minGain) {
        this.tolerance = tolerance;
        this.minGain = minGain;
        this.useAdaGrad = useAdaGrad;
        this.learningRate = learningRate;
        this.stopLyingIteration = stopLyingIteration;
        this.maxIter = maxIter;
        this.realMin = realMin;
        this.normalize = normalize;
        this.initialMomentum = initialMomentum;
        this.usePca = usePca;
        this.finalMomentum = finalMomentum;
        this.momentum = momentum;
        this.switchMomentumIteration = switchMomentumIteration;
        this.perplexity = perplexity;
    }

    public Pair<INDArray, INDArray> hBeta(INDArray d, double beta) {
        INDArray P = Transforms.exp((INDArray)d.neg().muli((Number)beta));
        INDArray sum = P.sum(Integer.MAX_VALUE);
        INDArray otherSum = d.mul(P).sum(0);
        INDArray H = Transforms.log((INDArray)sum).addi(otherSum.muli((Number)beta).divi(sum));
        P.divi(sum);
        return new Pair<INDArray, INDArray>(H, P);
    }

    public INDArray computeGaussianPerplexity(final INDArray d, double u) {
        int n = d.rows();
        final INDArray p = Nd4j.zeros((int)n, (int)n);
        final INDArray beta = Nd4j.ones((int)n, (int)1);
        final double logU = Math.log(u);
        log.info("Calculating probabilities of data similarities..");
        ExecutorService service = Executors.newFixedThreadPool(Runtime.getRuntime().availableProcessors());
        int i = 0;
        while (i < n) {
            if (i % 500 == 0) {
                log.info("Handled " + i + " records");
            }
            final int j = i++;
            service.submit(new Runnable(){

                @Override
                public void run() {
                    double betaMin = Double.NEGATIVE_INFINITY;
                    double betaMax = Double.POSITIVE_INFINITY;
                    NDArrayIndex[] range = new NDArrayIndex[]{NDArrayIndex.concat((NDArrayIndex[])new NDArrayIndex[]{NDArrayIndex.interval((int)0, (int)j), NDArrayIndex.interval((int)(j + 1), (int)d.columns())})};
                    INDArray row = d.slice(j).get(range);
                    Pair<INDArray, INDArray> pair = Tsne.this.hBeta(row, beta.getDouble(j));
                    INDArray hDiff = pair.getFirst().sub((Number)logU);
                    for (int tries = 0; BooleanIndexing.and((INDArray)Transforms.abs((INDArray)hDiff), (Condition)Conditions.greaterThan((Number)Tsne.this.tolerance)) && tries < 50; ++tries) {
                        if (BooleanIndexing.and((INDArray)hDiff, (Condition)Conditions.greaterThan((Number)0))) {
                            if (Double.isInfinite(betaMax)) {
                                beta.putScalar(j, beta.getDouble(j) * 2.0);
                            } else {
                                beta.putScalar(j, (beta.getDouble(j) + betaMax) / 2.0);
                            }
                            betaMin = beta.getDouble(j);
                        } else {
                            if (Double.isInfinite(betaMin)) {
                                beta.putScalar(j, beta.getDouble(j) / 2.0);
                            } else {
                                beta.putScalar(j, (beta.getDouble(j) + betaMin) / 2.0);
                            }
                            betaMax = beta.getDouble(j);
                        }
                        pair = Tsne.this.hBeta(row, beta.getDouble(j));
                        hDiff = pair.getFirst().subi((Number)logU);
                    }
                    p.slice(j).put(range, pair.getSecond());
                }
            });
        }
        try {
            service.shutdown();
            service.awaitTermination(1L, TimeUnit.DAYS);
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        d.data().flush();
        log.info("Mean value of sigma " + Transforms.sqrt((INDArray)beta.rdiv((Number)1)).mean(Integer.MAX_VALUE));
        BooleanIndexing.applyWhere((INDArray)p, (Condition)Conditions.isNan(), (Function)new Value((Number)this.realMin));
        INDArray permute = p.transpose();
        INDArray pOut = p.add(permute);
        BooleanIndexing.applyWhere((INDArray)pOut, (Condition)Conditions.isNan(), (Function)new Value((Number)this.realMin));
        pOut.divi(pOut.sum(Integer.MAX_VALUE));
        BooleanIndexing.applyWhere((INDArray)pOut, (Condition)Conditions.lessThan((Number)1.0E-12), (Function)new Value((Number)1.0E-12));
        return pOut;
    }

    public INDArray calculate(INDArray X, int nDims, double perplexity) {
        if (this.usePca) {
            X = PCA.pca((INDArray)X, (int)Math.min(50, X.columns()), (boolean)this.normalize);
        } else if (this.normalize) {
            X.subi(X.min(Integer.MAX_VALUE));
            X = X.divi(X.max(Integer.MAX_VALUE));
            X = X.subiRowVector(X.mean(0));
        }
        if (nDims > X.columns()) {
            nDims = X.columns();
        }
        INDArray sumX = Transforms.pow((INDArray)X, (Number)2).sum(1);
        INDArray D = X.mmul(X.transpose()).muli((Number)-2).addiRowVector(sumX).transpose().addiRowVector(sumX);
        X.data().flush();
        if (this.y == null) {
            this.y = Nd4j.randn((int)X.rows(), (int)nDims, (RandomGenerator)new MersenneTwister(123)).muli((Number)Float.valueOf(0.001f));
        }
        INDArray p = this.computeGaussianPerplexity(D, perplexity);
        D.data().flush();
        p.muli((Number)4);
        if (this.useAdaGrad && this.adaGrad == null) {
            this.adaGrad = new AdaGrad(this.y.shape());
            this.adaGrad.setMasterStepSize(this.learningRate);
        }
        for (int i = 0; i < this.maxIter; ++i) {
            this.step(p, i);
            if (i == this.switchMomentumIteration) {
                this.momentum = this.finalMomentum;
            }
            if (i == this.stopLyingIteration) {
                p.divi((Number)4);
            }
            if (this.iterationListener == null) continue;
            this.iterationListener.iterationDone(i);
        }
        return this.y;
    }

    protected Pair<Double, INDArray> gradient(INDArray p) {
        INDArray sumY = Transforms.pow((INDArray)this.y, (Number)2).sum(1);
        if (this.yIncs == null) {
            this.yIncs = Nd4j.zeros((int[])this.y.shape());
        }
        if (this.gains == null) {
            this.gains = Nd4j.ones((int[])this.y.shape());
        }
        INDArray qu = this.y.mmul(this.y.transpose()).muli((Number)-2).addiRowVector(sumY).transpose().addiRowVector(sumY).addi((Number)1).rdivi((Number)1);
        int n = this.y.rows();
        Nd4j.doAlongDiagonal((INDArray)qu, (Function)new Zero());
        INDArray q = qu.div(qu.sum(Integer.MAX_VALUE));
        BooleanIndexing.applyWhere((INDArray)q, (Condition)Conditions.lessThan((Number)this.realMin), (Function)new Value((Number)this.realMin));
        INDArray PQ = p.sub(q);
        INDArray yGrads = this.getYGradient(n, PQ, qu);
        this.gains = this.gains.add((Number)0.2).muli(yGrads.cond(Conditions.greaterThan((Number)0)).neqi(this.yIncs.cond(Conditions.greaterThan((Number)0)))).addi(this.gains.mul((Number)0.8).muli(yGrads.cond(Conditions.greaterThan((Number)0)).eqi(this.yIncs.cond(Conditions.greaterThan((Number)0)))));
        BooleanIndexing.applyWhere((INDArray)this.gains, (Condition)Conditions.lessThan((Number)this.minGain), (Function)new Value((Number)this.minGain));
        INDArray gradChange = this.gains.mul(yGrads);
        if (this.useAdaGrad) {
            gradChange = this.adaGrad.getGradient(gradChange);
        } else {
            gradChange.muli((Number)this.learningRate);
        }
        this.yIncs.muli((Number)this.momentum).subi(gradChange);
        double cost = p.mul(Transforms.log((INDArray)p.div(q), (boolean)false)).sum(Integer.MAX_VALUE).getDouble(0);
        return new Pair<Double, INDArray>(cost, this.yIncs);
    }

    public INDArray getYGradient(int n, INDArray PQ, INDArray qu) {
        INDArray yGrads = Nd4j.create((int[])this.y.shape());
        for (int i = 0; i < n; ++i) {
            INDArray sum1 = Nd4j.tile((INDArray)PQ.getRow(i).mul(qu.getRow(i)), (int[])new int[]{this.y.columns(), 1}).transpose().mul(this.y.getRow(i).broadcast(this.y.shape()).sub(this.y)).sum(0);
            yGrads.putRow(i, sum1);
        }
        return yGrads;
    }

    public void step(INDArray p, int i) {
        Pair<Double, INDArray> costGradient = this.gradient(p);
        INDArray yIncs = costGradient.getSecond();
        log.info("Cost at iteration " + i + " was " + costGradient.getFirst());
        this.y.addi(yIncs);
        this.y.addi(yIncs).subiRowVector(this.y.mean(0));
        this.y.subi(Nd4j.tile((INDArray)this.y.mean(0), (int[])new int[]{this.y.rows(), 1}));
    }

    public void plot(INDArray matrix, int nDims, List<String> labels) throws IOException {
        this.plot(matrix, nDims, labels, "coords.csv");
    }

    public void plot(INDArray matrix, int nDims, List<String> labels, String path) throws IOException {
        this.calculate(matrix, nDims, this.perplexity);
        BufferedWriter write = new BufferedWriter(new FileWriter(new File(path), true));
        for (int i = 0; i < this.y.rows() && i < labels.size(); ++i) {
            String word = labels.get(i);
            if (word == null) continue;
            StringBuffer sb = new StringBuffer();
            INDArray wordVector = this.y.getRow(i);
            for (int j = 0; j < wordVector.length(); ++j) {
                sb.append(wordVector.getDouble(j));
                if (j >= wordVector.length() - 1) continue;
                sb.append(",");
            }
            sb.append(",");
            sb.append(word);
            sb.append(" ");
            sb.append("\n");
            write.write(sb.toString());
        }
        write.flush();
        write.close();
    }

    public INDArray getY() {
        return this.y;
    }

    public void setY(INDArray y) {
        this.y = y;
    }

    public IterationListener getIterationListener() {
        return this.iterationListener;
    }

    public void setIterationListener(IterationListener iterationListener) {
        this.iterationListener = iterationListener;
    }

    static {
        Tsne.loadIntoTmp();
        log = LoggerFactory.getLogger(Tsne.class);
    }

    public static class Builder {
        protected int maxIter = 1000;
        protected double realMin = 1.0E-12f;
        protected double initialMomentum = 0.5;
        protected double finalMomentum = 0.8f;
        protected double momentum = 0.5;
        protected int switchMomentumIteration = 100;
        protected boolean normalize = true;
        protected boolean usePca = false;
        protected int stopLyingIteration = 100;
        protected double tolerance = 1.0E-5f;
        protected double learningRate = 0.1f;
        protected boolean useAdaGrad = true;
        protected double perplexity = 30.0;
        protected double minGain = 0.1f;

        public Builder minGain(double minGain) {
            this.minGain = minGain;
            return this;
        }

        public Builder perplexity(double perplexity) {
            this.perplexity = perplexity;
            return this;
        }

        public Builder useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder learningRate(double learningRate) {
            this.learningRate = learningRate;
            return this;
        }

        public Builder tolerance(double tolerance) {
            this.tolerance = tolerance;
            return this;
        }

        public Builder stopLyingIteration(int stopLyingIteration) {
            this.stopLyingIteration = stopLyingIteration;
            return this;
        }

        public Builder usePca(boolean usePca) {
            this.usePca = usePca;
            return this;
        }

        public Builder normalize(boolean normalize) {
            this.normalize = normalize;
            return this;
        }

        public Builder setMaxIter(int maxIter) {
            this.maxIter = maxIter;
            return this;
        }

        public Builder setRealMin(double realMin) {
            this.realMin = realMin;
            return this;
        }

        public Builder setInitialMomentum(double initialMomentum) {
            this.initialMomentum = initialMomentum;
            return this;
        }

        public Builder setFinalMomentum(double finalMomentum) {
            this.finalMomentum = finalMomentum;
            return this;
        }

        public Builder setMomentum(double momentum) {
            this.momentum = momentum;
            return this;
        }

        public Builder setSwitchMomentumIteration(int switchMomentumIteration) {
            this.switchMomentumIteration = switchMomentumIteration;
            return this;
        }

        public Tsne build() {
            return new Tsne(this.maxIter, this.realMin, this.initialMomentum, this.finalMomentum, this.momentum, this.switchMomentumIteration, this.normalize, this.usePca, this.stopLyingIteration, this.tolerance, this.learningRate, this.useAdaGrad, this.perplexity, this.minGain);
        }
    }
}

