/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.embeddings.wordvectors;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.embeddings.inmemory.InMemoryLookupTable;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.stopwords.StopWords;
import org.deeplearning4j.util.MathUtils;
import org.deeplearning4j.util.SetUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public class WordVectorsImpl
implements WordVectors {
    protected int minWordFrequency = 5;
    protected WeightLookupTable lookupTable;
    protected VocabCache vocab;
    protected int layerSize = 100;
    public static final String UNK = "UNK";
    protected List<String> stopWords = StopWords.getStopWords();

    @Override
    public boolean hasWord(String word) {
        return this.vocab().indexOf(word) >= 0;
    }

    @Override
    public Collection<String> wordsNearestSum(List<String> positive, List<String> negative, int top) {
        INDArray words = Nd4j.create((int)this.lookupTable().layerSize());
        Set union = SetUtils.union(new HashSet<String>(positive), new HashSet<String>(negative));
        for (String s : positive) {
            words.addi(this.lookupTable().vector(s));
        }
        for (String s : negative) {
            words.addi(this.lookupTable.vector(s).mul((Number)-1));
        }
        if (this.lookupTable() instanceof InMemoryLookupTable) {
            InMemoryLookupTable l = (InMemoryLookupTable)this.lookupTable();
            INDArray syn0 = l.getSyn0();
            INDArray weights = syn0.norm2(0).rdivi((Number)1).muli(words);
            INDArray distances = syn0.mulRowVector(weights).sum(1);
            INDArray[] sorted = Nd4j.sortWithIndices((INDArray)distances, (int)0, (boolean)false);
            INDArray sort = sorted[0];
            ArrayList<String> ret = new ArrayList<String>();
            if (top > sort.length()) {
                top = sort.length();
            }
            int end = top + 1;
            for (int i = 0; i < end; ++i) {
                String word = this.vocab.wordAtIndex(sort.getInt(new int[]{i}));
                if (union.contains(word)) {
                    if (++end < sort.length()) continue;
                    break;
                }
                ret.add(this.vocab().wordAtIndex(sort.getInt(new int[]{i})));
            }
            return ret;
        }
        Counter distances = new Counter();
        for (String s : this.vocab().words()) {
            INDArray otherVec = this.getWordVectorMatrix(s);
            double sim = Transforms.cosineSim((INDArray)words, (INDArray)otherVec);
            distances.incrementCount((Object)s, sim);
        }
        distances.keepTopNKeys(top);
        return distances.keySet();
    }

    @Override
    public Collection<String> wordsNearestSum(String word, int n) {
        INDArray vec = Transforms.unitVec((INDArray)this.getWordVectorMatrix(word));
        if (this.lookupTable() instanceof InMemoryLookupTable) {
            InMemoryLookupTable l = (InMemoryLookupTable)this.lookupTable();
            INDArray syn0 = l.getSyn0();
            INDArray weights = syn0.norm2(0).rdivi((Number)1).muli(vec);
            INDArray distances = syn0.mulRowVector(weights).sum(1);
            INDArray[] sorted = Nd4j.sortWithIndices((INDArray)distances, (int)0, (boolean)false);
            INDArray sort = sorted[0];
            ArrayList<String> ret = new ArrayList<String>();
            VocabWord word2 = this.vocab().wordFor(word);
            if (n > sort.length()) {
                n = sort.length();
            }
            for (int i = 0; i < n + 1; ++i) {
                if (sort.getInt(new int[]{i}) == word2.getIndex()) continue;
                ret.add(this.vocab().wordAtIndex(sort.getInt(new int[]{i})));
            }
            return ret;
        }
        if (vec == null) {
            return new ArrayList<String>();
        }
        Counter distances = new Counter();
        for (String s : this.vocab().words()) {
            if (s.equals(word)) continue;
            INDArray otherVec = this.getWordVectorMatrix(s);
            double sim = Transforms.cosineSim((INDArray)vec, (INDArray)otherVec);
            distances.incrementCount((Object)s, sim);
        }
        distances.keepTopNKeys(n);
        return distances.keySet();
    }

    @Override
    public Map<String, Double> accuracy(List<String> questions) {
        HashMap<String, Double> accuracy = new HashMap<String, Double>();
        Counter right = new Counter();
        for (String s : questions) {
            if (s.startsWith(":")) {
                double correct = right.getCount((Object)"correct");
                double wrong = right.getCount((Object)"wrong");
                double accuracyRet = 100.0 * correct / (correct / wrong);
                accuracy.put(s, accuracyRet);
                right.clear();
                continue;
            }
            String[] split = s.split(" ");
            String word = split[0];
            List<String> positive = Arrays.asList(word);
            String predicted = split[3];
            List<String> negative = Arrays.asList(split[1], split[2]);
            String w = this.wordsNearest(positive, negative, 1).iterator().next();
            if (predicted.equals(w)) {
                right.incrementCount((Object)"right", 1.0);
                continue;
            }
            right.incrementCount((Object)"wrong", 1.0);
        }
        return accuracy;
    }

    @Override
    public int indexOf(String word) {
        return this.vocab().indexOf(word);
    }

    @Override
    public List<String> similarWordsInVocabTo(String word, double accuracy) {
        ArrayList<String> ret = new ArrayList<String>();
        for (String s : this.vocab.words()) {
            String[] stringArray = new String[]{word, s};
            if (!(MathUtils.stringSimilarity((String[])stringArray) >= accuracy)) continue;
            ret.add(s);
        }
        return ret;
    }

    @Override
    public double[] getWordVector(String word) {
        int i = this.vocab().indexOf(word);
        if (i < 0) {
            return this.lookupTable.vector(UNK).ravel().data().asDouble();
        }
        return this.lookupTable.vector(word).ravel().data().asDouble();
    }

    @Override
    public INDArray getWordVectorMatrixNormalized(String word) {
        int i = this.vocab().indexOf(word);
        if (i < 0) {
            return this.lookupTable().vector(UNK);
        }
        INDArray r = this.lookupTable().vector(word);
        return r.div((Number)Nd4j.getBlasWrapper().nrm2(r));
    }

    @Override
    public INDArray getWordVectorMatrix(String word) {
        return this.lookupTable().vector(word);
    }

    @Override
    public Collection<String> wordsNearest(List<String> positive, List<String> negative, int top) {
        INDArray mean;
        for (String p : SetUtils.union(new HashSet<String>(positive), new HashSet<String>(negative))) {
            if (this.vocab().containsWord(p)) continue;
            return new ArrayList<String>();
        }
        INDArray words = Nd4j.create((int)(positive.size() + negative.size()), (int)this.lookupTable().layerSize());
        int row = 0;
        Set union = SetUtils.union(new HashSet<String>(positive), new HashSet<String>(negative));
        for (String s : positive) {
            words.putRow(row++, this.lookupTable().vector(s));
        }
        for (String s : negative) {
            words.putRow(row++, this.lookupTable().vector(s).mul((Number)-1));
        }
        INDArray iNDArray = mean = words.isMatrix() ? words.mean(0) : words;
        if (this.lookupTable() instanceof InMemoryLookupTable) {
            InMemoryLookupTable l = (InMemoryLookupTable)this.lookupTable();
            INDArray syn0 = l.getSyn0();
            INDArray weights = mean;
            INDArray distances = syn0.mmul(weights.transpose());
            distances.diviRowVector(distances.norm2(1));
            INDArray[] sorted = Nd4j.sortWithIndices((INDArray)distances, (int)0, (boolean)false);
            INDArray sort = sorted[0];
            ArrayList<String> ret = new ArrayList<String>();
            if (top > sort.length()) {
                top = sort.length();
            }
            int end = top + 1;
            for (int i = 0; i < end; ++i) {
                String word = this.vocab().wordAtIndex(sort.getInt(new int[]{i}));
                if (union.contains(word)) {
                    if (++end < sort.length()) continue;
                    break;
                }
                ret.add(this.vocab().wordAtIndex(sort.getInt(new int[]{i})));
            }
            return ret;
        }
        Counter distances = new Counter();
        for (String s : this.vocab().words()) {
            INDArray otherVec = this.getWordVectorMatrix(s);
            double sim = Transforms.cosineSim((INDArray)mean, (INDArray)otherVec);
            distances.incrementCount((Object)s, sim);
        }
        distances.keepTopNKeys(top);
        return distances.keySet();
    }

    @Override
    public Collection<String> wordsNearest(String word, int n) {
        return this.wordsNearest(Arrays.asList(word), new ArrayList<String>(), n);
    }

    @Override
    public double similarity(String word, String word2) {
        if (word.equals(word2)) {
            return 1.0;
        }
        INDArray vector = Transforms.unitVec((INDArray)this.getWordVectorMatrix(word));
        INDArray vector2 = Transforms.unitVec((INDArray)this.getWordVectorMatrix(word2));
        if (vector == null || vector2 == null) {
            return -1.0;
        }
        return Nd4j.getBlasWrapper().dot(vector, vector2);
    }

    @Override
    public VocabCache vocab() {
        return this.vocab;
    }

    @Override
    public WeightLookupTable lookupTable() {
        return this.lookupTable;
    }

    public void setLookupTable(WeightLookupTable lookupTable) {
        this.lookupTable = lookupTable;
    }

    public void setVocab(VocabCache vocab) {
        this.vocab = vocab;
    }
}

