package com.clearnlp.component.online;

import com.clearnlp.classification.feature.FtrToken;
import com.clearnlp.classification.feature.JointFtrXml;
import com.clearnlp.classification.instance.StringInstance;
import com.clearnlp.classification.model.StringModelAD;
import com.clearnlp.classification.prediction.StringPrediction;
import com.clearnlp.classification.vector.StringFeatureVector;
import com.clearnlp.component.evaluation.POSEval;
import com.clearnlp.component.state.TagState;
import com.clearnlp.dependency.DEPNode;
import com.clearnlp.dependency.DEPTree;
import com.clearnlp.nlp.NLPProcess;
import com.clearnlp.pattern.PTPunct;
import com.clearnlp.util.UTArray;
import com.clearnlp.util.UTString;
import com.clearnlp.util.map.Prob2DMap;
import com.clearnlp.util.pair.Pair;
import com.clearnlp.util.pair.StringDoublePair;
import com.google.common.collect.Sets;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Matcher;

/* loaded from: input_file:com/clearnlp/component/online/OnlinePOSTagger.class */
public class OnlinePOSTagger extends AbstractOnlineStatisticalComponent<TagState> {
    protected final int LEXICA_LOWER_SIMPLIFIED_FORMS = 0;
    protected final int LEXICA_AMBIGUITY_CLASSE_PROB = 1;
    protected final int LEXICA_AMBIGUITY_CLASSE_MAP = 2;
    protected Set<String> s_lsfs;
    protected Prob2DMap p_ambi;
    protected Map<String, String> m_ambi;
    private StringModelAD s_model;
    private JointFtrXml f_xml;

    public OnlinePOSTagger(JointFtrXml[] jointFtrXmlArr, Set<String> set) {
        super(jointFtrXmlArr);
        this.LEXICA_LOWER_SIMPLIFIED_FORMS = 0;
        this.LEXICA_AMBIGUITY_CLASSE_PROB = 1;
        this.LEXICA_AMBIGUITY_CLASSE_MAP = 2;
        this.f_xml = this.f_xmls[0];
        this.s_lsfs = set;
        this.p_ambi = new Prob2DMap();
    }

    public OnlinePOSTagger(JointFtrXml[] jointFtrXmlArr, Object[] objArr) {
        super(jointFtrXmlArr, objArr, 1);
        this.LEXICA_LOWER_SIMPLIFIED_FORMS = 0;
        this.LEXICA_AMBIGUITY_CLASSE_PROB = 1;
        this.LEXICA_AMBIGUITY_CLASSE_MAP = 2;
        init();
    }

    public OnlinePOSTagger(ObjectInputStream objectInputStream) {
        super(objectInputStream);
        this.LEXICA_LOWER_SIMPLIFIED_FORMS = 0;
        this.LEXICA_AMBIGUITY_CLASSE_PROB = 1;
        this.LEXICA_AMBIGUITY_CLASSE_MAP = 2;
        init();
    }

    private void init() {
        this.s_model = this.s_models[0];
        this.f_xml = this.f_xmls[0];
    }

    @Override // com.clearnlp.component.online.AbstractOnlineStatisticalComponent
    public Object[] getLexica() {
        Object[] objArr = new Object[3];
        objArr[0] = this.s_lsfs;
        objArr[1] = this.p_ambi;
        objArr[2] = this.m_ambi == null ? getAmbiguityClasses() : this.m_ambi;
        return objArr;
    }

    @Override // com.clearnlp.component.online.AbstractOnlineStatisticalComponent
    public void setLexia(Object[] objArr) {
        this.s_lsfs = (Set) objArr[0];
        this.p_ambi = (Prob2DMap) objArr[1];
        this.m_ambi = (Map) objArr[2];
    }

    private Map<String, String> getAmbiguityClasses() {
        HashMap hashMap = new HashMap();
        double ambiguityClassThreshold = this.f_xml.getAmbiguityClassThreshold();
        for (String str : this.p_ambi.keySet()) {
            StringBuilder sb = new StringBuilder();
            StringDoublePair[] prob1D = this.p_ambi.getProb1D(str);
            UTArray.sortReverseOrder(prob1D);
            for (StringDoublePair stringDoublePair : prob1D) {
                if (stringDoublePair.d <= ambiguityClassThreshold) {
                    break;
                }
                sb.append("_");
                sb.append(stringDoublePair.s);
            }
            if (sb.length() > 0) {
                hashMap.put(str, sb.substring(1));
            }
        }
        return hashMap;
    }

    @Override // com.clearnlp.component.online.AbstractOnlineStatisticalComponent
    public void load(ObjectInputStream objectInputStream) throws Exception {
        loadDefault(objectInputStream);
        loadLexica(objectInputStream);
        objectInputStream.close();
    }

    @Override // com.clearnlp.component.online.AbstractOnlineStatisticalComponent
    public void save(ObjectOutputStream objectOutputStream) throws Exception {
        saveDefault(objectOutputStream);
        saveLexica(objectOutputStream);
        objectOutputStream.close();
    }

