/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism;

import java.util.ArrayList;
import java.util.Collection;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.deeplearning4j.datasets.iterator.AsyncDataSetIterator;
import org.deeplearning4j.datasets.iterator.impl.ListDataSetIterator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.Updater;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.DataSet;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelWrapper {
    private static Logger logger = LoggerFactory.getLogger(ParallelWrapper.class);
    private Model model;
    private int workers = 2;
    private int prefetchSize = 2;
    private int averagingFrequency = 1;
    private Trainer[] zoo;
    private AtomicLong iterationsCounter = new AtomicLong(0L);
    private boolean reportScore = false;
    private boolean averageUpdaters = true;
    private boolean legacyAveraging = false;

    protected ParallelWrapper(Model model, int workers, int prefetchSize) {
        this.model = model;
        this.workers = workers;
        this.prefetchSize = prefetchSize;
        if (this.model instanceof MultiLayerNetwork) {
            ((MultiLayerNetwork)this.model).getUpdater();
        } else if (this.model instanceof ComputationGraph) {
            ((ComputationGraph)this.model).getUpdater();
        }
        this.zoo = new Trainer[workers];
        for (int cnt = 0; cnt < workers; ++cnt) {
            this.zoo[cnt] = new Trainer(cnt, model);
            this.zoo[cnt].start();
        }
    }

    public synchronized void fit(@NonNull DataSetIterator source) {
        if (source == null) {
            throw new NullPointerException("source");
        }
        source.reset();
        DataSetIterator iterator = this.prefetchSize > 0 && !(source instanceof AsyncDataSetIterator) && !(source instanceof ListDataSetIterator) ? new AsyncDataSetIterator(source, this.prefetchSize) : source;
        AtomicInteger locker = new AtomicInteger(0);
        while (iterator.hasNext()) {
            DataSet dataSet = (DataSet)iterator.next();
            int pos = locker.getAndIncrement();
            this.zoo[pos].feedDataSet(dataSet);
            if (pos + 1 != this.workers && iterator.hasNext()) continue;
            this.iterationsCounter.incrementAndGet();
            for (int cnt = 0; cnt < this.workers && cnt < locker.get(); ++cnt) {
                try {
                    this.zoo[cnt].waitTillRunning();
                    continue;
                }
                catch (Exception e) {
                    throw new RuntimeException(e);
                }
            }
            if (this.iterationsCounter.get() % (long)this.averagingFrequency == 0L && pos + 1 == this.workers) {
                int cnt;
                Cloneable updater;
                int cnt2;
                INDArray params;
                double score = 0.0;
                if (!this.legacyAveraging) {
                    params = new ArrayList();
                    for (cnt2 = 0; cnt2 < this.workers && cnt2 < locker.get(); ++cnt2) {
                        params.add(this.zoo[cnt2].getModel().params());
                        score += this.zoo[cnt2].getModel().score();
                    }
                    Nd4j.averageAndPropagate((INDArray)this.model.params(), (Collection)params);
                } else {
                    params = Nd4j.zeros((int[])this.model.params().shape());
                    for (cnt2 = 0; cnt2 < this.workers && cnt2 < locker.get(); ++cnt2) {
                        params.addi(this.zoo[cnt2].getModel().params());
                        score += this.zoo[cnt2].getModel().score();
                    }
                    params.divi((Number)this.workers);
                    this.model.setParams(params);
                }
                score /= (double)Math.min(this.workers, locker.get());
                if (this.reportScore) {
                    logger.info("Averaged score: " + score);
                }
                if (this.model instanceof MultiLayerNetwork) {
                    if (this.averageUpdaters && (updater = ((MultiLayerNetwork)this.model).getUpdater()) != null && updater.getStateViewArray() != null) {
                        if (!this.legacyAveraging) {
                            ArrayList<INDArray> updaters = new ArrayList<INDArray>();
                            for (cnt = 0; cnt < this.workers && cnt < locker.get(); ++cnt) {
                                updaters.add(((MultiLayerNetwork)this.zoo[cnt].getModel()).getUpdater().getStateViewArray());
                            }
                            Nd4j.averageAndPropagate((INDArray)updater.getStateViewArray(), updaters);
                        } else {
                            INDArray state = Nd4j.zeros((int[])updater.getStateViewArray().shape());
                            for (cnt = 0; cnt < this.workers && cnt < locker.get(); ++cnt) {
                                state.addi(((MultiLayerNetwork)this.zoo[cnt].getModel()).getUpdater().getStateViewArray().dup());
                            }
                            state.divi((Number)cnt);
                            updater.setStateViewArray((MultiLayerNetwork)this.model, state, false);
                        }
                    }
                    ((MultiLayerNetwork)this.model).setScore(score);
                } else if (this.model instanceof ComputationGraph) {
                    if (this.averageUpdaters && (updater = ((ComputationGraph)this.model).getUpdater()) != null && ((ComputationGraphUpdater)updater).getStateViewArray() != null) {
                        if (!this.legacyAveraging) {
                            ArrayList<INDArray> updaters = new ArrayList<INDArray>();
                            for (cnt = 0; cnt < this.workers && cnt < locker.get(); ++cnt) {
                                updaters.add(((ComputationGraph)this.zoo[cnt].getModel()).getUpdater().getStateViewArray());
                            }
                            Nd4j.averageAndPropagate((INDArray)((ComputationGraphUpdater)updater).getStateViewArray(), updaters);
                        } else {
                            INDArray state = Nd4j.zeros((int[])((ComputationGraphUpdater)updater).getStateViewArray().shape());
                            for (cnt = 0; cnt < this.workers && cnt < locker.get(); ++cnt) {
                                state.addi(((ComputationGraph)this.zoo[cnt].getModel()).getUpdater().getStateViewArray());
                            }
                            state.divi((Number)cnt);
                            ((ComputationGraphUpdater)updater).setStateViewArray(state);
                        }
                    }
                    ((ComputationGraph)this.model).setScore(score);
                }
                if (this.legacyAveraging) {
                    for (int cnt3 = 0; cnt3 < this.workers; ++cnt3) {
                        this.zoo[cnt3].updateModel(this.model);
                    }
                }
            }
            locker.set(0);
        }
        logger.debug("Iterations passed: {}", (Object)this.iterationsCounter.get());
        this.iterationsCounter.set(0L);
    }

    private static class Trainer
    extends Thread
    implements Runnable {
        private Model originalModel;
        private Model replicatedModel;
        private LinkedBlockingQueue<DataSet> queue = new LinkedBlockingQueue();
        private AtomicInteger running = new AtomicInteger(0);
        private int threadId;
        private AtomicBoolean shouldUpdate = new AtomicBoolean(false);

        public Trainer(int threadId, Model model) {
            this.threadId = threadId;
            this.setDaemon(true);
            this.setName("ParallelWrapper trainer " + threadId);
            this.originalModel = model;
            if (!(model instanceof MultiLayerNetwork) && model instanceof ComputationGraph) {
                this.replicatedModel = ((ComputationGraph)model).clone();
                if (threadId != 0) {
                    ((ComputationGraph)this.replicatedModel).setListeners(new ArrayList<IterationListener>());
                }
            }
        }

        public void feedDataSet(@NonNull DataSet dataSet) {
            if (dataSet == null) {
                throw new NullPointerException("dataSet");
            }
            this.running.incrementAndGet();
            this.queue.add(dataSet);
        }

        public Model getModel() {
            return this.replicatedModel;
        }

        public void updateModel(@NonNull Model model) {
            if (model == null) {
                throw new NullPointerException("model");
            }
            this.shouldUpdate.set(true);
            if (this.replicatedModel instanceof MultiLayerNetwork) {
                this.replicatedModel.setParams(model.params().dup());
                Updater updater = ((MultiLayerNetwork)this.originalModel).getUpdater();
                INDArray view = updater.getStateViewArray();
                updater = ((MultiLayerNetwork)this.replicatedModel).getUpdater();
                updater.setStateViewArray((MultiLayerNetwork)this.replicatedModel, view.dup(), false);
            } else if (this.replicatedModel instanceof ComputationGraph) {
                this.replicatedModel.setParams(model.params().dup());
                ComputationGraphUpdater updater = ((ComputationGraph)this.originalModel).getUpdater();
                INDArray view = updater.getStateViewArray();
                updater = ((ComputationGraph)this.replicatedModel).getUpdater();
                updater.setStateViewArray(view.dup());
            }
        }

        public boolean isRunning() {
            return this.running.get() == 0;
        }

        @Override
        public void run() {
            try {
                if (this.originalModel instanceof MultiLayerNetwork) {
                    MultiLayerConfiguration conf = ((MultiLayerNetwork)this.originalModel).getLayerWiseConfigurations().clone();
                    this.replicatedModel = new MultiLayerNetwork(conf);
                    ((MultiLayerNetwork)this.replicatedModel).init();
                } else if (this.originalModel instanceof ComputationGraph) {
                    this.replicatedModel = new ComputationGraph(((ComputationGraph)this.originalModel).getConfiguration().clone());
                    ((ComputationGraph)this.replicatedModel).init();
                }
                while (true) {
                    DataSet dataSet;
                    if ((dataSet = this.queue.poll(1L, TimeUnit.SECONDS)) == null) {
                        continue;
                    }
                    if (this.replicatedModel instanceof MultiLayerNetwork) {
                        ((MultiLayerNetwork)this.replicatedModel).fit(dataSet);
                    } else if (this.replicatedModel instanceof ComputationGraph) {
                        ((ComputationGraph)this.replicatedModel).fit(dataSet);
                    }
                    this.running.decrementAndGet();
                }
            }
            catch (Exception exception) {
                return;
            }
        }

        public void waitTillRunning() {
            while (this.running.get() != 0) {
                try {
                    Thread.sleep(10L);
                }
                catch (Exception exception) {}
            }
        }
    }

    public static class Builder {
        private Model model;
        private int workers = 2;
        private int prefetchSize = 16;
        private int averagingFrequency = 1;
        private boolean reportScore = false;
        private boolean averageUpdaters = true;
        private boolean legacyAveraging = true;

        public Builder(@NonNull MultiLayerNetwork mln) {
            if (mln == null) {
                throw new NullPointerException("mln");
            }
            this.model = mln;
        }

        public Builder(@NonNull ComputationGraph graph) {
            if (graph == null) {
                throw new NullPointerException("graph");
            }
            this.model = graph;
        }

        public Builder workers(int num) {
            if (num < 2) {
                throw new RuntimeException("Number of workers can't be lower then 2!");
            }
            this.workers = num;
            return this;
        }

        public Builder averagingFrequency(int freq) {
            this.averagingFrequency = freq;
            return this;
        }

        public Builder averageUpdaters(boolean reallyAverage) {
            this.averageUpdaters = reallyAverage;
            return this;
        }

        public Builder prefetchBuffer(int size) {
            if (size < 0) {
                size = 0;
            }
            this.prefetchSize = size;
            return this;
        }

        public Builder useLegacyAveraging(boolean reallyUse) {
            this.legacyAveraging = reallyUse;
            return this;
        }

        public Builder reportScoreAfterAveraging(boolean reallyReport) {
            this.reportScore = reallyReport;
            return this;
        }

        public ParallelWrapper build() {
            ParallelWrapper wrapper = new ParallelWrapper(this.model, this.workers, this.prefetchSize);
            wrapper.averagingFrequency = this.averagingFrequency;
            wrapper.reportScore = this.reportScore;
            wrapper.averageUpdaters = this.averageUpdaters;
            wrapper.legacyAveraging = this.legacyAveraging;
            return wrapper;
        }
    }
}

