/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.zoo.nlp.sentimentanalysis;

import ai.djl.Model;
import ai.djl.modality.Classifications;
import ai.djl.modality.nlp.SimpleVocabulary;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.bert.BertTokenizer;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.StackBatchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.net.URL;
import java.util.Arrays;
import java.util.List;

public class PtDistilBertTranslator
implements Translator<String, Classifications> {
    private Vocabulary vocabulary;
    private BertTokenizer tokenizer;

    public Batchifier getBatchifier() {
        return new StackBatchifier();
    }

    public void prepare(NDManager manager, Model model) throws IOException {
        URL url = model.getArtifact("distilbert-base-uncased-finetuned-sst-2-english-vocab.txt");
        this.vocabulary = SimpleVocabulary.builder().optMinFrequency(1).addFromTextFile(url).optUnknownToken("[UNK]").build();
        this.tokenizer = new BertTokenizer();
    }

    public Classifications processOutput(TranslatorContext ctx, NDList list) {
        NDArray raw = list.singletonOrThrow();
        NDArray computed = raw.exp().div(raw.exp().sum(new int[]{0}, true));
        return new Classifications(Arrays.asList("Negative", "Positive"), computed);
    }

    public NDList processInput(TranslatorContext ctx, String input) {
        List tokens = this.tokenizer.tokenize(input);
        long[] indices = tokens.stream().mapToLong(arg_0 -> ((Vocabulary)this.vocabulary).getIndex(arg_0)).toArray();
        long[] attentionMask = new long[tokens.size()];
        Arrays.fill(attentionMask, 1L);
        NDManager manager = ctx.getNDManager();
        NDArray indicesArray = manager.create(indices);
        NDArray attentionMaskArray = manager.create(attentionMask);
        return new NDList(new NDArray[]{indicesArray, attentionMaskArray});
    }
}

