/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.zoo.nlp.qa;

import ai.djl.modality.nlp.DefaultVocabulary;
import ai.djl.modality.nlp.Vocabulary;
import ai.djl.modality.nlp.bert.BertFullTokenizer;
import ai.djl.modality.nlp.bert.BertToken;
import ai.djl.modality.nlp.bert.BertTokenizer;
import ai.djl.modality.nlp.qa.QAInput;
import ai.djl.modality.nlp.translator.QATranslator;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.TranslatorContext;
import java.io.IOException;
import java.util.List;
import java.util.Map;

public class PtBertQATranslator
extends QATranslator {
    private List<String> tokens;
    private Vocabulary vocabulary;
    private BertTokenizer tokenizer;

    PtBertQATranslator(Builder builder) {
        super((QATranslator.BaseBuilder)builder);
    }

    public void prepare(TranslatorContext ctx) throws IOException {
        this.vocabulary = DefaultVocabulary.builder().addFromTextFile(ctx.getModel().getArtifact(this.vocab)).optUnknownToken("[UNK]").build();
        this.tokenizer = this.tokenizerName == null ? new BertTokenizer() : new BertFullTokenizer(this.vocabulary, true);
    }

    public NDList processInput(TranslatorContext ctx, QAInput input) {
        String question = input.getQuestion();
        String paragraph = input.getParagraph();
        if (this.toLowerCase) {
            question = question.toLowerCase(this.locale);
            paragraph = paragraph.toLowerCase(this.locale);
        }
        BertToken token = this.padding ? this.tokenizer.encode(question, paragraph, this.maxLength) : this.tokenizer.encode(question, paragraph);
        this.tokens = token.getTokens();
        NDManager manager = ctx.getNDManager();
        long[] indices = this.tokens.stream().mapToLong(arg_0 -> ((Vocabulary)this.vocabulary).getIndex(arg_0)).toArray();
        long[] attentionMask = token.getAttentionMask().stream().mapToLong(i -> i).toArray();
        NDList ndList = new NDList(3);
        ndList.add((Object)manager.create(indices));
        ndList.add((Object)manager.create(attentionMask));
        if (this.includeTokenTypes) {
            long[] tokenTypes = token.getTokenTypes().stream().mapToLong(i -> i).toArray();
            ndList.add((Object)manager.create(tokenTypes));
        }
        return ndList;
    }

    public String processOutput(TranslatorContext ctx, NDList list) {
        int endIdx;
        NDArray startLogits = (NDArray)list.get(0);
        NDArray endLogits = (NDArray)list.get(1);
        int startIdx = (int)startLogits.argMax().getLong(new long[0]);
        if (startIdx >= (endIdx = (int)endLogits.argMax().getLong(new long[0]))) {
            return "";
        }
        return this.tokenizer.tokenToString(this.tokens.subList(startIdx, endIdx + 1));
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(Map<String, ?> arguments) {
        Builder builder = new Builder();
        builder.configure(arguments);
        return builder;
    }

    public static class Builder
    extends QATranslator.BaseBuilder<Builder> {
        protected Builder self() {
            return this;
        }

        protected PtBertQATranslator build() {
            return new PtBertQATranslator(this);
        }
    }
}

