/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.graph.models.deepwalk;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicLong;
import org.deeplearning4j.graph.api.IGraph;
import org.deeplearning4j.graph.api.IVertexSequence;
import org.deeplearning4j.graph.api.NoEdgeHandling;
import org.deeplearning4j.graph.api.Vertex;
import org.deeplearning4j.graph.iterator.GraphWalkIterator;
import org.deeplearning4j.graph.iterator.parallel.GraphWalkIteratorProvider;
import org.deeplearning4j.graph.iterator.parallel.RandomWalkGraphIteratorProvider;
import org.deeplearning4j.graph.models.deepwalk.GraphHuffman;
import org.deeplearning4j.graph.models.embeddings.GraphVectorLookupTable;
import org.deeplearning4j.graph.models.embeddings.GraphVectorsImpl;
import org.deeplearning4j.graph.models.embeddings.InMemoryGraphLookupTable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.threadly.concurrent.PriorityScheduler;
import org.threadly.concurrent.future.FutureUtils;
import org.threadly.concurrent.future.ListenableFuture;

public class DeepWalk<V, E>
extends GraphVectorsImpl<V, E> {
    public static final int STATUS_UPDATE_FREQUENCY = 1000;
    private Logger log = LoggerFactory.getLogger(DeepWalk.class);
    private int vectorSize;
    private int windowSize;
    private double learningRate;
    private boolean initCalled = false;
    private long seed;
    private int nThreads = Runtime.getRuntime().availableProcessors();
    private transient AtomicLong walkCounter = new AtomicLong(0L);

    @Override
    public int getVectorSize() {
        return this.vectorSize;
    }

    public int getWindowSize() {
        return this.windowSize;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double learningRate) {
        this.learningRate = learningRate;
        if (this.lookupTable != null) {
            this.lookupTable.setLearningRate(learningRate);
        }
    }

    public void initialize(IGraph<V, E> graph) {
        int nVertices = graph.numVertices();
        int[] degrees = new int[nVertices];
        for (int i = 0; i < nVertices; ++i) {
            degrees[i] = graph.getVertexDegree(i);
        }
        this.initialize(degrees);
    }

    public void initialize(int[] graphVertexDegrees) {
        this.log.info("Initializing: Creating Huffman tree and lookup table...");
        GraphHuffman gh = new GraphHuffman(graphVertexDegrees.length);
        gh.buildTree(graphVertexDegrees);
        this.lookupTable = new InMemoryGraphLookupTable(graphVertexDegrees.length, this.vectorSize, gh, this.learningRate);
        this.initCalled = true;
        this.log.info("Initialization complete");
    }

    public void fit(IGraph<V, E> graph, int walkLength) {
        if (!this.initCalled) {
            this.initialize(graph);
        }
        RandomWalkGraphIteratorProvider<V> iteratorProvider = new RandomWalkGraphIteratorProvider<V>(graph, walkLength, this.seed, NoEdgeHandling.SELF_LOOP_ON_DISCONNECTED);
        this.fit(iteratorProvider);
    }

    public void fit(GraphWalkIteratorProvider<V> iteratorProvider) {
        if (!this.initCalled) {
            throw new UnsupportedOperationException("DeepWalk not initialized (call initialize before fit)");
        }
        List<GraphWalkIterator<V>> iteratorList = iteratorProvider.getGraphWalkIterators(this.nThreads);
        PriorityScheduler scheduler = new PriorityScheduler(this.nThreads);
        ArrayList<ListenableFuture> list = new ArrayList<ListenableFuture>(iteratorList.size());
        for (GraphWalkIterator<V> iter : iteratorList) {
            LearningCallable c = new LearningCallable(iter);
            list.add(scheduler.submit((Callable)c));
        }
        scheduler.shutdown();
        try {
            FutureUtils.blockTillAllCompleteOrFirstError(list);
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            throw new RuntimeException(e);
        }
        catch (ExecutionException e) {
            throw new RuntimeException(e);
        }
    }

    public void fit(GraphWalkIterator<V> iterator) {
        if (!this.initCalled) {
            throw new UnsupportedOperationException("DeepWalk not initialized (call initialize before fit)");
        }
        int walkLength = iterator.walkLength();
        while (iterator.hasNext()) {
            IVertexSequence<V> sequence = iterator.next();
            int[] walk = new int[walkLength + 1];
            int i = 0;
            while (sequence.hasNext()) {
                walk[i++] = ((Vertex)sequence.next()).vertexID();
            }
            this.skipGram(walk);
            long iter = this.walkCounter.incrementAndGet();
            if (iter % 1000L != 0L) continue;
            this.log.info("Processed {} random walks on graph", (Object)iter);
        }
    }

    private void skipGram(int[] walk) {
        for (int mid = this.windowSize; mid < walk.length - this.windowSize; ++mid) {
            for (int pos = mid - this.windowSize; pos <= mid + this.windowSize; ++pos) {
                if (pos == mid) continue;
                this.lookupTable.iterate(walk[mid], walk[pos]);
            }
        }
    }

    public GraphVectorLookupTable lookupTable() {
        return this.lookupTable;
    }

    private class LearningCallable
    implements Callable<Void> {
        private final GraphWalkIterator<V> iterator;

        @Override
        public Void call() throws Exception {
            DeepWalk.this.fit(this.iterator);
            return null;
        }

        public LearningCallable(GraphWalkIterator<V> iterator) {
            this.iterator = iterator;
        }
    }

    public static class Builder<V, E> {
        private int vectorSize = 100;
        private long seed = System.currentTimeMillis();
        private double learningRate = 0.01;
        private int windowSize = 2;

        public Builder<V, E> vectorSize(int vectorSize) {
            this.vectorSize = vectorSize;
            return this;
        }

        public Builder<V, E> learningRate(double learningRate) {
            this.learningRate = learningRate;
            return this;
        }

        public Builder<V, E> windowSize(int windowSize) {
            this.windowSize = windowSize;
            return this;
        }

        public Builder<V, E> seed(long seed) {
            this.seed = seed;
            return this;
        }

        public DeepWalk<V, E> build() {
            DeepWalk dw = new DeepWalk();
            dw.vectorSize = this.vectorSize;
            dw.windowSize = this.windowSize;
            dw.learningRate = this.learningRate;
            dw.seed = this.seed;
            return dw;
        }
    }
}

