package io.spokestack.spokestack.nlu.tensorflow;

import androidx.annotation.NonNull;
import io.spokestack.spokestack.nlu.NLUContext;
import io.spokestack.spokestack.nlu.Slot;
import io.spokestack.spokestack.nlu.tensorflow.Metadata;
import io.spokestack.spokestack.util.Tuple;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:io/spokestack/spokestack/nlu/tensorflow/TFNLUOutput.class */
final class TFNLUOutput {
    private final Metadata metadata;
    private Map<String, SlotParser> slotParsers = new HashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    public TFNLUOutput(Metadata metadata) {
        this.metadata = metadata;
    }

    public void registerSlotParsers(Map<String, SlotParser> map) {
        this.slotParsers = map;
    }

    public Tuple<Metadata.Intent, Float> getIntent(ByteBuffer byteBuffer) {
        Metadata.Intent[] intents = this.metadata.getIntents();
        Tuple<Integer, Float> bufferArgMax = bufferArgMax(byteBuffer, intents.length);
        return new Tuple<>(intents[bufferArgMax.first().intValue()], bufferArgMax.second());
    }

    public Map<String, String> getSlots(NLUContext nLUContext, EncodedTokens encodedTokens, ByteBuffer byteBuffer) {
        String[] labels = getLabels(byteBuffer, encodedTokens.getIds().size());
        nLUContext.traceDebug("Tag labels: %s", Arrays.toString(labels));
        HashMap hashMap = new HashMap();
        Tuple tuple = null;
        for (int i = 0; i < labels.length; i++) {
            String str = labels[i];
            if (!str.equals("o")) {
                String substring = str.substring(2);
                if (tuple == null) {
                    tuple = new Tuple(substring, Integer.valueOf(i));
                    hashMap.put(tuple.second(), Integer.valueOf(i + 1));
                } else if (substring.equals(tuple.first())) {
                    hashMap.put(tuple.second(), Integer.valueOf(i + 1));
                } else {
                    tuple = new Tuple(substring, Integer.valueOf(i));
                }
            } else if (tuple != null) {
                hashMap.put(tuple.second(), Integer.valueOf(i));
                tuple = null;
            }
        }
        HashMap hashMap2 = new HashMap();
        for (Map.Entry entry : hashMap.entrySet()) {
            String decodeRange = encodedTokens.decodeRange(((Integer) entry.getKey()).intValue(), ((Integer) entry.getValue()).intValue(), true);
            String substring2 = labels[((Integer) entry.getKey()).intValue()].substring(2);
            String str2 = (String) hashMap2.get(substring2);
            if (str2 != null) {
                hashMap2.put(substring2, str2 + " " + decodeRange);
            } else {
                hashMap2.put(substring2, decodeRange);
            }
        }
        return hashMap2;
    }

    String[] getLabels(ByteBuffer byteBuffer, int i) {
        int length = this.metadata.getTags().length;
        String[] strArr = new String[i];
        for (int i2 = 0; i2 < strArr.length; i2++) {
            strArr[i2] = this.metadata.getTags()[bufferArgMax(byteBuffer, length).first().intValue()];
        }
        return strArr;
    }

    private Tuple<Integer, Float> bufferArgMax(ByteBuffer byteBuffer, int i) {
        float[] fArr = new float[i];
        for (int i2 = 0; i2 < i; i2++) {
            fArr[i2] = byteBuffer.getFloat();
        }
        return argMax(fArr);
    }

    private Tuple<Integer, Float> argMax(float[] fArr) {
        int i = 0;
        float f = fArr[0];
        for (int i2 = 1; i2 < fArr.length; i2++) {
            float f2 = fArr[i2];
            if (f2 > f) {
                i = i2;
                f = f2;
            }
        }
        return new Tuple<>(Integer.valueOf(i), Float.valueOf(f));
    }

    public Map<String, Slot> parseSlots(@NonNull Metadata.Intent intent, @NonNull Map<String, String> map) {
        Map<String, Slot> parseImplicitSlots = parseImplicitSlots(intent);
        for (Metadata.Slot slot : intent.getSlots()) {
            String str = map.get(slot.getName());
            if (str == null) {
                String captureName = slot.getCaptureName();
                if (!parseImplicitSlots.containsKey(captureName)) {
                    parseImplicitSlots.put(captureName, new Slot(captureName, slot.getType(), null, null));
                }
            } else {
                Slot parseSlotValue = parseSlotValue(slot, str);
                parseImplicitSlots.put(parseSlotValue.getName(), parseSlotValue);
            }
        }
        return parseImplicitSlots;
    }

    private Map<String, Slot> parseImplicitSlots(Metadata.Intent intent) {
        HashMap hashMap = new HashMap();
        for (Metadata.Slot slot : intent.getImplicitSlots()) {
            String captureName = slot.getCaptureName();
            hashMap.put(captureName, new Slot(captureName, slot.getType(), String.valueOf(slot.getValue()), slot.getValue()));
        }
        return hashMap;
    }

    private Slot parseSlotValue(Metadata.Slot slot, String str) {
        String type = slot.getType();
        SlotParser slotParser = this.slotParsers.get(type);
        if (slotParser == null) {
            throw new IllegalArgumentException("No parser found for \"" + type + "\" slot");
        }
        String captureName = slot.getCaptureName();
        try {
            return new Slot(captureName, slot.getType(), str, slotParser.parse(slot.getFacets(), str));
        } catch (Exception e) {
            throw new IllegalArgumentException("Error parsing slot " + captureName, e);
        }
    }
}
