/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism;

import java.util.ArrayList;
import java.util.List;
import java.util.Observer;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.InplaceParallelInference;
import org.deeplearning4j.parallelism.inference.InferenceMode;
import org.deeplearning4j.parallelism.inference.InferenceObservable;
import org.deeplearning4j.parallelism.inference.LoadBalanceMode;
import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObservable;
import org.deeplearning4j.parallelism.inference.observers.BasicInferenceObserver;
import org.deeplearning4j.parallelism.inference.observers.BatchedInferenceObservable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ParallelInference {
    private static final Logger log = LoggerFactory.getLogger(ParallelInference.class);
    protected Model model;
    protected long nanos;
    protected int workers;
    protected int batchLimit;
    protected InferenceMode inferenceMode;
    protected int queueLimit;
    protected LoadBalanceMode loadBalanceMode = LoadBalanceMode.FIFO;
    private BlockingQueue<InferenceObservable> observables;
    private final Object locker = new Object();
    private InferenceWorker[] zoo;
    private ObservablesProvider provider;
    public static final int DEFAULT_NUM_WORKERS = Nd4j.getAffinityManager().getNumberOfDevices();
    public static final int DEFAULT_BATCH_LIMIT = 32;
    public static final InferenceMode DEFAULT_INFERENCE_MODE = InferenceMode.BATCHED;
    public static final int DEFAULT_QUEUE_LIMIT = 64;

    protected ParallelInference() {
    }

    public void updateModel(@NonNull Model model) {
        if (model == null) {
            throw new NullPointerException("model is marked @NonNull but is null");
        }
        if (this.zoo != null) {
            for (InferenceWorker w : this.zoo) {
                w.updateModel(model);
            }
        } else {
            this.model = model;
        }
    }

    protected Model[] getCurrentModelsFromWorkers() {
        if (this.zoo == null) {
            return new Model[0];
        }
        Model[] models = new Model[this.zoo.length];
        int cnt = 0;
        for (InferenceWorker w : this.zoo) {
            models[cnt++] = w.replicatedModel;
        }
        return models;
    }

    protected void init() {
        this.observables = new LinkedBlockingQueue<InferenceObservable>(this.queueLimit);
        int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
        int currentDevice = Nd4j.getAffinityManager().getDeviceForCurrentThread();
        AtomicBoolean assignedRoot = new AtomicBoolean(false);
        this.zoo = new InferenceWorker[this.workers];
        for (int i = 0; i < this.workers; ++i) {
            int cDevice = i % numDevices;
            boolean cRoot = !assignedRoot.get() && cDevice == currentDevice;
            assignedRoot.compareAndSet(false, cRoot);
            this.zoo[i] = new InferenceWorker(i, this.model, this.observables, cRoot, cDevice);
            this.zoo[i].setDaemon(true);
            this.zoo[i].start();
        }
        if (this.inferenceMode == InferenceMode.BATCHED) {
            log.info("Initializing ObservablesProvider...");
            this.provider = new ObservablesProvider(this.nanos, this.batchLimit, this.observables);
        }
    }

    protected long getWorkerCounter(int workerIdx) {
        return this.zoo[workerIdx].getCounterValue();
    }

    public synchronized void shutdown() {
        if (this.zoo == null) {
            return;
        }
        for (int e = 0; e < this.zoo.length; ++e) {
            if (this.zoo[e] == null) continue;
            this.zoo[e].interrupt();
            this.zoo[e].shutdown();
            this.zoo[e] = null;
        }
        this.zoo = null;
        System.gc();
    }

    public INDArray output(double[] input) {
        return this.output(Nd4j.create((double[])input));
    }

    public INDArray output(float[] input) {
        return this.output(Nd4j.create((float[])input));
    }

    public INDArray output(INDArray input) {
        return this.output(input, (INDArray)null);
    }

    public INDArray output(INDArray input, INDArray inputMask) {
        INDArray[] iNDArrayArray;
        INDArray[] iNDArrayArray2 = new INDArray[]{input};
        if (inputMask == null) {
            iNDArrayArray = null;
        } else {
            INDArray[] iNDArrayArray3 = new INDArray[1];
            iNDArrayArray = iNDArrayArray3;
            iNDArrayArray3[0] = inputMask;
        }
        INDArray[] out = this.output(iNDArrayArray2, iNDArrayArray);
        if (out.length != 1) {
            throw new IllegalArgumentException("Network has multiple (" + out.length + ") output arrays, but only a single output can be returned using this method. Use for output(INDArray[] input, INDArray[] inputMasks) for multi-output nets");
        }
        return out[0];
    }

    public INDArray output(DataSet dataSet) {
        return this.output(dataSet.getFeatures(), dataSet.getFeaturesMaskArray());
    }

    public INDArray[] output(INDArray ... input) {
        return this.output(input, (INDArray[])null);
    }

    public INDArray[] output(INDArray[] input, INDArray[] inputMasks) {
        InferenceObservable observable;
        Nd4j.getExecutioner().commit();
        BasicInferenceObserver observer = new BasicInferenceObserver();
        if (this.inferenceMode == InferenceMode.SEQUENTIAL) {
            observable = new BasicInferenceObservable(input, inputMasks);
            observable.addObserver(observer);
            try {
                this.observables.put(observable);
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
                throw new RuntimeException(e);
            }
        } else {
            observable = this.provider.setInput((Observer)observer, input, inputMasks);
        }
        try {
            observer.waitTillDone();
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        return observable.getOutput();
    }

    public <T> T output(@NonNull ModelAdapter<T> adapter, INDArray ... inputs) {
        if (adapter == null) {
            throw new NullPointerException("adapter is marked @NonNull but is null");
        }
        return this.output(adapter, inputs, (INDArray[])null);
    }

    public <T> T output(@NonNull ModelAdapter<T> adapter, INDArray[] input, INDArray[] inputMasks) {
        if (adapter == null) {
            throw new NullPointerException("adapter is marked @NonNull but is null");
        }
        throw new ND4JIllegalStateException("Adapted mode requires Inplace inference mode");
    }

    protected static class ObservablesProvider {
        private BlockingQueue<InferenceObservable> targetQueue;
        private long nanos;
        private int batchLimit;
        private volatile BatchedInferenceObservable currentObservable;
        private final Object locker = new Object();

        protected ObservablesProvider(long nanos, int batchLimit, @NonNull BlockingQueue<InferenceObservable> queue) {
            if (queue == null) {
                throw new NullPointerException("queue is marked @NonNull but is null");
            }
            this.targetQueue = queue;
            this.nanos = nanos;
            this.batchLimit = batchLimit;
        }

        protected InferenceObservable setInput(@NonNull Observer observer, INDArray input) {
            if (observer == null) {
                throw new NullPointerException("observer is marked @NonNull but is null");
            }
            return this.setInput(observer, new INDArray[]{input}, (INDArray[])null);
        }

        protected InferenceObservable setInput(@NonNull Observer observer, INDArray ... input) {
            if (observer == null) {
                throw new NullPointerException("observer is marked @NonNull but is null");
            }
            return this.setInput(observer, input, (INDArray[])null);
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        protected InferenceObservable setInput(@NonNull Observer observer, INDArray[] input, INDArray[] inputMask) {
            if (observer == null) {
                throw new NullPointerException("observer is marked @NonNull but is null");
            }
            Object object = this.locker;
            synchronized (object) {
                boolean isNew = false;
                if (this.currentObservable == null || this.currentObservable.getCounter() >= this.batchLimit || this.currentObservable.isLocked()) {
                    isNew = true;
                    this.currentObservable = new BatchedInferenceObservable();
                }
                this.currentObservable.addInput(input, inputMask);
                this.currentObservable.addObserver(observer);
                try {
                    if (isNew) {
                        this.targetQueue.put(this.currentObservable);
                    }
                }
                catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                    throw new RuntimeException(e);
                }
                return this.currentObservable;
            }
        }
    }

    private class InferenceWorker
    extends Thread
    implements Runnable {
        private BlockingQueue<InferenceObservable> inputQueue;
        private AtomicBoolean shouldWork = new AtomicBoolean(true);
        private AtomicBoolean isStopped = new AtomicBoolean(false);
        private Model protoModel;
        private Model replicatedModel;
        private AtomicLong counter = new AtomicLong(0L);
        private boolean rootDevice;
        private int deviceId;
        private ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock();

        private InferenceWorker(@NonNull int id, @NonNull Model model, BlockingQueue inputQueue, boolean rootDevice, int deviceId) {
            if (model == null) {
                throw new NullPointerException("model is marked @NonNull but is null");
            }
            if (inputQueue == null) {
                throw new NullPointerException("inputQueue is marked @NonNull but is null");
            }
            this.inputQueue = inputQueue;
            this.protoModel = model;
            this.rootDevice = rootDevice;
            this.deviceId = deviceId;
            this.setDaemon(true);
            this.setName("InferenceThread-" + id);
        }

        protected long getCounterValue() {
            return this.counter.get();
        }

        protected void updateModel(@NonNull Model model) {
            if (model == null) {
                throw new NullPointerException("model is marked @NonNull but is null");
            }
            try {
                this.modelLock.writeLock().lock();
                this.protoModel = model;
                this.initializeReplicaModel();
            }
            finally {
                this.modelLock.writeLock().unlock();
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        protected void initializeReplicaModel() {
            if (this.protoModel instanceof ComputationGraph) {
                if (!this.rootDevice) {
                    this.replicatedModel = new ComputationGraph(ComputationGraphConfiguration.fromJson((String)((ComputationGraph)this.protoModel).getConfiguration().toJson()));
                    this.replicatedModel.init();
                    Object object = ParallelInference.this.locker;
                    synchronized (object) {
                        this.replicatedModel.setParams(this.protoModel.params().unsafeDuplication(true));
                        Nd4j.getExecutioner().commit();
                    }
                } else {
                    this.replicatedModel = this.protoModel;
                }
            } else if (this.protoModel instanceof MultiLayerNetwork) {
                if (!this.rootDevice) {
                    this.replicatedModel = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String)((MultiLayerNetwork)this.protoModel).getLayerWiseConfigurations().toJson()));
                    this.replicatedModel.init();
                    Object object = ParallelInference.this.locker;
                    synchronized (object) {
                        this.replicatedModel.setParams(this.protoModel.params().unsafeDuplication(true));
                        Nd4j.getExecutioner().commit();
                    }
                } else {
                    this.replicatedModel = this.protoModel;
                }
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void run() {
            Nd4j.getAffinityManager().unsafeSetDevice(Integer.valueOf(this.deviceId));
            try {
                this.initializeReplicaModel();
                boolean isCG = this.replicatedModel instanceof ComputationGraph;
                boolean isMLN = this.replicatedModel instanceof MultiLayerNetwork;
                while (this.shouldWork.get()) {
                    ArrayList<INDArray[]> out;
                    List<Pair<INDArray[], INDArray[]>> batches;
                    InferenceObservable request = this.inputQueue.take();
                    if (request == null) continue;
                    this.counter.incrementAndGet();
                    if (isCG) {
                        batches = request.getInputBatches();
                        out = new ArrayList<INDArray[]>(batches.size());
                        try {
                            for (Pair<INDArray[], INDArray[]> inBatch : batches) {
                                try {
                                    this.modelLock.readLock().lock();
                                    INDArray[] output = ((ComputationGraph)this.replicatedModel).output(false, (INDArray[])inBatch.getFirst(), (INDArray[])inBatch.getSecond());
                                    out.add(output);
                                }
                                finally {
                                    Nd4j.getExecutioner().commit();
                                    this.modelLock.readLock().unlock();
                                }
                            }
                            request.setOutputBatches(out);
                        }
                        catch (Exception e) {
                            request.setOutputException(e);
                        }
                        continue;
                    }
                    if (!isMLN) continue;
                    batches = request.getInputBatches();
                    out = new ArrayList(batches.size());
                    try {
                        for (Pair<INDArray[], INDArray[]> inBatch : batches) {
                            INDArray f = ((INDArray[])inBatch.getFirst())[0];
                            INDArray fm = inBatch.getSecond() == null ? null : ((INDArray[])inBatch.getSecond())[0];
                            try {
                                this.modelLock.readLock().lock();
                                INDArray output = ((MultiLayerNetwork)this.replicatedModel).output(f, false, fm, null);
                                out.add(new INDArray[]{output});
                            }
                            finally {
                                Nd4j.getExecutioner().commit();
                                this.modelLock.readLock().unlock();
                            }
                        }
                        request.setOutputBatches(out);
                    }
                    catch (Exception e) {
                        request.setOutputException(e);
                    }
                }
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            finally {
                this.isStopped.set(true);
            }
        }

        protected void shutdown() {
            this.shouldWork.set(false);
            while (!this.isStopped.get()) {
            }
        }
    }

    public static class Builder {
        private Model model;
        private int workers = DEFAULT_NUM_WORKERS;
        private int batchLimit = 32;
        private InferenceMode inferenceMode = DEFAULT_INFERENCE_MODE;
        private int queueLimit = 64;
        protected LoadBalanceMode loadBalanceMode = LoadBalanceMode.FIFO;

        public Builder(@NonNull Model model) {
            if (model == null) {
                throw new NullPointerException("model is marked @NonNull but is null");
            }
            this.model = model;
        }

        public Builder inferenceMode(@NonNull InferenceMode inferenceMode) {
            if (inferenceMode == null) {
                throw new NullPointerException("inferenceMode is marked @NonNull but is null");
            }
            this.inferenceMode = inferenceMode;
            return this;
        }

        public Builder loadBalanceMode(@NonNull LoadBalanceMode loadBalanceMode) {
            if (loadBalanceMode == null) {
                throw new NullPointerException("loadBalanceMode is marked @NonNull but is null");
            }
            this.loadBalanceMode = loadBalanceMode;
            return this;
        }

        public Builder workers(int workers) {
            if (workers < 1) {
                throw new IllegalStateException("Workers should be positive value");
            }
            this.workers = workers;
            return this;
        }

        public Builder batchLimit(int limit) {
            if (limit < 1) {
                throw new IllegalStateException("Batch limit should be positive value");
            }
            this.batchLimit = limit;
            return this;
        }

        public Builder queueLimit(int limit) {
            if (limit < 1) {
                throw new IllegalStateException("Queue limit should be positive value");
            }
            this.queueLimit = limit;
            return this;
        }

        public ParallelInference build() {
            if (this.inferenceMode == InferenceMode.INPLACE) {
                InplaceParallelInference inf = new InplaceParallelInference();
                inf.inferenceMode = this.inferenceMode;
                inf.model = this.model;
                inf.workers = this.workers;
                inf.loadBalanceMode = this.loadBalanceMode;
                inf.init();
                return inf;
            }
            ParallelInference inference = new ParallelInference();
            inference.batchLimit = this.batchLimit;
            inference.queueLimit = this.queueLimit;
            inference.inferenceMode = this.inferenceMode;
            inference.model = this.model;
            inference.workers = this.workers;
            inference.loadBalanceMode = this.loadBalanceMode;
            inference.init();
            return inference;
        }
    }
}

