/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.inference;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.metric.Metrics;
import ai.djl.ndarray.LazyNDArray;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Block;
import ai.djl.training.ParameterStore;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Predictor<I, O>
implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(Predictor.class);
    private Translator<I, O> translator;
    private long timestamp;
    private boolean prepared;
    private Model model;
    protected NDManager manager;
    protected Metrics metrics;
    protected Block block;
    protected ParameterStore parameterStore;

    public Predictor(Model model, Translator<I, O> translator, Device device, boolean copy) {
        if (!device.equals(model.getNDManager().getDevice())) {
            copy = true;
        }
        this.model = model;
        this.manager = model.getNDManager().newSubManager(device);
        this.manager.setName("predictor");
        this.translator = translator;
        this.block = model.getBlock();
        this.parameterStore = new ParameterStore(this.manager, copy);
    }

    public O predict(I input) throws TranslateException {
        return this.batchPredict(Collections.singletonList(input)).get(0);
    }

    protected NDList predictInternal(TranslatorContext ctx, NDList ndList) throws TranslateException {
        logger.trace("Predictor input data: {}", (Object)ndList);
        return this.block.forward(this.parameterStore, ndList, false);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public List<O> batchPredict(List<I> inputs) throws TranslateException {
        long begin = System.nanoTime();
        try (PredictorContext context = new PredictorContext();){
            Batchifier batchifier;
            if (!this.prepared) {
                this.translator.prepare(context);
                this.prepared = true;
            }
            if ((batchifier = this.translator.getBatchifier()) == null) {
                ArrayList ret = new ArrayList(inputs.size());
                Object object = inputs.iterator();
                while (true) {
                    if (!object.hasNext()) {
                        object = ret;
                        return object;
                    }
                    I input = object.next();
                    begin = this.timestamp = System.nanoTime();
                    NDList ndList = this.translator.processInput(context, input);
                    this.preprocessEnd(ndList);
                    NDList result = this.predictInternal(context, ndList);
                    this.predictEnd(result);
                    ret.add(this.translator.processOutput(context, result));
                    this.postProcessEnd(begin);
                }
            }
            this.timestamp = System.nanoTime();
            NDList inputBatch = this.processInputs(context, inputs);
            this.preprocessEnd(inputBatch);
            NDList result = this.predictInternal(context, inputBatch);
            this.predictEnd(result);
            List<O> ret = this.processOutputs(context, result);
            this.postProcessEnd(begin);
            List<O> list = ret;
            return list;
        }
        catch (TranslateException e) {
            throw e;
        }
        catch (Exception e) {
            throw new TranslateException(e);
        }
    }

    public void setMetrics(Metrics metrics) {
        this.metrics = metrics;
    }

    private void waitToRead(NDList list) {
        for (NDArray array : list) {
            if (!(array instanceof LazyNDArray)) continue;
            ((LazyNDArray)array).waitToRead();
        }
    }

    private NDList processInputs(TranslatorContext ctx, List<I> inputs) throws Exception {
        int batchSize = inputs.size();
        NDList[] preprocessed = new NDList[batchSize];
        for (int i = 0; i < batchSize; ++i) {
            preprocessed[i] = this.translator.processInput(ctx, inputs.get(i));
        }
        return this.translator.getBatchifier().batchify(preprocessed);
    }

    private List<O> processOutputs(TranslatorContext ctx, NDList list) throws Exception {
        NDList[] unbatched = this.translator.getBatchifier().unbatchify(list);
        ArrayList outputs = new ArrayList(unbatched.length);
        for (NDList output : unbatched) {
            outputs.add(this.translator.processOutput(ctx, output));
        }
        return outputs;
    }

    private void preprocessEnd(NDList list) {
        if (this.metrics != null) {
            this.waitToRead(list);
            long tmp = System.nanoTime();
            long duration = tmp - this.timestamp;
            this.timestamp = tmp;
            this.metrics.addMetric("Preprocess", duration, "nano");
        }
    }

    private void predictEnd(NDList list) {
        if (this.metrics != null) {
            this.waitToRead(list);
            long tmp = System.nanoTime();
            long duration = tmp - this.timestamp;
            this.timestamp = tmp;
            this.metrics.addMetric("Inference", duration, "nano");
        }
    }

    private void postProcessEnd(long begin) {
        if (this.metrics != null) {
            long tmp = System.nanoTime();
            long duration = tmp - this.timestamp;
            this.timestamp = tmp;
            this.metrics.addMetric("Postprocess", duration, "nano");
            this.metrics.addMetric("Total", tmp - begin, "nano");
        }
    }

    @Override
    public void close() {
        this.manager.close();
    }

    protected void finalize() throws Throwable {
        if (this.manager.isOpen()) {
            if (logger.isDebugEnabled()) {
                logger.warn("Predictor for {} was not closed explicitly.", (Object)this.model.getName());
            }
            this.close();
        }
        super.finalize();
    }

    private class PredictorContext
    implements TranslatorContext {
        private NDManager ctxManager;
        private Map<String, Object> attachments;

        PredictorContext() {
            this.ctxManager = Predictor.this.manager.newSubManager();
            this.ctxManager.setName("predictor ctx");
            this.attachments = new ConcurrentHashMap<String, Object>();
        }

        @Override
        public Model getModel() {
            return Predictor.this.model;
        }

        @Override
        public NDManager getNDManager() {
            return this.ctxManager;
        }

        @Override
        public NDManager getPredictorManager() {
            return Predictor.this.manager;
        }

        @Override
        public Block getBlock() {
            return Predictor.this.block;
        }

        @Override
        public Metrics getMetrics() {
            return Predictor.this.metrics;
        }

        @Override
        public void close() {
            this.ctxManager.close();
        }

        @Override
        public Object getAttachment(String key) {
            return this.attachments.get(key);
        }

        @Override
        public void setAttachment(String key, Object value) {
            this.attachments.put(key, value);
        }
    }
}

