/*
 * Decompiled with CFR 0.152.
 */
package tagbio.umap;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Random;
import tagbio.umap.CooMatrix;
import tagbio.umap.CsrMatrix;
import tagbio.umap.Curve;
import tagbio.umap.DefaultMatrix;
import tagbio.umap.FlatTree;
import tagbio.umap.Heap;
import tagbio.umap.IndexedDistances;
import tagbio.umap.MathUtils;
import tagbio.umap.Matrix;
import tagbio.umap.NearestNeighborDescent;
import tagbio.umap.NearestNeighborSearch;
import tagbio.umap.PairwiseDistances;
import tagbio.umap.ParallelNearestNeighborDescent;
import tagbio.umap.RandomProjectionTree;
import tagbio.umap.SearchGraph;
import tagbio.umap.UmapProgress;
import tagbio.umap.Utils;
import tagbio.umap.metric.CategoricalMetric;
import tagbio.umap.metric.EuclideanMetric;
import tagbio.umap.metric.Metric;
import tagbio.umap.metric.PrecomputedMetric;
import tagbio.umap.metric.ReducedEuclideanMetric;

public class Umap {
    private static final float SMOOTH_K_TOLERANCE = 1.0E-5f;
    private static final float MIN_K_DIST_SCALE = 0.001f;
    private static final int SMALL_PROBLEM_THRESHOLD = 4096;
    private boolean mAngularRpForest = false;
    private int mNNeighbors = 15;
    private int mNComponents = 2;
    private Integer mNEpochs = null;
    private Metric mMetric = EuclideanMetric.SINGLETON;
    private float mLearningRate = 1.0f;
    private float mRepulsionStrength = 1.0f;
    private float mMinDist = 0.1f;
    private float mSpread = 1.0f;
    private float mSetOpMixRatio = 1.0f;
    private int mLocalConnectivity = 1;
    private int mNegativeSampleRate = 5;
    private float mTransformQueueSize = 4.0f;
    private Metric mTargetMetric = CategoricalMetric.SINGLETON;
    private int mTargetNNeighbors = -1;
    private float mTargetWeight = 0.5f;
    private boolean mVerbose = false;
    private Random mRandom = new Random(42L);
    private int mThreads = 1;
    private float mInitialAlpha;
    private int mRunNNeighbors;
    private float mRunA;
    private float mRunB;
    private Matrix mRawData;
    private SearchGraph mSearchGraph = null;
    private int[][] mKnnIndices;
    private float[][] mKnnDists;
    private List<FlatTree> mRpForest;
    private boolean mSmallData;
    private Matrix mGraph;
    private Matrix mEmbedding;
    private NearestNeighborSearch mSearch;

    private static float[][] smoothKnnDist(float[][] distances, float k, int nIter, int localConnectivity, float bandwidth) {
        float target = (float)(MathUtils.log2(k) * (double)bandwidth);
        float[] rho = new float[distances.length];
        float[] result = new float[distances.length];
        float meanDistances = MathUtils.mean(distances);
        for (int i = 0; i < distances.length; ++i) {
            float lo = 0.0f;
            float hi = Float.POSITIVE_INFINITY;
            float mid = 1.0f;
            float[] ithDistances = distances[i];
            float[] nonZeroDists = MathUtils.filterPositive(ithDistances);
            if (nonZeroDists.length >= localConnectivity) {
                int index = (int)Math.floor(localConnectivity);
                float interpolation = localConnectivity - index;
                if (index > 0) {
                    rho[i] = nonZeroDists[index - 1];
                    if (interpolation > 1.0E-5f) {
                        int n = i;
                        rho[n] = rho[n] + interpolation * (nonZeroDists[index] - nonZeroDists[index - 1]);
                    }
                } else {
                    rho[i] = interpolation * nonZeroDists[0];
                }
            } else if (nonZeroDists.length > 0) {
                rho[i] = MathUtils.max(nonZeroDists);
            }
            for (int n = 0; n < nIter; ++n) {
                double pSum = 0.0;
                for (int j = 1; j < distances[0].length; ++j) {
                    double d = distances[i][j] - rho[i];
                    pSum += d > 0.0 ? Math.exp(-(d / (double)mid)) : 1.0;
                }
                if (Math.abs(pSum - (double)target) < (double)1.0E-5f) break;
                if (pSum > (double)target) {
                    hi = mid;
                    mid = (lo + hi) / 2.0f;
                    continue;
                }
                lo = mid;
                if (hi == Float.POSITIVE_INFINITY) {
                    mid *= 2.0f;
                    continue;
                }
                mid = (lo + hi) / 2.0f;
            }
            result[i] = mid;
            if (rho[i] > 0.0f) {
                float meanIthDistances = MathUtils.mean(ithDistances);
                if (!(result[i] < 0.001f * meanIthDistances)) continue;
                result[i] = 0.001f * meanIthDistances;
                continue;
            }
            if (!(result[i] < 0.001f * meanDistances)) continue;
            result[i] = 0.001f * meanDistances;
        }
        return new float[][]{result, rho};
    }

