/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.embeddings.inmemory;

import com.fasterxml.jackson.jaxrs.json.JacksonJsonProvider;
import com.google.common.util.concurrent.AtomicDouble;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.WebTarget;
import javax.ws.rs.core.Response;
import lombok.NonNull;
import org.apache.commons.io.FileUtils;
import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.models.embeddings.WeightLookupTable;
import org.deeplearning4j.models.sequencevectors.sequence.SequenceElement;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.plot.Tsne;
import org.deeplearning4j.plot.dropwizard.ObjectMapperProvider;
import org.deeplearning4j.ui.UiConnectionInfo;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.FloatBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.AdaGrad;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InMemoryLookupTable<T extends SequenceElement>
implements WeightLookupTable<T> {
    private static final Logger log = LoggerFactory.getLogger(InMemoryLookupTable.class);
    protected INDArray syn0;
    protected INDArray syn1;
    protected int vectorLength;
    protected transient Random rng = Nd4j.getRandom();
    protected AtomicDouble lr = new AtomicDouble(0.025);
    protected double[] expTable;
    protected static double MAX_EXP = 6.0;
    protected long seed = 123L;
    protected INDArray table;
    protected INDArray syn1Neg;
    protected boolean useAdaGrad;
    protected double negative = 0.0;
    protected VocabCache<T> vocab;
    protected Map<Integer, INDArray> codes = new ConcurrentHashMap<Integer, INDArray>();
    protected AdaGrad adaGrad;
    protected Long tableId;

    public InMemoryLookupTable() {
    }

    public InMemoryLookupTable(VocabCache vocab, int vectorLength, boolean useAdaGrad, double lr, Random gen, double negative) {
        this.vocab = vocab;
        this.vectorLength = vectorLength;
        this.useAdaGrad = useAdaGrad;
        this.lr.set(lr);
        this.rng = gen;
        this.negative = negative;
        this.initExpTable();
        if (useAdaGrad) {
            this.initAdaGrad();
        }
    }

    protected void initAdaGrad() {
        this.adaGrad = new AdaGrad(new int[]{this.vocab.numWords() + 1, this.vectorLength}, this.lr.get());
    }

    public double[] getExpTable() {
        return this.expTable;
    }

    public void setExpTable(double[] expTable) {
        this.expTable = expTable;
    }

    @Override
    public double getGradient(int column, double gradient) {
        if (this.adaGrad == null) {
            this.initAdaGrad();
        }
        return this.adaGrad.getGradient(gradient, column, this.syn0.shape());
    }

    @Override
    public int layerSize() {
        return this.vectorLength;
    }

    @Override
    public void resetWeights(boolean reset) {
        if (this.rng == null) {
            this.rng = Nd4j.getRandom();
        }
        this.rng.setSeed(this.seed);
        if (this.syn0 == null || reset) {
            this.syn0 = Nd4j.rand((int[])new int[]{this.vocab.numWords(), this.vectorLength}, (Random)this.rng).subi((Number)0.5).divi((Number)this.vectorLength);
        }
        if (this.syn1 == null || reset) {
            this.syn1 = Nd4j.create((int[])this.syn0.shape());
        }
        this.initNegative();
    }

    @Override
    public void plotVocab(Tsne tsne, int numWords, File file) {
        INDArray array = Nd4j.create((int)numWords, (int)this.vectorLength);
        try {
            ArrayList<String> plot = new ArrayList<String>();
            for (int i = 0; i < numWords && i < this.vocab.numWords(); ++i) {
                plot.add(this.vocab.wordAtIndex(i));
                array.putRow(i, this.syn0.slice(i));
            }
            tsne.plot(array, 2, plot, file.getAbsolutePath());
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void plotVocab(int numWords, File file) {
        Tsne tsne = new Tsne.Builder().normalize(false).setFinalMomentum((double)0.8f).setMaxIter(1000).build();
        this.plotVocab(tsne, numWords, file);
    }

    @Override
    public void plotVocab(int numWords, UiConnectionInfo connectionInfo) {
        Tsne tsne = new Tsne.Builder().normalize(false).setFinalMomentum((double)0.8f).setMaxIter(1000).build();
        this.plotVocab(tsne, numWords, connectionInfo);
    }

    @Override
    public void plotVocab(Tsne tsne, int numWords, UiConnectionInfo connectionInfo) {
        try {
            File file = File.createTempFile("tsne", "temp");
            file.deleteOnExit();
            this.plotVocab(tsne, numWords, file);
            List list = FileUtils.readLines((File)file);
            Client client = (Client)((Client)ClientBuilder.newClient().register(JacksonJsonProvider.class)).register((Object)new ObjectMapperProvider());
            WebTarget target = client.target(connectionInfo.getFirstPart()).path(connectionInfo.getSecondPart("api")).path("coords").queryParam("sid", new Object[]{connectionInfo.getSessionId()});
            Response resp = target.request(new String[]{"application/json"}).accept(new String[]{"application/json"}).post(Entity.entity((Object)list, (String)"application/json"));
            log.debug("{}", (Object)resp);
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void putCode(int codeIndex, INDArray code) {
        this.codes.put(codeIndex, code);
    }

    @Override
    public INDArray loadCodes(int[] codes) {
        return this.syn1.getRows(codes);
    }

    protected void initNegative() {
        if (this.negative > 0.0) {
            this.syn1Neg = Nd4j.zeros((int[])this.syn0.shape());
            this.makeTable(Math.max(this.expTable.length, 100000), 0.75);
        }
    }

    protected void initExpTable() {
        this.expTable = new double[100000];
        for (int i = 0; i < this.expTable.length; ++i) {
            double tmp = FastMath.exp((double)(((double)i / (double)this.expTable.length * 2.0 - 1.0) * MAX_EXP));
            this.expTable[i] = tmp / (tmp + 1.0);
        }
    }

    @Override
    @Deprecated
    public void iterateSample(T w1, T w2, AtomicLong nextRandom, double alpha) {
        INDArray neu1e;
        INDArray l1;
        block19: {
            if (w2 == null || ((SequenceElement)w2).getIndex() < 0 || ((SequenceElement)w1).getIndex() == ((SequenceElement)w2).getIndex() || ((SequenceElement)w1).getLabel().equals("STOP") || ((SequenceElement)w2).getLabel().equals("STOP") || ((SequenceElement)w1).getLabel().equals("UNK") || ((SequenceElement)w2).getLabel().equals("UNK")) {
                return;
            }
            l1 = this.syn0.slice(((SequenceElement)w2).getIndex());
            neu1e = Nd4j.create((int)this.vectorLength);
            for (int i = 0; i < ((SequenceElement)w1).getCodeLength(); ++i) {
                int idx;
                int code = ((SequenceElement)w1).getCodes().get(i);
                int point = ((SequenceElement)w1).getPoints().get(i);
                if (point >= this.syn0.rows() || point < 0) {
                    throw new IllegalStateException("Illegal point " + point);
                }
                INDArray syn1 = this.syn1.slice(point);
                double dot = Nd4j.getBlasWrapper().dot(l1, syn1);
                if (dot < -MAX_EXP || dot >= MAX_EXP || (idx = (int)((dot + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0))) >= this.expTable.length) continue;
                double f = this.expTable[idx];
                double g = this.useAdaGrad ? ((SequenceElement)w1).getGradient(i, (double)(1 - code) - f, this.lr.get()) : ((double)(1 - code) - f) * alpha;
                Nd4j.getBlasWrapper().level1().axpy(syn1.length(), g, syn1, neu1e);
                Nd4j.getBlasWrapper().level1().axpy(syn1.length(), g, l1, syn1);
            }
            int target = ((SequenceElement)w1).getIndex();
            if (!(this.negative > 0.0)) break block19;
            int d = 0;
            while ((double)d < this.negative + 1.0) {
                block22: {
                    int label;
                    block21: {
                        block20: {
                            if (d != 0) break block20;
                            label = 1;
                            break block21;
                        }
                        nextRandom.set(nextRandom.get() * 25214903917L + 11L);
                        int idx = Math.abs((int)(nextRandom.get() >> 16) % this.table.length());
                        target = this.table.getInt(new int[]{idx});
                        if (target <= 0) {
                            target = (int)nextRandom.get() % (this.vocab.numWords() - 1) + 1;
                        }
                        if (target == ((SequenceElement)w1).getIndex()) break block22;
                        label = 0;
                    }
                    if (target < this.syn1Neg.rows() && target >= 0) {
                        double g;
                        double f = Nd4j.getBlasWrapper().dot(l1, this.syn1Neg.slice(target));
                        if (f > MAX_EXP) {
                            g = this.useAdaGrad ? ((SequenceElement)w1).getGradient(target, label - 1, alpha) : (double)(label - 1) * alpha;
                        } else if (f < -MAX_EXP) {
                            g = (double)label * (this.useAdaGrad ? ((SequenceElement)w1).getGradient(target, alpha, alpha) : alpha);
                        } else {
                            double d2 = g = this.useAdaGrad ? ((SequenceElement)w1).getGradient(target, (double)label - this.expTable[(int)((f + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0))], alpha) : ((double)label - this.expTable[(int)((f + MAX_EXP) * ((double)this.expTable.length / MAX_EXP / 2.0))]) * alpha;
                        }
                        if (this.syn0.data().dataType() == DataBuffer.Type.DOUBLE) {
                            Nd4j.getBlasWrapper().axpy(g, this.syn1Neg.slice(target), neu1e);
                        } else {
                            Nd4j.getBlasWrapper().axpy((float)g, this.syn1Neg.slice(target), neu1e);
                        }
                        if (this.syn0.data().dataType() == DataBuffer.Type.DOUBLE) {
                            Nd4j.getBlasWrapper().axpy(g, l1, this.syn1Neg.slice(target));
                        } else {
                            Nd4j.getBlasWrapper().axpy((float)g, l1, this.syn1Neg.slice(target));
                        }
                    }
                }
                ++d;
            }
        }
        if (this.syn0.data().dataType() == DataBuffer.Type.DOUBLE) {
            Nd4j.getBlasWrapper().axpy(1.0, neu1e, l1);
        } else {
            Nd4j.getBlasWrapper().axpy(1.0f, neu1e, l1);
        }
    }

    public boolean isUseAdaGrad() {
        return this.useAdaGrad;
    }

    public void setUseAdaGrad(boolean useAdaGrad) {
        this.useAdaGrad = useAdaGrad;
    }

    public double getNegative() {
        return this.negative;
    }

    public void setNegative(double negative) {
        this.negative = negative;
    }

    @Override
    public void iterate(T w1, T w2) {
    }

    @Override
    public void resetWeights() {
        this.resetWeights(true);
    }

    protected void makeTable(int tableSize, double power) {
        String word2;
        int vocabSize = this.syn0.rows();
        this.table = Nd4j.create((DataBuffer)new FloatBuffer((long)tableSize));
        double trainWordsPow = 0.0;
        for (String word2 : this.vocab.words()) {
            trainWordsPow += Math.pow(this.vocab.wordFrequency(word2), power);
        }
        int wordIdx = 0;
        word2 = this.vocab.wordAtIndex(wordIdx);
        double d1 = Math.pow(this.vocab.wordFrequency(word2), power) / trainWordsPow;
        for (int i = 0; i < tableSize; ++i) {
            this.table.putScalar(i, wordIdx);
            double mul = (double)i * 1.0 / (double)tableSize;
            if (!(mul > d1)) continue;
            if (wordIdx < vocabSize - 1) {
                ++wordIdx;
            }
            word2 = this.vocab.wordAtIndex(wordIdx);
            String wordAtIndex = this.vocab.wordAtIndex(wordIdx);
            if (word2 == null) continue;
            d1 += Math.pow(this.vocab.wordFrequency(wordAtIndex), power) / trainWordsPow;
        }
    }

    @Override
    public void putVector(String word, INDArray vector) {
        if (word == null) {
            throw new IllegalArgumentException("No null words allowed");
        }
        if (vector == null) {
            throw new IllegalArgumentException("No null vectors allowed");
        }
        int idx = this.vocab.indexOf(word);
        this.syn0.slice(idx).assign(vector);
    }

    public INDArray getTable() {
        return this.table;
    }

    public void setTable(INDArray table) {
        this.table = table;
    }

    public INDArray getSyn1Neg() {
        return this.syn1Neg;
    }

    public void setSyn1Neg(INDArray syn1Neg) {
        this.syn1Neg = syn1Neg;
    }

    @Override
    public INDArray vector(String word) {
        if (word == null) {
            return null;
        }
        int idx = this.vocab.indexOf(word);
        if (idx < 0 && (idx = this.vocab.indexOf("UNK")) < 0) {
            return null;
        }
        return this.syn0.getRow(idx);
    }

    @Override
    public void setLearningRate(double lr) {
        this.lr.set(lr);
    }

    @Override
    public Iterator<INDArray> vectors() {
        return new WeightIterator();
    }

    @Override
    public INDArray getWeights() {
        return this.syn0;
    }

    public INDArray getSyn0() {
        return this.syn0;
    }

    public void setSyn0(INDArray syn0) {
        this.syn0 = syn0;
    }

    public INDArray getSyn1() {
        return this.syn1;
    }

    public void setSyn1(INDArray syn1) {
        this.syn1 = syn1;
    }

    @Override
    public VocabCache<T> getVocabCache() {
        return this.vocab;
    }

    public void setVectorLength(int vectorLength) {
        this.vectorLength = vectorLength;
    }

    @Deprecated
    public AtomicDouble getLr() {
        return this.lr;
    }

    public void setLr(AtomicDouble lr) {
        this.lr = lr;
    }

    public VocabCache getVocab() {
        return this.vocab;
    }

    public void setVocab(VocabCache vocab) {
        this.vocab = vocab;
    }

    public Map<Integer, INDArray> getCodes() {
        return this.codes;
    }

    public void setCodes(Map<Integer, INDArray> codes) {
        this.codes = codes;
    }

    public String toString() {
        return "InMemoryLookupTable{syn0=" + this.syn0 + ", syn1=" + this.syn1 + ", vectorLength=" + this.vectorLength + ", rng=" + this.rng + ", lr=" + this.lr + ", expTable=" + Arrays.toString(this.expTable) + ", seed=" + this.seed + ", table=" + this.table + ", syn1Neg=" + this.syn1Neg + ", useAdaGrad=" + this.useAdaGrad + ", negative=" + this.negative + ", vocab=" + this.vocab + ", codes=" + this.codes + '}';
    }

    public void consume(InMemoryLookupTable<T> srcTable) {
        if (srcTable.vectorLength != this.vectorLength) {
            throw new IllegalStateException("You can't consume lookupTable with different vector lengths");
        }
        if (srcTable.syn0 == null) {
            throw new IllegalStateException("Source lookupTable Syn0 is NULL");
        }
        this.resetWeights(true);
        if (srcTable.syn0.rows() > this.syn0.rows()) {
            throw new IllegalStateException("You can't consume lookupTable with built for larger vocabulary without updating your vocabulary first");
        }
        for (int x = 0; x < srcTable.syn0.rows(); ++x) {
            this.syn0.putRow(x, srcTable.syn0.getRow(x).dup());
            this.syn1.putRow(x, srcTable.syn1.getRow(x).dup());
        }
    }

    @Override
    public Long getTableId() {
        return this.tableId;
    }

    @Override
    public void setTableId(Long tableId) {
        this.tableId = tableId;
    }

    public static class Builder<T extends SequenceElement> {
        protected int vectorLength = 100;
        protected boolean useAdaGrad = false;
        protected double lr = 0.025;
        protected Random gen = Nd4j.getRandom();
        protected long seed = 123L;
        protected double negative = 0.0;
        protected VocabCache<T> vocabCache;

        public Builder<T> cache(@NonNull VocabCache<T> vocab) {
            if (vocab == null) {
                throw new NullPointerException("vocab");
            }
            this.vocabCache = vocab;
            return this;
        }

        public Builder<T> negative(double negative) {
            this.negative = negative;
            return this;
        }

        public Builder<T> vectorLength(int vectorLength) {
            this.vectorLength = vectorLength;
            return this;
        }

        public Builder<T> useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        @Deprecated
        public Builder<T> lr(double lr) {
            this.lr = lr;
            return this;
        }

        public Builder<T> gen(Random gen) {
            this.gen = gen;
            return this;
        }

        public Builder<T> seed(long seed) {
            this.seed = seed;
            return this;
        }

        public WeightLookupTable<T> build() {
            if (this.vocabCache == null) {
                throw new IllegalStateException("Vocab cache must be specified");
            }
            InMemoryLookupTable table = new InMemoryLookupTable(this.vocabCache, this.vectorLength, this.useAdaGrad, this.lr, this.gen, this.negative);
            table.seed = this.seed;
            return table;
        }
    }

    protected class WeightIterator
    implements Iterator<INDArray> {
        protected int currIndex = 0;

        protected WeightIterator() {
        }

        @Override
        public boolean hasNext() {
            return this.currIndex < InMemoryLookupTable.this.syn0.rows();
        }

        @Override
        public INDArray next() {
            INDArray ret = InMemoryLookupTable.this.syn0.slice(this.currIndex);
            ++this.currIndex;
            return ret;
        }

        @Override
        public void remove() {
            throw new UnsupportedOperationException();
        }
    }
}

