/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.scaleout.perform.models.glove;

import java.io.ByteArrayOutputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.canova.api.conf.Configuration;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.models.glove.CoOccurrences;
import org.deeplearning4j.models.glove.GloveWeightLookupTable;
import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.scaleout.api.statetracker.StateTracker;
import org.deeplearning4j.scaleout.job.Job;
import org.deeplearning4j.scaleout.perform.WorkerPerformer;
import org.deeplearning4j.scaleout.perform.models.glove.GloveJobAggregator;
import org.deeplearning4j.scaleout.perform.models.glove.GlovePerformerFactory;
import org.deeplearning4j.scaleout.perform.models.glove.GloveWork;
import org.deeplearning4j.scaleout.statetracker.hazelcast.HazelCastStateTracker;
import org.deeplearning4j.text.invertedindex.InvertedIndex;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GlovePerformer
implements WorkerPerformer {
    public static final String NAME_SPACE = "org.deeplearning4j.scaleout.perform.models.glove";
    public static final String VECTOR_LENGTH = "org.deeplearning4j.scaleout.perform.models.glove.length";
    public static final String NUM_WORDS = "org.deeplearning4j.scaleout.perform.models.glove.numwords";
    public static final String TABLE = "org.deeplearning4j.scaleout.perform.models.glove.table";
    public static final String ALPHA = "org.deeplearning4j.scaleout.perform.models.glove.alpha";
    public static final String ITERATIONS = "org.deeplearning4j.scaleout.perform.models.glove.iterations";
    public static final String X_MAX = "org.deeplearning4j.scaleout.perform.models.glove.xmax";
    public static final String MAX_COUNT = "org.deeplearning4j.scaleout.perform.models.glove.maxcount";
    public static final String LOOKUPTABLE_SIZE = "org.deeplearning4j.scaleout.perform.models.glove.lookuptablesize";
    private StateTracker stateTracker;
    private double xMax = 0.75;
    private static final Logger log = LoggerFactory.getLogger(GlovePerformer.class);
    private CoOccurrences coOccurrences;
    private double maxCount = 100.0;
    private int[] lookupTableSize;
    private int[] biasShape;

    public GlovePerformer(StateTracker stateTracker) {
        this.stateTracker = stateTracker;
    }

    public GlovePerformer() {
    }

    public void perform(Job job) {
        if (job.getWork() instanceof GloveWork) {
            GloveWork work = (GloveWork)job.getWork();
            if (work == null) {
                return;
            }
            List<Pair<VocabWord, VocabWord>> sentences = work.getCoOccurrences();
            for (Pair<VocabWord, VocabWord> coc : sentences) {
                this.iterateSample(work, (VocabWord)coc.getFirst(), (VocabWord)coc.getSecond(), this.coOccurrences.count(((VocabWord)coc.getFirst()).getWord(), ((VocabWord)coc.getSecond()).getWord()));
            }
            job.setResult((Serializable)((Object)Arrays.asList(work.addDeltas())));
        }
    }

    public void update(Object ... o) {
    }

    public void setup(Configuration conf) {
        this.xMax = conf.getFloat(X_MAX, 0.75f);
        this.maxCount = conf.getFloat(MAX_COUNT, 100.0f);
        this.lookupTableSize = this.getInts(conf, LOOKUPTABLE_SIZE);
        this.biasShape = new int[]{this.lookupTableSize[1]};
        String connectionString = conf.get("org.deeplearning4j.scaleout.statetracker.connectionstring");
        log.info("Creating state tracker with connection string " + connectionString);
        if (this.stateTracker == null) {
            try {
                this.stateTracker = new HazelCastStateTracker(connectionString);
            }
            catch (Exception e) {
                e.printStackTrace();
            }
        }
        this.coOccurrences = (CoOccurrences)this.stateTracker.get("cooccurrences");
        if (this.coOccurrences == null) {
            throw new IllegalStateException("Please specify co occurrences");
        }
    }

    private int[] getInts(Configuration conf, String key) {
        String[] strs = conf.getStrings(key);
        int[] ret = new int[strs.length];
        for (int i = 0; i < ret.length; ++i) {
            ret[i] = Integer.parseInt(strs[i]);
        }
        return ret;
    }

    public static void configure(GloveWeightLookupTable table, InvertedIndex index, Configuration conf) {
        if (table.getSyn0() == null) {
            throw new IllegalStateException("Unable to configure glove: missing look up table size. Please call table.resetWeights() first");
        }
        conf.setInt(VECTOR_LENGTH, table.getVectorLength());
        conf.setFloat(ALPHA, (float)table.getLr().get());
        conf.setStrings(LOOKUPTABLE_SIZE, new String[]{String.valueOf(table.getSyn0().rows()), String.valueOf(table.getSyn0().columns())});
        conf.setInt(NUM_WORDS, index.totalWords());
        conf.set("org.deeplearning4j.scaleout.aggregator", GloveJobAggregator.class.getName());
        conf.set("org.deeplearning4j.scaleout.perform.workerperformer", GlovePerformerFactory.class.getName());
        table.resetWeights();
        if (table.getNegative() > 0.0) {
            ByteArrayOutputStream bis = new ByteArrayOutputStream();
            try {
                DataOutputStream ois = new DataOutputStream(bis);
                Nd4j.write((INDArray)table.getTable(), (DataOutputStream)ois);
            }
            catch (IOException e) {
                e.printStackTrace();
            }
            conf.set(TABLE, new String(bis.toByteArray()));
        }
    }

    public double iterateSample(GloveWork work, VocabWord w1, VocabWord w2, double score) {
        double fDiff;
        INDArray w1Vector = work.getOriginalVectors().get(w1.getWord());
        INDArray w2Vector = work.getOriginalVectors().get(w2.getWord());
        double prediction = Nd4j.getBlasWrapper().dot(w1Vector, w2Vector);
        double weight = Math.pow(Math.min(1.0, score / this.maxCount), this.xMax);
        double gradient = fDiff = score > this.xMax ? prediction : weight * ((prediction += work.getBiases().get(w1.getWord()) + work.getBiases().get(w2.getWord())) - Math.log(score));
        this.update(work, w1, w1Vector, w2Vector, gradient);
        this.update(work, w2, w2Vector, w1Vector, gradient);
        return fDiff;
    }

    private void update(GloveWork gloveWork, VocabWord w1, INDArray wordVector, INDArray contextVector, double gradient) {
        INDArray grad1 = contextVector.mul((Number)gradient);
        INDArray update = gloveWork.getAdaGrad(w1.getWord()).getGradient(grad1, 0, this.lookupTableSize);
        wordVector.subi(update);
        double w1Bias = gloveWork.getBias(w1.getWord());
        double biasGradient = gloveWork.getBiasAdaGrad(w1.getWord()).getGradient(gradient, 0, this.biasShape);
        double update2 = w1Bias - biasGradient;
        gloveWork.updateBias(w1.getWord(), update2);
    }
}