    static float[][] smoothKnnDist(float[][] distances, float k, int localConnectivity) {
        return Umap.smoothKnnDist(distances, k, 64, localConnectivity, 1.0f);
    }

    static IndexedDistances nearestNeighbors(Matrix instances, int nNeighbors, Metric metric, boolean angular, Random random, int threads, boolean verbose) {
        List<FlatTree> rpForest;
        float[][] knnDists;
        int[][] knnIndices;
        if (verbose) {
            Utils.message("Finding nearest neighbors");
        }
        if (metric.equals(PrecomputedMetric.SINGLETON)) {
            knnIndices = Utils.fastKnnIndices(instances, nNeighbors);
            knnDists = new float[knnIndices.length][nNeighbors];
            for (int i = 0; i < knnDists.length; ++i) {
                for (int j = 0; j < nNeighbors; ++j) {
                    knnDists[i][j] = instances.get(i, knnIndices[i][j]);
                }
            }
            rpForest = Collections.emptyList();
        } else {
            boolean isAngular = metric.isAngular();
            if (instances instanceof CsrMatrix) {
                throw new UnsupportedOperationException();
            }
            NearestNeighborDescent metricNearestNeighborsDescent = threads == 1 ? new NearestNeighborDescent(metric) : new ParallelNearestNeighborDescent(metric, threads);
            int nTrees = 5 + (int)Math.round(Math.pow(instances.rows(), 0.5) / 20.0);
            int nIters = Math.max(5, (int)Math.round(MathUtils.log2(instances.rows())));
            UmapProgress.incTotal(nIters + nTrees + 2);
            if (verbose) {
                Utils.message("Building random projection forest with " + nTrees + " trees");
            }
            rpForest = RandomProjectionTree.makeForest(instances, nNeighbors, nTrees, random, isAngular, threads);
            if (verbose) {
                long nodeCount = 0L;
                for (FlatTree tree : rpForest) {
                    int[][] nArray = tree.getIndices();
                    int n = nArray.length;
                    for (int i = 0; i < n; ++i) {
                        int[] a;
                        for (int b : a = nArray[i]) {
                            if (b < 0) continue;
                            ++nodeCount;
                        }
                    }
                }
                Utils.message("Total number of values in forest: " + nodeCount);
                Utils.message("NN descent for " + nIters + " iterations");
            }
            metricNearestNeighborsDescent.setVerbose(verbose);
            Heap nn = metricNearestNeighborsDescent.descent(instances, nNeighbors, random, 60, true, nIters, rpForest);
            knnIndices = nn.indices();
            knnDists = nn.weights();
            if (MathUtils.containsNegative(knnIndices)) {
                Utils.message("Failed to correctly find nearest neighbors for some samples. Results may be less than ideal. Try re-running with different parameters.");
            }
        }
        if (verbose) {
            Utils.message("Finished nearest neighbor search");
        }
        return new IndexedDistances(knnIndices, knnDists, rpForest);
    }

