/*
 * Decompiled with CFR 0.152.
 */
package ch.epfl.bbp.uima.word2vec;

import ch.epfl.bbp.uima.utils.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.FileInputStream;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;

public class Word2Vec {
    protected HashMap<String, float[]> vocabulary = new HashMap();
    protected int vocabSize;
    protected int vectorSize;
    protected int topNSize = 40;
    private static final int MAX_SIZE = 50;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public Word2Vec loadModel(String path) throws IOException {
        FilterInputStream dis = null;
        BufferedInputStream bis = null;
        double len = 0.0;
        float vector = 0.0f;
        Preconditions.checkFileExists(path);
        try {
            bis = new BufferedInputStream(new FileInputStream(path));
            dis = new DataInputStream(bis);
            this.vocabSize = Integer.parseInt(Word2Vec.readString((DataInputStream)dis));
            this.vectorSize = Integer.parseInt(Word2Vec.readString((DataInputStream)dis));
            float[] vectors = null;
            for (int i = 0; i < this.vocabSize; ++i) {
                int j;
                String word = Word2Vec.readString((DataInputStream)dis);
                vectors = new float[this.vectorSize];
                len = 0.0;
                for (j = 0; j < this.vectorSize; ++j) {
                    vector = Word2Vec.readFloat(dis);
                    len += (double)(vector * vector);
                    vectors[j] = vector;
                }
                len = Math.sqrt(len);
                for (j = 0; j < vectors.length; ++j) {
                    vectors[j] = (float)((double)vectors[j] / len);
                }
                this.vocabulary.put(word, vectors);
                dis.read();
            }
        }
        finally {
            bis.close();
            dis.close();
        }
        return this;
    }

    public Set<WordEntry> distance(String word) {
        float[] wordVector = this.getWordVector(word);
        if (wordVector == null) {
            return null;
        }
        Set<Map.Entry<String, float[]>> entrySet = this.vocabulary.entrySet();
        float[] tempVector = null;
        ArrayList<WordEntry> wordEntrys = new ArrayList<WordEntry>(this.topNSize);
        for (Map.Entry<String, float[]> entry : entrySet) {
            String name = entry.getKey();
            if (name.equals(word)) continue;
            float dist = 0.0f;
            tempVector = entry.getValue();
            for (int i = 0; i < wordVector.length; ++i) {
                dist += wordVector[i] * tempVector[i];
            }
            this.insertTopN(name, dist, wordEntrys);
        }
        return new TreeSet<WordEntry>(wordEntrys);
    }

    public TreeSet<WordEntry> analogy(String word0, String word1, String word2) {
        float[] wv0 = this.getWordVector(word0);
        float[] wv1 = this.getWordVector(word1);
        float[] wv2 = this.getWordVector(word2);
        if (wv1 == null || wv2 == null || wv0 == null) {
            return null;
        }
        float[] wordVector = new float[this.vectorSize];
        for (int i = 0; i < this.vectorSize; ++i) {
            wordVector[i] = wv1[i] - wv0[i] + wv2[i];
        }
        ArrayList<WordEntry> wordEntrys = new ArrayList<WordEntry>(this.topNSize);
        for (Map.Entry<String, float[]> entry : this.vocabulary.entrySet()) {
            String name = entry.getKey();
            if (name.equals(word0) || name.equals(word1) || name.equals(word2)) continue;
            float dist = 0.0f;
            float[] tempVector = entry.getValue();
            for (int i = 0; i < wordVector.length; ++i) {
                dist += wordVector[i] * tempVector[i];
            }
            this.insertTopN(name, dist, wordEntrys);
        }
        return new TreeSet<WordEntry>(wordEntrys);
    }