    protected void loadLexica(ObjectInputStream objectInputStream) throws Exception {
        this.m_ambi = (Map) objectInputStream.readObject();
    }

    protected void saveLexica(ObjectOutputStream objectOutputStream) throws Exception {
        objectOutputStream.writeObject(this.m_ambi);
    }

    @Override // com.clearnlp.component.online.AbstractOnlineStatisticalComponent
    public Set<String> getLabels() {
        return Sets.newHashSet(this.s_model.getLabels());
    }

    @Override // com.clearnlp.component.online.AbstractOnlineComponent
    public void process(DEPTree dEPTree, byte b) {
        TagState initialize = initialize(dEPTree, b);
        finalize(initialize, processAux(initialize, b), b);
    }

    private List<StringInstance> processAux(TagState tagState, byte b) {
        List<StringInstance> emptyInstanceList = getEmptyInstanceList(b);
        String str = null;
        while (!tagState.isTerminate()) {
            switch (b) {
                case 0:
                    processCollect(tagState);
                    break;
                case 1:
                    str = processTrain(tagState, emptyInstanceList);
                    break;
                case 2:
                default:
                    str = processDecode(tagState);
                    break;
                case 3:
                    str = processBootstrap(tagState, emptyInstanceList);
                    break;
            }
            setLabel(tagState.getInput(), str);
            tagState.moveForward();
        }
        return emptyInstanceList;
    }

    private TagState initialize(DEPTree dEPTree, byte b) {
        TagState tagState = new TagState(dEPTree);
        simplifyForms(dEPTree, b);
        if (b != 2) {
            tagState.setGoldLabels(dEPTree.getPOSTags());
            if (b != 0) {
                dEPTree.clearPOSTags();
            }
        }
        return tagState;
    }

    private void simplifyForms(DEPTree dEPTree, byte b) {
        NLPProcess.simplifyForms(dEPTree);
    }

    private void finalize(TagState tagState, List<StringInstance> list, byte b) {
        if (isTrainOrBootstrap(b)) {
            this.s_model.addInstances(list);
            return;
        }
        if (isEvaluate(b)) {
            if (this.e_eval == null) {
                this.e_eval = new POSEval();
            }
            Object[] goldLabels = tagState.getGoldLabels();
            DEPTree tree = tagState.getTree();
            this.e_eval.countAccuracy(tree, goldLabels);
            tree.setPOSTags((String[]) goldLabels);
        }
    }

    private void processCollect(TagState tagState) {
        DEPNode input = tagState.getInput();
        if (this.s_lsfs.contains(input.lowerSimplifiedForm)) {
            this.p_ambi.add(input.simplifiedForm, input.pos);
        }
    }

    private String processTrain(TagState tagState, List<StringInstance> list) {
        StringFeatureVector featureVector = getFeatureVector(this.f_xml, tagState);
        String goldLabel = getGoldLabel(tagState);
        addInstance(tagState, list, goldLabel, featureVector);
        return goldLabel;
    }

    private String processBootstrap(TagState tagState, List<StringInstance> list) {
        StringFeatureVector featureVector = getFeatureVector(this.f_xml, tagState);
        String autoLabel = getAutoLabel(tagState, featureVector);
        addInstance(tagState, list, getGoldLabel(tagState), featureVector);
        return autoLabel;
    }

    private String processDecode(TagState tagState) {
        return getAutoLabel(tagState, getFeatureVector(this.f_xml, tagState));
    }

    private String getGoldLabel(TagState tagState) {
        return tagState.getGoldLabel();
    }

    private String getAutoLabel(TagState tagState, StringFeatureVector stringFeatureVector) {
        Pair<StringPrediction, StringPrediction> predictTop2 = this.s_model.predictTop2(stringFeatureVector);
        StringPrediction stringPrediction = predictTop2.o1;
        StringPrediction stringPrediction2 = predictTop2.o2;
        if (stringPrediction.score - stringPrediction2.score < 1.0d) {
            tagState.getInput().addFeat("p2", stringPrediction2.label);
        }
        return stringPrediction.label;
    }

    private void addInstance(TagState tagState, List<StringInstance> list, String str, StringFeatureVector stringFeatureVector) {
        if (stringFeatureVector.isEmpty()) {
            return;
        }
        list.add(new StringInstance(str, stringFeatureVector));
    }