    static CooMatrix computeMembershipStrengths(int[][] knnIndices, float[][] knnDists, float[] sigmas, float[] rhos, int rowCount, int colCount) {
        int nSamples = knnIndices.length;
        int nNeighbors = knnIndices[0].length;
        int size = nSamples * nNeighbors;
        int[] rows = new int[size];
        int[] cols = new int[size];
        float[] vals = new float[size];
        for (int i = 0; i < nSamples; ++i) {
            for (int j = 0; j < nNeighbors; ++j) {
                if (knnIndices[i][j] == -1) continue;
                float val = knnIndices[i][j] == i ? 0.0f : (knnDists[i][j] - rhos[i] <= 0.0f ? 1.0f : (float)Math.exp(-((knnDists[i][j] - rhos[i]) / sigmas[i])));
                rows[i * nNeighbors + j] = i;
                cols[i * nNeighbors + j] = knnIndices[i][j];
                vals[i * nNeighbors + j] = val;
            }
        }
        return new CooMatrix(vals, rows, cols, rowCount, colCount);
    }

    static Matrix fuzzySimplicialSet(Matrix instances, int nNeighbors, Random random, Metric metric, int[][] knnIndices, float[][] knnDists, boolean angular, float setOpMixRatio, int localConnectivity, int threads, boolean verbose) {
        if (knnIndices == null || knnDists == null) {
            IndexedDistances nn = Umap.nearestNeighbors(instances, nNeighbors, metric, angular, random, threads, verbose);
            knnIndices = nn.getIndices();
            knnDists = nn.getDistances();
        }
        float[][] sigmasRhos = Umap.smoothKnnDist(knnDists, nNeighbors, localConnectivity);
        float[] sigmas = sigmasRhos[0];
        float[] rhos = sigmasRhos[1];
        Matrix result = Umap.computeMembershipStrengths(knnIndices, knnDists, sigmas, rhos, instances.rows(), instances.rows()).eliminateZeros();
        Matrix prodMatrix = result.hadamardMultiplyTranspose();
        return result.addTranspose().subtract(prodMatrix).multiply(setOpMixRatio).add(prodMatrix.multiply(1.0f - setOpMixRatio)).eliminateZeros();
    }

    private static Matrix resetLocalConnectivity(Matrix simplicialSet) {
        Matrix nss = simplicialSet.rowNormalize();
        Matrix prodMatrix = nss.hadamardMultiplyTranspose();
        return nss.addTranspose().subtract(prodMatrix).eliminateZeros();
    }

    private static Matrix categoricalSimplicialSetIntersection(CooMatrix simplicialSet, float[] target, float unknownDist, float farDist) {
        simplicialSet.fastIntersection(target, unknownDist, farDist);
        return Umap.resetLocalConnectivity(simplicialSet.eliminateZeros());
    }

    private static Matrix generalSimplicialSetIntersection(Matrix simplicialSet1, Matrix simplicialSet2, float weight) {
        CooMatrix result = simplicialSet1.add(simplicialSet2).toCoo();
        CsrMatrix left = simplicialSet1.toCsr();
        CsrMatrix right = simplicialSet2.toCsr();
        left.intersect(right, result, weight);
        return result;
    }

    static float[] makeEpochsPerSample(float[] weights, int nEpochs) {
        float[] result = new float[weights.length];
        Arrays.fill(result, -1.0f);
        float[] nSamples = MathUtils.multiply(MathUtils.divide(weights, MathUtils.max(weights)), nEpochs);
        for (int k = 0; k < nSamples.length; ++k) {
            if (!(nSamples[k] > 0.0f)) continue;
            result[k] = (float)nEpochs / nSamples[k];
        }
        return result;
    }

    static float clip(float val) {
        return val > 4.0f ? 4.0f : (val < -4.0f ? -4.0f : val);
    }