    public TreeSet<WordEntry> analogy(String query) {
        int i;
        String[] split = query.split(" ");
        com.google.common.base.Preconditions.checkArgument((split.length % 2 == 1 ? 1 : 0) != 0, (Object)("should be uneven nr: " + query));
        boolean[] operators = new boolean[(split.length - 1) / 2];
        for (int i2 = 0; i2 < (split.length - 1) / 2; ++i2) {
            if (split[2 * i2 + 1].equals("+")) {
                operators[i2] = true;
                continue;
            }
            if (split[2 * i2 + 1].equals("-")) {
                operators[i2] = false;
                continue;
            }
            throw new RuntimeException("illegal operator, was: '" + operators[i2]);
        }
        float[][] ww = new float[(split.length + 1) / 2][this.getSize()];
        ArrayList queryNames = Lists.newArrayList();
        for (i = 0; i < (split.length + 1) / 2; ++i) {
            String word = split[2 * i];
            queryNames.add(word);
            ww[i] = this.getWordVector(word);
            if (ww[i] != null) continue;
            throw new RuntimeException("no ww for: " + word);
        }
        Preconditions.checkEquals(ww.length - 1, operators.length);
        System.out.print("'" + (String)queryNames.get(0) + "' ");
        for (i = 0; i < operators.length; ++i) {
            System.out.print("'" + operators[i] + "' ");
            System.out.print("'" + (String)queryNames.get(i + 1) + "' ");
        }
        System.out.println();
        ArrayList<WordEntry> wordEntrys = new ArrayList<WordEntry>(this.topNSize);
        float[] wordVector = new float[this.vectorSize];
        for (int i3 = 0; i3 < this.vectorSize; ++i3) {
            float val = ww[0][i3];
            for (int j = 0; j < operators.length; ++j) {
                if (operators[j]) {
                    val += ww[j + 1][i3];
                    continue;
                }
                val -= ww[j + 1][i3];
            }
            wordVector[i3] = val;
        }
        for (Map.Entry<String, float[]> entry : this.vocabulary.entrySet()) {
            String name = entry.getKey();
            if (queryNames.contains(name)) continue;
            float dist = 0.0f;
            float[] tempVector = entry.getValue();
            for (int i4 = 0; i4 < wordVector.length; ++i4) {
                dist += wordVector[i4] * tempVector[i4];
            }
            this.insertTopN(name, dist, wordEntrys);
        }
        return new TreeSet<WordEntry>(wordEntrys);
    }

    protected void insertTopN(String name, float score, List<WordEntry> wordsEntrys) {
        if (wordsEntrys.size() < this.topNSize) {
            wordsEntrys.add(new WordEntry(name, score));
            return;
        }
        float min = Float.MAX_VALUE;
        int minOffe = 0;
        for (int i = 0; i < this.topNSize; ++i) {
            WordEntry wordEntry = wordsEntrys.get(i);
            if (!(min > wordEntry.score)) continue;
            min = wordEntry.score;
            minOffe = i;
        }
        if (score > min) {
            wordsEntrys.set(minOffe, new WordEntry(name, score));
        }
    }

    public float[] getWordVector(String word) {
        return this.vocabulary.get(word);
    }

    protected static float readFloat(InputStream is) throws IOException {
        byte[] bytes = new byte[4];
        is.read(bytes);
        return Word2Vec.getFloat(bytes);
    }

    protected static float getFloat(byte[] b) {
        int accum = 0;
        accum |= (b[0] & 0xFF) << 0;
        accum |= (b[1] & 0xFF) << 8;
        accum |= (b[2] & 0xFF) << 16;
        return Float.intBitsToFloat(accum |= (b[3] & 0xFF) << 24);
    }

    protected static String readString(DataInputStream dis) throws IOException {
        byte[] bytes = new byte[50];
        byte b = dis.readByte();
        int i = -1;
        StringBuilder sb = new StringBuilder();
        while (b != 32 && b != 10) {
            bytes[++i] = b;
            b = dis.readByte();
            if (i != 49) continue;
            sb.append(new String(bytes));
            i = -1;
            bytes = new byte[50];
        }
        sb.append(new String(bytes, 0, i + 1));
        return sb.toString();
    }

    public int getTopNSize() {
        return this.topNSize;
    }

    public void setTopNSize(int topNSize) {
        this.topNSize = topNSize;
    }

    public HashMap<String, float[]> getWordMap() {
        return this.vocabulary;
    }

    public int getVocabSize() {
        return this.vocabSize;
    }

    public int getSize() {
        return this.vectorSize;
    }

    public List<Integer> getMostFrequentTopics(String word, float threshold) {
        ArrayList<Integer> ret = new ArrayList<Integer>();
        float[] wordVector = this.getWordVector(word);
        for (int i = 0; i < wordVector.length; ++i) {
            if (!(wordVector[i] > threshold)) continue;
            ret.add(i);
        }
        return ret;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static HashMap<String, Integer> loadClassModel(String path) throws IOException {
        Preconditions.checkFileExists(path);
        HashMap wordMap = Maps.newHashMap();
        FilterInputStream dis = null;
        BufferedInputStream bis = null;
        try {
            try {
                bis = new BufferedInputStream(new FileInputStream(path));
                dis = new DataInputStream(bis);
                while (true) {
                    wordMap.put(Word2Vec.readString((DataInputStream)dis), Integer.parseInt(Word2Vec.readString((DataInputStream)dis)));
                }
            }
            catch (Exception e) {
                bis.close();
                dis.close();
            }
        }
        catch (Throwable throwable) {
            bis.close();
            dis.close();
            throw throwable;
        }
        return wordMap;
    }

    public static class WordEntry
    implements Comparable<WordEntry> {
        public String name;
        public float score;

        public WordEntry(String name, float score) {
            this.name = name;
            this.score = score;
        }

        public String toString() {
            return this.name + "\t" + this.score;
        }

        @Override
        public int compareTo(WordEntry o) {
            if (this.score > o.score) {
                return -1;
            }
            return 1;
        }
    }
}

