/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.graph.models.embeddings;

import java.util.Comparator;
import java.util.PriorityQueue;
import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.api.Vertex;
import org.deeplearning4j.graph.models.GraphVectors;
import org.deeplearning4j.graph.models.embeddings.GraphVectorLookupTable;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.api.blas.Level1;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public class GraphVectorsImpl<V, E>
implements GraphVectors<V, E> {
    protected IGraph<V, E> graph;
    protected GraphVectorLookupTable lookupTable;

    @Override
    public IGraph<V, E> getGraph() {
        return this.graph;
    }

    @Override
    public int numVertices() {
        return this.lookupTable.getNumVertices();
    }

    @Override
    public int getVectorSize() {
        return this.lookupTable.vectorSize();
    }

    @Override
    public INDArray getVertexVector(Vertex<V> vertex) {
        return this.lookupTable.getVector(vertex.vertexID());
    }

    @Override
    public INDArray getVertexVector(int vertexIdx) {
        return this.lookupTable.getVector(vertexIdx);
    }

    @Override
    public int[] verticesNearest(int vertexIdx, int top) {
        INDArray vec = this.lookupTable.getVector(vertexIdx).dup();
        double norm2 = vec.norm2Number().doubleValue();
        PriorityQueue<Pair<Double, Integer>> pq = new PriorityQueue<Pair<Double, Integer>>(this.lookupTable.getNumVertices(), new PairComparator());
        Level1 l1 = Nd4j.getBlasWrapper().level1();
        for (int i = 0; i < this.numVertices(); ++i) {
            if (i == vertexIdx) continue;
            INDArray other = this.lookupTable.getVector(i);
            double cosineSim = l1.dot(vec.length(), 1.0, vec, other) / (norm2 * other.norm2Number().doubleValue());
            pq.add((Pair<Double, Integer>)new Pair((Object)cosineSim, (Object)i));
        }
        int[] out = new int[top];
        for (int i = 0; i < top; ++i) {
            out[i] = (Integer)((Pair)pq.remove()).getSecond();
        }
        return out;
    }

    @Override
    public double similarity(Vertex<V> vertex1, Vertex<V> vertex2) {
        return this.similarity(vertex1.vertexID(), vertex2.vertexID());
    }

    @Override
    public double similarity(int vertexIdx1, int vertexIdx2) {
        if (vertexIdx1 == vertexIdx2) {
            return 1.0;
        }
        INDArray vector = Transforms.unitVec((INDArray)this.getVertexVector(vertexIdx1));
        INDArray vector2 = Transforms.unitVec((INDArray)this.getVertexVector(vertexIdx2));
        return Nd4j.getBlasWrapper().dot(vector, vector2);
    }

    public GraphVectorsImpl(IGraph<V, E> graph, GraphVectorLookupTable lookupTable) {
        this.graph = graph;
        this.lookupTable = lookupTable;
    }

    public GraphVectorsImpl() {
    }

    private static class PairComparator
    implements Comparator<Pair<Double, Integer>> {
        private PairComparator() {
        }

        @Override
        public int compare(Pair<Double, Integer> o1, Pair<Double, Integer> o2) {
            return -Double.compare((Double)o1.getFirst(), (Double)o2.getFirst());
        }
    }
}