    private Matrix optimizeLayout(Matrix headEmbedding, Matrix tailEmbedding, int[] head, int[] tail, int nEpochs, int nVertices, float[] epochsPerSample, float a, float b, Random random, float gamma, float initialAlpha, float negativeSampleRate, boolean verbose) {
        if (!(headEmbedding instanceof DefaultMatrix)) {
            throw new UnsupportedOperationException("Require matrix we can set entries on");
        }
        int dim = headEmbedding.cols();
        boolean moveOther = headEmbedding.rows() == tailEmbedding.rows();
        float alpha = initialAlpha;
        float[] epochsPerNegativeSample = MathUtils.divide(epochsPerSample, negativeSampleRate);
        float[] epochOfNextNegativeSample = Arrays.copyOf(epochsPerNegativeSample, epochsPerNegativeSample.length);
        float[] epochOfNextSample = Arrays.copyOf(epochsPerSample, epochsPerSample.length);
        for (int n = 0; n < nEpochs; ++n) {
            for (int i = 0; i < epochsPerSample.length; ++i) {
                float[] other;
                if (!(epochOfNextSample[i] <= (float)n)) continue;
                int j = head[i];
                int k = tail[i];
                float[] current = headEmbedding.row(j);
                float distSquared = ReducedEuclideanMetric.SINGLETON.distance(current, other = tailEmbedding.row(k));
                float gradCoeff = (double)distSquared > 0.0 ? (float)(-2.0 * (double)a * (double)b * Math.pow(distSquared, (double)b - 1.0) / ((double)a * Math.pow(distSquared, b) + 1.0)) : 0.0f;
                for (int d = 0; d < dim; ++d) {
                    float gradD = Umap.clip(gradCoeff * (current[d] - other[d]));
                    int n2 = d;
                    current[n2] = current[n2] + gradD * alpha;
                    if (!moveOther) continue;
                    int n3 = d;
                    other[n3] = other[n3] + -gradD * alpha;
                }
                int n4 = i;
                epochOfNextSample[n4] = epochOfNextSample[n4] + epochsPerSample[i];
                int nNegSamples = (int)(((float)n - epochOfNextNegativeSample[i]) / epochsPerNegativeSample[i]);
                for (int p = 0; p < nNegSamples; ++p) {
                    int kr = random.nextInt(nVertices);
                    other = tailEmbedding.row(kr);
                    distSquared = ReducedEuclideanMetric.SINGLETON.distance(current, other);
                    if (distSquared > 0.0f) {
                        gradCoeff = 2.0f * gamma * b / (float)((0.001 + (double)distSquared) * ((double)a * Math.pow(distSquared, b) + 1.0));
                    } else {
                        if (j == kr) continue;
                        gradCoeff = 0.0f;
                    }
                    int d = 0;
                    while (d < dim) {
                        float gradD = (double)gradCoeff > 0.0 ? Umap.clip(gradCoeff * (current[d] - other[d])) : 4.0f;
                        int n5 = d++;
                        current[n5] = current[n5] + gradD * alpha;
                    }
                }
                int n6 = i;
                epochOfNextNegativeSample[n6] = epochOfNextNegativeSample[n6] + (float)nNegSamples * epochsPerNegativeSample[i];
            }
            alpha = initialAlpha * (1.0f - (float)n / (float)nEpochs);
            if (verbose && n % (nEpochs / 10) == 0) {
                Utils.message("Completed " + n + "/" + nEpochs);
            }
            UmapProgress.update();
        }
        return headEmbedding;
    }