    private void setLabel(DEPNode dEPNode, String str) {
        dEPNode.setPOSTag(str);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.clearnlp.component.online.AbstractOnlineStatisticalComponent
    public String getField(FtrToken ftrToken, TagState tagState) {
        DEPNode node = tagState.getNode(ftrToken);
        if (node == null) {
            return null;
        }
        String str = ftrToken.field;
        boolean z = -1;
        switch (str.hashCode()) {
            case JointFtrXml.S_ARG /* 97 */:
                if (str.equals(JointFtrXml.F_AMBIGUITY_CLASS)) {
                    z = 4;
                    break;
                }
                break;
            case JointFtrXml.S_PRED /* 112 */:
                if (str.equals("p")) {
                    z = 2;
                    break;
                }
                break;
            case 3522:
                if (str.equals("p2")) {
                    z = 3;
                    break;
                }
                break;
            case 3667:
                if (str.equals("sf")) {
                    z = false;
                    break;
                }
                break;
            case 107455:
                if (str.equals(JointFtrXml.F_LOWER_SIMPLIFIED_FORM)) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                if (containsLowerSimplifiedForm(node)) {
                    return node.simplifiedForm;
                }
                return null;
            case true:
                if (containsLowerSimplifiedForm(node)) {
                    return node.lowerSimplifiedForm;
                }
                return null;
            case true:
                return node.pos;
            case true:
                return node.getFeat("p2");
            case true:
                return this.m_ambi.get(node.simplifiedForm);
            default:
                Matcher matcher = JointFtrXml.P_BOOLEAN.matcher(ftrToken.field);
                if (!matcher.find()) {
                    Matcher matcher2 = JointFtrXml.P_FEAT.matcher(ftrToken.field);
                    if (matcher2.find()) {
                        return node.getFeat(matcher2.group(1));
                    }
                    Matcher matcher3 = JointFtrXml.P_PREFIX.matcher(ftrToken.field);
                    if (matcher3.find()) {
                        int parseInt = Integer.parseInt(matcher3.group(1));
                        if (parseInt <= node.lowerSimplifiedForm.length()) {
                            return node.lowerSimplifiedForm.substring(0, parseInt);
                        }
                        return null;
                    }
                    Matcher matcher4 = JointFtrXml.P_SUFFIX.matcher(ftrToken.field);
                    if (!matcher4.find()) {
                        throw new IllegalArgumentException("Unsupported feature: " + ftrToken.field);
                    }
                    int parseInt2 = Integer.parseInt(matcher4.group(1));
                    int length = node.lowerSimplifiedForm.length();
                    if (parseInt2 <= length) {
                        return node.lowerSimplifiedForm.substring(length - parseInt2, length);
                    }
                    return null;
                }
                int parseInt3 = Integer.parseInt(matcher.group(1));
                String str2 = ftrToken.field + ftrToken.offset;
                switch (parseInt3) {
                    case 0:
                        if (UTString.isAllUpperCase(node.simplifiedForm)) {
                            return str2;
                        }
                        return null;
                    case 1:
                        if (UTString.isAllLowerCase(node.simplifiedForm)) {
                            return str2;
                        }
                        return null;
                    case 2:
                        if (UTString.beginsWithUpperCase(node.simplifiedForm) && (!tagState.isInputFirstNode())) {
                            return str2;
                        }
                        return null;
                    case 3:
                        if (UTString.getNumOfCapitalsNotAtBeginning(node.simplifiedForm) == 1) {
                            return str2;
                        }
                        return null;
                    case 4:
                        if (UTString.getNumOfCapitalsNotAtBeginning(node.simplifiedForm) > 1) {
                            return str2;
                        }
                        return null;
                    case IFlag.FLAG_GENERATE /* 5 */:
                        if (node.simplifiedForm.contains(".")) {
                            return str2;
                        }
                        return null;
                    case 6:
                        if (UTString.containsDigit(node.simplifiedForm)) {
                            return str2;
                        }
                        return null;
                    case 7:
                        if (node.simplifiedForm.contains("-")) {
                            return str2;
                        }
                        return null;
                    case 8:
                        if (tagState.isInputLastNode()) {
                            return str2;
                        }
                        return null;
                    case 9:
                        if (tagState.isInputFirstNode()) {
                            return str2;
                        }
                        return null;
                    case 10:
                        if (PTPunct.containsOnlyPunctuation(node.lowerSimplifiedForm)) {
                            return str2;
                        }
                        return null;
                    default:
                        throw new IllegalArgumentException("Unsupported feature: " + ftrToken.field);
                }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // com.clearnlp.component.online.AbstractOnlineStatisticalComponent
    public String[] getFields(FtrToken ftrToken, TagState tagState) {
        DEPNode node = tagState.getNode(ftrToken);
        if (node == null) {
            return null;
        }
        String[] strArr = null;
        Matcher matcher = JointFtrXml.P_PREFIX.matcher(ftrToken.field);
        if (matcher.find()) {
            strArr = UTString.getPrefixes(node.lowerSimplifiedForm, Integer.parseInt(matcher.group(1)));
        } else {
            Matcher matcher2 = JointFtrXml.P_SUFFIX.matcher(ftrToken.field);
            if (matcher2.find()) {
                strArr = UTString.getSuffixes(node.lowerSimplifiedForm, Integer.parseInt(matcher2.group(1)));
            }
        }
        if (strArr == null || strArr.length == 0) {
            return null;
        }
        return strArr;
    }

    private boolean containsLowerSimplifiedForm(DEPNode dEPNode) {
        return this.s_lsfs == null || this.s_lsfs.contains(dEPNode.lowerSimplifiedForm);
    }
}
