/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.inference;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.inference.streaming.StreamingBlock;
import ai.djl.inference.streaming.StreamingTranslator;
import ai.djl.metric.Dimension;
import ai.djl.metric.Metrics;
import ai.djl.metric.Unit;
import ai.djl.ndarray.LazyNDArray;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.nn.Block;
import ai.djl.training.ParameterStore;
import ai.djl.translate.Batchifier;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class Predictor<I, O>
implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(Predictor.class);
    protected Translator<I, O> translator;
    protected long timestamp;
    protected boolean prepared;
    protected Model model;
    protected NDManager manager;
    protected Metrics metrics;
    protected Block block;
    protected ParameterStore parameterStore;
    protected Dimension dimension;

    public Predictor(Model model, Translator<I, O> translator, Device device, boolean copy) {
        if (!device.equals(model.getNDManager().getDevice())) {
            copy = true;
        }
        this.model = model;
        this.manager = model.getNDManager().newSubManager(device);
        this.manager.setName("predictor");
        this.translator = translator;
        this.block = model.getBlock();
        this.parameterStore = new ParameterStore(this.manager, copy);
        this.dimension = new Dimension("Model", model.getProperty("metric_dimension", "model"));
    }

    public O predict(I input) throws TranslateException {
        return this.batchPredict(Collections.singletonList(input)).get(0);
    }

    protected NDList predictInternal(TranslatorContext ctx, NDList ndList) throws TranslateException {
        logger.trace("Predictor input data: {}", (Object)ndList);
        if (ndList.isEmpty()) {
            return new NDList();
        }
        return this.block.forward(this.parameterStore, ndList, false);
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    public List<O> batchPredict(List<I> inputs) throws TranslateException {
        try (PredictorContext context = new PredictorContext(this.model, this.manager, this.metrics);){
            if (!this.prepared) {
                this.translator.prepare(context);
                this.prepared = true;
            }
            if (this.translator.getBatchifier() == null) {
                ArrayList ret = new ArrayList(inputs.size());
                Object arrayList = inputs.iterator();
                while (true) {
                    if (!arrayList.hasNext()) {
                        arrayList = ret;
                        return arrayList;
                    }
                    I input = arrayList.next();
                    long begin = this.timestamp = System.nanoTime();
                    NDList ndList = this.translator.processInput(context, input);
                    this.preprocessEnd(ndList, 1);
                    NDList result = this.predictInternal(context, ndList);
                    this.predictEnd(result, 1);
                    ret.add(this.translator.processOutput(context, result));
                    this.postProcessEnd(begin, 1);
                }
            }
            int batchSize = inputs.size();
            long begin = this.timestamp = System.nanoTime();
            NDList ndList = this.translator.batchProcessInput(context, inputs);
            this.preprocessEnd(ndList, batchSize);
            NDList result = this.predictInternal(context, ndList);
            this.predictEnd(result, batchSize);
            List<O> ret = this.translator.batchProcessOutput(context, result);
            this.postProcessEnd(begin, batchSize);
            List<O> list = ret;
            return list;
        }
        catch (TranslateException e) {
            throw e;
        }
        catch (Exception e) {
            throw new TranslateException(e);
        }
    }

    public StreamingTranslator.StreamOutput<O> streamingPredict(I input) throws TranslateException {
        String streamingSupported = this.streamingSupportError();
        if (streamingSupported != null) {
            throw new IllegalStateException(streamingSupported);
        }
        StreamingBlock streamingBlock = (StreamingBlock)this.block;
        StreamingTranslator streamingTranslator = (StreamingTranslator)this.translator;
        try {
            Batchifier batchifier;
            PredictorContext context = new PredictorContext(this.model, this.manager, this.metrics);
            if (!this.prepared) {
                this.translator.prepare(context);
                this.prepared = true;
            }
            if ((batchifier = this.translator.getBatchifier()) == null) {
                NDList ndList = this.translator.processInput(context, input);
                return streamingTranslator.processStreamOutput(context, (Stream)streamingBlock.forwardStream(this.parameterStore, ndList, false).onClose(context::close));
            }
            NDList inputBatch = this.processInputs(context, Collections.singletonList(input));
            return streamingTranslator.processStreamOutput(context, (Stream)streamingBlock.forwardStream(this.parameterStore, inputBatch, false).map(result -> {
                NDList[] unbatched = this.translator.getBatchifier().unbatchify((NDList)result);
                if (unbatched.length != 1) {
                    throw new IllegalStateException("Unexpected number of outputs from model");
                }
                return unbatched[0];
            }).onClose(context::close));
        }
        catch (TranslateException e) {
            throw e;
        }
        catch (Exception e) {
            throw new TranslateException(e);
        }
    }

    public boolean supportsStreaming() {
        return this.streamingSupportError() == null;
    }

    private String streamingSupportError() {
        if (!(this.block instanceof StreamingBlock)) {
            return "streamingPredict() can only be called with a StreamingBlock";
        }
        if (!(this.translator instanceof StreamingTranslator)) {
            return "streamingPredict() can only be called with a StreamingTranslator";
        }
        return null;
    }

    public void setMetrics(Metrics metrics) {
        this.metrics = metrics;
    }

    private void waitToRead(NDList list) {
        for (NDArray array : list) {
            if (!(array instanceof LazyNDArray)) continue;
            ((LazyNDArray)array).waitToRead();
        }
    }

    private NDList processInputs(TranslatorContext ctx, List<I> inputs) throws Exception {
        int batchSize = inputs.size();
        NDList[] preprocessed = new NDList[batchSize];
        for (int i = 0; i < batchSize; ++i) {
            preprocessed[i] = this.translator.processInput(ctx, inputs.get(i));
        }
        return this.translator.getBatchifier().batchify(preprocessed);
    }

    private void preprocessEnd(NDList list, int batchSize) {
        if (this.metrics != null) {
            this.waitToRead(list);
            long tmp = System.nanoTime();
            long duration = (tmp - this.timestamp) / 1000L / (long)batchSize;
            this.timestamp = tmp;
            this.metrics.addMetric("Preprocess", duration, Unit.MICROSECONDS, this.dimension);
        }
    }

    private void predictEnd(NDList list, int batchSize) {
        if (this.metrics != null) {
            this.waitToRead(list);
            long tmp = System.nanoTime();
            long duration = (tmp - this.timestamp) / 1000L / (long)batchSize;
            this.timestamp = tmp;
            this.metrics.addMetric("Inference", duration, Unit.MICROSECONDS, this.dimension);
        }
    }

    private void postProcessEnd(long begin, int batchSize) {
        if (this.metrics != null) {
            long tmp = System.nanoTime();
            long duration = (tmp - this.timestamp) / 1000L / (long)batchSize;
            this.timestamp = tmp;
            this.metrics.addMetric("Postprocess", duration, Unit.MICROSECONDS, this.dimension);
            long prediction = (tmp - begin) / 1000L;
            this.metrics.addMetric("Prediction", prediction, Unit.MICROSECONDS, this.dimension);
        }
    }

    @Override
    public void close() {
        this.manager.close();
    }

    protected void finalize() throws Throwable {
        if (this.manager.isOpen()) {
            if (logger.isDebugEnabled()) {
                logger.warn("Predictor for {} was not closed explicitly.", (Object)this.model.getName());
            }
            this.close();
        }
        super.finalize();
    }

    public static final class PredictorContext
    implements TranslatorContext {
        private Model model;
        private NDManager predictorManager;
        private Metrics metrics;
        private NDManager ctxManager;
        private Map<String, Object> attachments;

        public PredictorContext(Model model, NDManager predictorManager, Metrics metrics) {
            this.model = model;
            this.predictorManager = predictorManager;
            this.metrics = metrics;
            this.ctxManager = predictorManager.newSubManager();
            this.ctxManager.setName("predictor ctx");
            this.attachments = new ConcurrentHashMap<String, Object>();
        }

        @Override
        public Model getModel() {
            return this.model;
        }

        @Override
        public NDManager getNDManager() {
            return this.ctxManager;
        }

        @Override
        public NDManager getPredictorManager() {
            return this.predictorManager;
        }

        @Override
        public Block getBlock() {
            return this.model.getBlock();
        }

        @Override
        public Metrics getMetrics() {
            return this.metrics;
        }

        @Override
        public void close() {
            this.ctxManager.close();
        }

        @Override
        public Object getAttachment(String key) {
            return this.attachments.get(key);
        }

        @Override
        public void setAttachment(String key, Object value) {
            this.attachments.put(key, value);
        }
    }
}