    private Matrix simplicialSetEmbedding(Matrix data, Matrix graphIn, int nComponents, float initialAlpha, float a, float b, float gamma, int negativeSampleRate, int nEpochs, String init, Random random, Metric metric, boolean verbose) {
        CooMatrix graph = graphIn.toCoo();
        int nVertices = graph.cols();
        if (nEpochs <= 0) {
            nEpochs = graph.rows() <= 10000 ? 500 : 200;
        }
        float[] graphData = graph.data();
        MathUtils.zeroEntriesBelowLimit(graphData, MathUtils.max(graphData) / (float)nEpochs);
        graph = (CooMatrix)graph.eliminateZeros();
        if (!"random".equals(init)) {
            if ("spectral".equals(init)) {
                throw new UnsupportedOperationException();
            }
            throw new UnsupportedOperationException();
        }
        DefaultMatrix embedding = new DefaultMatrix(MathUtils.uniform(random, -10.0f, 10.0f, graph.rows(), nComponents));
        float[] epochsPerSample = Umap.makeEpochsPerSample(graph.data(), nEpochs);
        int[] head = graph.row();
        int[] tail = graph.col();
        return this.optimizeLayout(embedding, embedding, head, tail, nEpochs, nVertices, epochsPerSample, a, b, random, gamma, initialAlpha, negativeSampleRate, verbose);
    }

    private static Matrix initTransform(int[][] indices, float[][] weights, Matrix embedding) {
        float[][] result = new float[indices.length][embedding.cols()];
        for (int i = 0; i < indices.length; ++i) {
            for (int j = 0; j < indices[i].length; ++j) {
                for (int d = 0; d < embedding.cols(); ++d) {
                    float[] fArray = result[i];
                    int n = d;
                    fArray[n] = fArray[n] + weights[i][j] * embedding.get(indices[i][j], d);
                }
            }
        }
        return new DefaultMatrix(result);
    }

    private static float[] findAbParams(float spread, float minDist) {
        return Curve.curveFit(spread, minDist);
    }

    public Umap setNumberNearestNeighbours(int neighbors) {
        if (neighbors < 2) {
            throw new IllegalArgumentException("Number of neighbors must be greater than 2.");
        }
        this.mNNeighbors = neighbors;
        return this;
    }

    public Umap setNumberComponents(int components) {
        if (components < 1) {
            throw new IllegalArgumentException("Number of components must be greater than 0.");
        }
        this.mNComponents = components;
        return this;
    }

    public Umap setNumberEpochs(Integer epochs) {
        if (epochs != null && epochs <= 10) {
            throw new IllegalArgumentException("Epochs must be larger than 10.");
        }
        this.mNEpochs = epochs;
        return this;
    }

    public Umap setMetric(Metric metric) {
        if (metric == null) {
            throw new NullPointerException("Null metric not permitted.");
        }
        this.mMetric = metric;
        return this;
    }

    public Umap setMetric(String metric) {
        this.setMetric(Metric.getMetric(metric));
        return this;
    }

    public Umap setLearningRate(float rate) {
        if ((double)rate <= 0.0) {
            throw new IllegalArgumentException("Learning rate must be positive.");
        }
        this.mLearningRate = rate;
        return this;
    }

    public Umap setRepulsionStrength(float repulsionStrength) {
        if ((double)repulsionStrength < 0.0) {
            throw new IllegalArgumentException("Repulsion strength cannot be negative.");
        }
        this.mRepulsionStrength = repulsionStrength;
        return this;
    }

    public Umap setMinDist(float minDist) {
        if ((double)minDist < 0.0) {
            throw new IllegalArgumentException("Minimum distance must be greater than 0.0.");
        }
        this.mMinDist = minDist;
        return this;
    }

    public Umap setSpread(float spread) {
        this.mSpread = spread;
        return this;
    }

    public Umap setSetOpMixRatio(float setOpMixRatio) {
        if ((double)setOpMixRatio < 0.0 || (double)setOpMixRatio > 1.0) {
            throw new IllegalArgumentException("Set operation mixing ratio be between 0.0 and 1.0.");
        }
        this.mSetOpMixRatio = setOpMixRatio;
        return this;
    }

    public Umap setLocalConnectivity(int localConnectivity) {
        this.mLocalConnectivity = localConnectivity;
        return this;
    }

    public Umap setNegativeSampleRate(int negativeSampleRate) {
        if (negativeSampleRate <= 0) {
            throw new IllegalArgumentException("Negative sample rate must be positive.");
        }
        this.mNegativeSampleRate = negativeSampleRate;
        return this;
    }

