/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.parallelism;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.api.ModelAdapter;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.parallelism.ParallelInference;
import org.deeplearning4j.parallelism.inference.LoadBalanceMode;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class InplaceParallelInference
extends ParallelInference {
    private static final Logger log = LoggerFactory.getLogger(InplaceParallelInference.class);
    protected List<ModelHolder> holders = new CopyOnWriteArrayList<ModelHolder>();
    protected ModelSelector selector = new ModelSelector();
    protected final Object locker = new Object();

    @Override
    protected void init() {
        for (int e = 0; e < Nd4j.getAffinityManager().getNumberOfDevices(); ++e) {
            ModelHolder h = ModelHolder.builder().sourceModel(this.model).workers(this.workers).loadBalanceMode(this.loadBalanceMode).targetDeviceId(e).rootDevice(e == Nd4j.getAffinityManager().getDeviceForCurrentThread()).build();
            h.init();
            this.holders.add(h);
            this.selector.addModelHolder(e, h);
        }
    }

    @Override
    public synchronized void updateModel(@NonNull Model model) {
        if (model == null) {
            throw new NullPointerException("model is marked @NonNull but is null");
        }
        for (ModelHolder h : this.holders) {
            h.updateModel(model);
        }
    }

    @Override
    protected synchronized Model[] getCurrentModelsFromWorkers() {
        Model[] models = new Model[this.holders.size()];
        int cnt = 0;
        for (ModelHolder h : this.holders) {
            models[cnt++] = h.sourceModel;
        }
        return models;
    }

    @Override
    public INDArray[] output(INDArray[] input, INDArray[] inputMasks) {
        return this.selector.output(input, inputMasks);
    }

    public <T> T output(@NonNull ModelAdapter<T> adapter, INDArray[] input, INDArray[] inputMasks, INDArray[] labelsMasks) {
        if (adapter == null) {
            throw new NullPointerException("adapter is marked @NonNull but is null");
        }
        ModelHolder holder = this.selector.getModelForThisThread();
        Model model = null;
        boolean acquired = false;
        try {
            model = holder.acquireModel();
            acquired = true;
            Object object = adapter.apply(model, input, inputMasks, labelsMasks);
            return (T)object;
        }
        catch (InterruptedException e) {
            throw new RuntimeException(e);
        }
        finally {
            if (model != null && acquired) {
                holder.releaseModel(model);
            }
        }
    }

    protected static class ModelHolder {
        protected Model sourceModel;
        protected int workers;
        protected List<Model> replicas;
        protected boolean rootDevice;
        protected LoadBalanceMode loadBalanceMode;
        protected int targetDeviceId;
        protected final AtomicLong position = new AtomicLong(0L);
        protected final ReentrantReadWriteLock modelLock = new ReentrantReadWriteLock();
        protected final BlockingQueue<Model> queue = new LinkedBlockingQueue<Model>();
        protected transient boolean isCG;
        protected transient boolean isMLN;

        protected synchronized void init() {
            INDArray params;
            if (this.workers < 1) {
                throw new ND4JIllegalStateException("Workers must be positive value");
            }
            this.replicas.clear();
            this.isCG = this.sourceModel instanceof ComputationGraph;
            this.isMLN = this.sourceModel instanceof MultiLayerNetwork;
            INDArray iNDArray = params = this.rootDevice ? this.sourceModel.params() : this.sourceModel.params().unsafeDuplication(true);
            if (!this.rootDevice) {
                Nd4j.getAffinityManager().replicateToDevice(Integer.valueOf(this.targetDeviceId), params);
            }
            for (int e = 0; e < this.workers; ++e) {
                ComputationGraph model;
                if (this.sourceModel instanceof ComputationGraph) {
                    model = new ComputationGraph(ComputationGraphConfiguration.fromJson((String)((ComputationGraph)this.sourceModel).getConfiguration().toJson()));
                    model.init(params, false);
                    Nd4j.getExecutioner().commit();
                    this.replicas.add((Model)model);
                    if (this.loadBalanceMode != LoadBalanceMode.FIFO) continue;
                    this.queue.add((Model)model);
                    continue;
                }
                if (!(this.sourceModel instanceof MultiLayerNetwork)) continue;
                model = new MultiLayerNetwork(MultiLayerConfiguration.fromJson((String)((MultiLayerNetwork)this.sourceModel).getLayerWiseConfigurations().toJson()));
                model.init(params, false);
                Nd4j.getExecutioner().commit();
                this.replicas.add((Model)model);
                if (this.loadBalanceMode != LoadBalanceMode.FIFO) continue;
                this.queue.add((Model)model);
            }
        }

        protected Model acquireModel() throws InterruptedException {
            try {
                this.modelLock.readLock().lock();
                switch (this.loadBalanceMode) {
                    case FIFO: {
                        Model model = this.queue.take();
                        return model;
                    }
                    case ROUND_ROBIN: {
                        Model model = this.replicas.get((int)(this.position.getAndIncrement() % (long)this.replicas.size()));
                        return model;
                    }
                }
                throw new ND4JIllegalStateException("Unknown LoadBalanceMode was specified: [" + (Object)((Object)this.loadBalanceMode) + "]");
            }
            finally {
                this.modelLock.readLock().unlock();
            }
        }

        /*
         * Enabled force condition propagation
         * Lifted jumps to return sites
         */
        protected void releaseModel(Model model) {
            try {
                this.modelLock.readLock().lock();
                switch (this.loadBalanceMode) {
                    case FIFO: {
                        this.queue.add(model);
                        return;
                    }
                    case ROUND_ROBIN: {
                        return;
                    }
                    default: {
                        throw new ND4JIllegalStateException("Unknown LoadBalanceMode was specified: [" + (Object)((Object)this.loadBalanceMode) + "]");
                    }
                }
            }
            finally {
                this.modelLock.readLock().unlock();
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        protected INDArray[] output(INDArray[] input, INDArray[] inputMasks) {
            try {
                this.modelLock.readLock().lock();
                if (this.isCG) {
                    INDArray[] output;
                    Model model = this.acquireModel();
                    try {
                        output = ((ComputationGraph)model).output(false, input, inputMasks);
                    }
                    finally {
                        this.releaseModel(model);
                    }
                    INDArray[] iNDArrayArray = output;
                    return iNDArrayArray;
                }
                if (this.isMLN) {
                    INDArray result;
                    if (input.length > 1 || inputMasks != null && inputMasks.length > 1) {
                        throw new ND4JIllegalStateException("MultilayerNetwork can't have multiple inputs");
                    }
                    Model model = this.acquireModel();
                    try {
                        result = ((MultiLayerNetwork)model).output(input[0], false, inputMasks == null ? null : inputMasks[0], null);
                    }
                    finally {
                        this.releaseModel(model);
                    }
                    INDArray[] iNDArrayArray = new INDArray[]{result};
                    return iNDArrayArray;
                }
                try {
                    throw new UnsupportedOperationException();
                }
                catch (InterruptedException e) {
                    throw new RuntimeException(e);
                }
            }
            finally {
                this.modelLock.readLock().unlock();
            }
        }

        protected void updateModel(@NonNull Model model) {
            if (model == null) {
                throw new NullPointerException("model is marked @NonNull but is null");
            }
            try {
                this.modelLock.writeLock().lock();
                this.sourceModel = model;
                this.init();
            }
            finally {
                this.modelLock.writeLock().unlock();
            }
        }

        private static int $default$workers() {
            return 4;
        }

        private static List<Model> $default$replicas() {
            return new ArrayList<Model>();
        }

        private static boolean $default$rootDevice() {
            return true;
        }

        private static LoadBalanceMode $default$loadBalanceMode() {
            return LoadBalanceMode.ROUND_ROBIN;
        }

        private static boolean $default$isCG() {
            return false;
        }

        private static boolean $default$isMLN() {
            return false;
        }

        public static ModelHolderBuilder builder() {
            return new ModelHolderBuilder();
        }

        public ModelHolder() {
            this.workers = ModelHolder.$default$workers();
            this.replicas = ModelHolder.$default$replicas();
            this.rootDevice = ModelHolder.$default$rootDevice();
            this.loadBalanceMode = ModelHolder.$default$loadBalanceMode();
            this.isCG = ModelHolder.$default$isCG();
            this.isMLN = ModelHolder.$default$isMLN();
        }

        public ModelHolder(Model sourceModel, int workers, List<Model> replicas, boolean rootDevice, LoadBalanceMode loadBalanceMode, int targetDeviceId, boolean isCG, boolean isMLN) {
            this.sourceModel = sourceModel;
            this.workers = workers;
            this.replicas = replicas;
            this.rootDevice = rootDevice;
            this.loadBalanceMode = loadBalanceMode;
            this.targetDeviceId = targetDeviceId;
            this.isCG = isCG;
            this.isMLN = isMLN;
        }

        public static class ModelHolderBuilder {
            private Model sourceModel;
            private boolean workers$set;
            private int workers;
            private boolean replicas$set;
            private List<Model> replicas;
            private boolean rootDevice$set;
            private boolean rootDevice;
            private boolean loadBalanceMode$set;
            private LoadBalanceMode loadBalanceMode;
            private int targetDeviceId;
            private boolean isCG$set;
            private boolean isCG;
            private boolean isMLN$set;
            private boolean isMLN;

            ModelHolderBuilder() {
            }

            public ModelHolderBuilder sourceModel(Model sourceModel) {
                this.sourceModel = sourceModel;
                return this;
            }

            public ModelHolderBuilder workers(int workers) {
                this.workers = workers;
                this.workers$set = true;
                return this;
            }

            public ModelHolderBuilder replicas(List<Model> replicas) {
                this.replicas = replicas;
                this.replicas$set = true;
                return this;
            }

            public ModelHolderBuilder rootDevice(boolean rootDevice) {
                this.rootDevice = rootDevice;
                this.rootDevice$set = true;
                return this;
            }

            public ModelHolderBuilder loadBalanceMode(LoadBalanceMode loadBalanceMode) {
                this.loadBalanceMode = loadBalanceMode;
                this.loadBalanceMode$set = true;
                return this;
            }

            public ModelHolderBuilder targetDeviceId(int targetDeviceId) {
                this.targetDeviceId = targetDeviceId;
                return this;
            }

            public ModelHolderBuilder isCG(boolean isCG) {
                this.isCG = isCG;
                this.isCG$set = true;
                return this;
            }

            public ModelHolderBuilder isMLN(boolean isMLN) {
                this.isMLN = isMLN;
                this.isMLN$set = true;
                return this;
            }

            public ModelHolder build() {
                int workers = this.workers;
                if (!this.workers$set) {
                    workers = ModelHolder.$default$workers();
                }
                List replicas = this.replicas;
                if (!this.replicas$set) {
                    replicas = ModelHolder.$default$replicas();
                }
                boolean rootDevice = this.rootDevice;
                if (!this.rootDevice$set) {
                    rootDevice = ModelHolder.$default$rootDevice();
                }
                LoadBalanceMode loadBalanceMode = this.loadBalanceMode;
                if (!this.loadBalanceMode$set) {
                    loadBalanceMode = ModelHolder.$default$loadBalanceMode();
                }
                boolean isCG = this.isCG;
                if (!this.isCG$set) {
                    isCG = ModelHolder.$default$isCG();
                }
                boolean isMLN = this.isMLN;
                if (!this.isMLN$set) {
                    isMLN = ModelHolder.$default$isMLN();
                }
                return new ModelHolder(this.sourceModel, workers, replicas, rootDevice, loadBalanceMode, this.targetDeviceId, isCG, isMLN);
            }

            public String toString() {
                return "InplaceParallelInference.ModelHolder.ModelHolderBuilder(sourceModel=" + this.sourceModel + ", workers=" + this.workers + ", replicas=" + this.replicas + ", rootDevice=" + this.rootDevice + ", loadBalanceMode=" + (Object)((Object)this.loadBalanceMode) + ", targetDeviceId=" + this.targetDeviceId + ", isCG=" + this.isCG + ", isMLN=" + this.isMLN + ")";
            }
        }
    }

    protected static class ModelSelector {
        protected Map<Integer, ModelHolder> map = new HashMap<Integer, ModelHolder>();
        protected final LoadBalanceMode loadBalanceMode;

        public ModelSelector() {
            this(LoadBalanceMode.ROUND_ROBIN);
        }

        public ModelSelector(LoadBalanceMode loadBalanceMode) {
            this.loadBalanceMode = loadBalanceMode;
        }

        protected void addModelHolder(@NonNull Integer device, @NonNull ModelHolder holder) {
            if (device == null) {
                throw new NullPointerException("device is marked @NonNull but is null");
            }
            if (holder == null) {
                throw new NullPointerException("holder is marked @NonNull but is null");
            }
            this.map.put(device, holder);
        }

        public ModelHolder getModelForThread(long threadId) {
            Integer device = Nd4j.getAffinityManager().getDeviceForThread(threadId);
            ModelHolder q = this.map.get(device);
            return q;
        }

        public INDArray[] output(INDArray[] input, INDArray[] inputMasks) {
            return this.getModelForThisThread().output(input, inputMasks);
        }

        public ModelHolder getModelForThisThread() {
            return this.getModelForThread(Thread.currentThread().getId());
        }
    }
}