    public Umap setTargetMetric(Metric targetMetric) {
        this.mTargetMetric = targetMetric;
        return this;
    }

    public Umap setTargetMetric(String targetMetric) {
        this.setTargetMetric(Metric.getMetric(targetMetric));
        return this;
    }

    public Umap setVerbose(boolean verbose) {
        this.mVerbose = verbose;
        return this;
    }

    public Umap setRandom(Random random) {
        this.mRandom = random;
        return this;
    }

    public Umap setSeed(long seed) {
        this.mRandom.setSeed(seed);
        return this;
    }

    public Umap setTransformQueueSize(float transformQueueSize) {
        this.mTransformQueueSize = transformQueueSize;
        return this;
    }

    public Umap setAngularRpForest(boolean angularRpForest) {
        this.mAngularRpForest = angularRpForest;
        return this;
    }

    public Umap setTargetNNeighbors(int targetNNeighbors) {
        if (targetNNeighbors < 2 && targetNNeighbors != -1) {
            throw new IllegalArgumentException("targetNNeighbors must be greater than 2");
        }
        this.mTargetNNeighbors = targetNNeighbors;
        return this;
    }

    public Umap setTargetWeight(float targetWeight) {
        this.mTargetWeight = targetWeight;
        return this;
    }

    public Umap setThreads(int threads) {
        if (threads < 1) {
            throw new IllegalArgumentException("threads must be at least 1");
        }
        this.mThreads = threads;
        return this;
    }

    private void validateParameters() {
        if (this.mMinDist > this.mSpread) {
            throw new IllegalArgumentException("minDist must be less than or equal to spread");
        }
    }

    private void fit(Matrix instances, float[] y) {
        int nEpochs;
        if (!instances.isFinite()) {
            throw new IllegalArgumentException("Supplied matrix of instances contains non-finite elements");
        }
        UmapProgress.reset(5);
        if (this.mVerbose) {
            Utils.message("Starting fitting for " + instances.rows() + " instances with " + instances.cols() + " attributes");
        }
        this.mRawData = instances;
        float[] ab = Umap.findAbParams(this.mSpread, this.mMinDist);
        this.mRunA = ab[0];
        this.mRunB = ab[1];
        this.mInitialAlpha = this.mLearningRate;
        this.validateParameters();
        UmapProgress.update();
        if (instances.rows() <= this.mNNeighbors) {
            if (instances.rows() == 1) {
                this.mEmbedding = new DefaultMatrix(new float[1][this.mNComponents]);
                return;
            }
            Utils.message("nNeighbors is larger than the dataset size; truncating to X.length - 1");
            this.mRunNNeighbors = instances.rows() - 1;
        } else {
            this.mRunNNeighbors = this.mNNeighbors;
        }
        if (this.mVerbose) {
            Utils.message("Construct fuzzy simplicial set: " + instances.rows());
        }
        UmapProgress.update();
        if (instances.rows() < 4096) {
            this.mSmallData = true;
            Matrix dmat = PairwiseDistances.pairwiseDistances(instances, this.mMetric);
            this.mGraph = Umap.fuzzySimplicialSet(dmat, this.mRunNNeighbors, this.mRandom, PrecomputedMetric.SINGLETON, null, null, this.mAngularRpForest, this.mSetOpMixRatio, this.mLocalConnectivity, this.mThreads, this.mVerbose);
        } else {
            this.mSmallData = false;
            IndexedDistances nn = Umap.nearestNeighbors(instances, this.mRunNNeighbors, this.mMetric, this.mAngularRpForest, this.mRandom, this.mThreads, this.mVerbose);
            this.mKnnIndices = nn.getIndices();
            this.mKnnDists = nn.getDistances();
            this.mRpForest = nn.getForest();
            this.mGraph = Umap.fuzzySimplicialSet(instances, this.mNNeighbors, this.mRandom, this.mMetric, this.mKnnIndices, this.mKnnDists, this.mAngularRpForest, this.mSetOpMixRatio, this.mLocalConnectivity, this.mThreads, this.mVerbose);
            Metric distanceFunc = this.mMetric;
            if (this.mMetric == PrecomputedMetric.SINGLETON) {
                Utils.message("Using precomputed metric; transform will be unavailable for new data");
            } else {
                this.mSearch = new NearestNeighborSearch(distanceFunc);
            }
        }
        UmapProgress.update();
        if (y != null) {
            if (instances.length() != (long)y.length) {
                throw new IllegalArgumentException("Length of x =  " + instances.length() + ", length of y = " + y.length + ", while it must be equal.");
            }
            if (CategoricalMetric.SINGLETON.equals(this.mTargetMetric)) {
                float farDist = this.mTargetWeight < 1.0f ? 2.5f * (1.0f / (1.0f - this.mTargetWeight)) : 1.0E12f;
                this.mGraph = Umap.categoricalSimplicialSetIntersection((CooMatrix)this.mGraph, y, 1.0f, farDist);
            } else {
                Matrix targetGraph;
                int targetNNeighbors;
                int n = targetNNeighbors = this.mTargetNNeighbors == -1 ? this.mRunNNeighbors : this.mTargetNNeighbors;
                if (y.length < 4096) {
                    Matrix ydmat = PairwiseDistances.pairwiseDistances(MathUtils.promoteTranspose(y), this.mTargetMetric);
                    targetGraph = Umap.fuzzySimplicialSet(ydmat, targetNNeighbors, this.mRandom, PrecomputedMetric.SINGLETON, null, null, false, 1.0f, 1, this.mThreads, false);
                } else {
                    targetGraph = Umap.fuzzySimplicialSet(MathUtils.promoteTranspose(y), targetNNeighbors, this.mRandom, this.mTargetMetric, null, null, false, 1.0f, 1, this.mThreads, false);
                }
                this.mGraph = Umap.generalSimplicialSetIntersection(this.mGraph, targetGraph, this.mTargetWeight);
                this.mGraph = Umap.resetLocalConnectivity(this.mGraph);
            }
        }
        UmapProgress.incTotal(this.mNEpochs == null ? (this.mGraph.rows() <= 10000 ? 500 : 200) : this.mNEpochs);
        UmapProgress.update();
        int n = nEpochs = this.mNEpochs == null ? 0 : this.mNEpochs;
        if (this.mVerbose) {
            Utils.message("Construct embedding");
        }
        this.mEmbedding = this.simplicialSetEmbedding(this.mRawData, this.mGraph, this.mNComponents, this.mInitialAlpha, this.mRunA, this.mRunB, this.mRepulsionStrength, this.mNegativeSampleRate, nEpochs, "random", this.mRandom, this.mMetric, this.mVerbose);
        if (this.mVerbose) {
            Utils.message("Finished embedding");
        }
        UmapProgress.finished();
    }

    public Matrix fitTransform(Matrix instances, float[] y) {
        this.fit(instances, y);
        return this.mEmbedding;
    }

    public Matrix fitTransform(Matrix instances) {
        return this.fitTransform(instances, null);
    }

    public float[][] fitTransform(float[][] instances) {
        return this.fitTransform(new DefaultMatrix(instances), null).toArray();
    }

    public double[][] fitTransform(double[][] instances) {
        float[][] input = new float[instances.length][instances[0].length];
        for (int k = 0; k < instances.length; ++k) {
            for (int j = 0; j < instances[0].length; ++j) {
                input[k][j] = (float)instances[k][j];
            }
        }
        Matrix result = this.fitTransform(new DefaultMatrix(input), null);
        double[][] output = new double[result.rows()][result.cols()];
        for (int k = 0; k < result.rows(); ++k) {
            for (int j = 0; j < result.cols(); ++j) {
                output[k][j] = result.get(k, j);
            }
        }
        return output;
    }

    public Matrix transform(Matrix instances) {
        float[][] dists;
        Object indices;
        if (this.mEmbedding.rows() == 1) {
            throw new IllegalArgumentException("Transform unavailable when model was fit with only a single data sample.");
        }
        if (this.mRawData instanceof CsrMatrix) {
            throw new IllegalArgumentException("Transform not available for sparse input.");
        }
        if (this.mMetric instanceof PrecomputedMetric) {
            throw new IllegalArgumentException("Transform of new data not available for precomputed metric.");
        }
        UmapProgress.reset(4);
        if (this.mSmallData) {
            Matrix distanceMatrix = PairwiseDistances.pairwiseDistances(instances, this.mRawData, this.mMetric);
            indices = new int[distanceMatrix.rows()][];
            for (int k = 0; k < distanceMatrix.rows(); ++k) {
                indices[k] = MathUtils.argsort(Arrays.copyOf(distanceMatrix.row(k), distanceMatrix.cols()));
            }
            indices = MathUtils.subarray(indices, this.mRunNNeighbors);
            dists = Utils.submatrix(distanceMatrix, indices, this.mRunNNeighbors);
        } else {
            Heap init = NearestNeighborDescent.initialiseSearch(this.mRpForest, this.mRawData, instances, (int)((float)this.mRunNNeighbors * this.mTransformQueueSize), this.mSearch, this.mRandom);
            if (this.mSearchGraph == null) {
                this.mSearchGraph = new SearchGraph(this.mRawData.rows());
                for (int k = 0; k < this.mKnnIndices.length; ++k) {
                    for (int j = 0; j < this.mKnnIndices[k].length; ++j) {
                        if (this.mKnnDists[k][j] == 0.0f) continue;
                        this.mSearchGraph.set(k, this.mKnnIndices[k][j]);
                    }
                }
            }
            Heap result = this.mSearch.initializedNndSearch(this.mRawData, this.mSearchGraph, init, instances).deheapSort();
            indices = MathUtils.subarray(result.indices(), this.mRunNNeighbors);
            dists = MathUtils.subarray(result.weights(), this.mRunNNeighbors);
        }
        UmapProgress.update();
        int adjustedLocalConnectivity = Math.max(0, this.mLocalConnectivity - 1);
        float[][] sigmasRhos = Umap.smoothKnnDist(dists, this.mRunNNeighbors, adjustedLocalConnectivity);
        float[] sigmas = sigmasRhos[0];
        float[] rhos = sigmasRhos[1];
        CooMatrix graph = Umap.computeMembershipStrengths(indices, dists, sigmas, rhos, instances.rows(), this.mRawData.rows());
        UmapProgress.update();
        CsrMatrix csrGraph = graph.toCsr().l1Normalize().toCsr();
        int[][] inds = csrGraph.reshapeIndicies(instances.rows(), this.mRunNNeighbors);
        float[][] weights = csrGraph.reshapeWeights(instances.rows(), this.mRunNNeighbors);
        Matrix embedding = Umap.initTransform(inds, weights, this.mEmbedding);
        int nEpochs = this.mNEpochs == null ? (graph.rows() <= 10000 ? 100 : 30) : this.mNEpochs;
        MathUtils.zeroEntriesBelowLimit(graph.data(), MathUtils.max(graph.data()) / (float)nEpochs);
        graph = graph.eliminateZeros().toCoo();
        float[] epochsPerSample = Umap.makeEpochsPerSample(graph.data(), nEpochs);
        int[] head = graph.row();
        int[] tail = graph.col();
        UmapProgress.update();
        UmapProgress.incTotal(nEpochs);
        Matrix matrix = this.optimizeLayout(embedding, this.mEmbedding.copy(), head, tail, nEpochs, graph.cols(), epochsPerSample, this.mRunA, this.mRunB, this.mRandom, this.mRepulsionStrength, this.mInitialAlpha, this.mNegativeSampleRate, this.mVerbose);
        UmapProgress.finished();
        return matrix;
    }

    public float[][] transform(float[][] instances) {
        return this.transform(new DefaultMatrix(instances)).toArray();
    }
}

