package com.clearnlp.nlp.train;

import com.clearnlp.classification.feature.JointFtrXml;
import com.clearnlp.classification.model.AbstractModel;
import com.clearnlp.classification.model.StringModel;
import com.clearnlp.classification.train.AbstractTrainSpace;
import com.clearnlp.classification.train.StringTrainSpace;
import com.clearnlp.component.AbstractStatisticalComponent;
import com.clearnlp.constant.universal.UNConstant;
import com.clearnlp.constant.universal.UNPunct;
import com.clearnlp.dependency.DEPTree;
import com.clearnlp.nlp.AbstractNLP;
import com.clearnlp.propbank.frameset.PBFLib;
import com.clearnlp.reader.JointReader;
import com.clearnlp.run.AdaGradTrain;
import com.clearnlp.run.LiblinearTrain;
import com.clearnlp.util.UTInput;
import com.clearnlp.util.UTOutput;
import com.clearnlp.util.UTXml;
import com.clearnlp.util.pair.ObjectDoublePair;
import java.io.BufferedOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectOutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.zip.GZIPOutputStream;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;

/* loaded from: input_file:com/clearnlp/nlp/train/AbstractNLPTrainer.class */
public abstract class AbstractNLPTrainer extends AbstractNLP {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/clearnlp/nlp/train/AbstractNLPTrainer$TrainTask.class */
    public class TrainTask implements Runnable {
        AbstractStatisticalComponent<?> j_component;
        JointReader j_reader;

        public TrainTask(Element element, String str, AbstractStatisticalComponent<?> abstractStatisticalComponent) {
            this.j_reader = AbstractNLPTrainer.this.getJointReader(UTXml.getFirstElementByTagName(element, "reader"));
            this.j_reader.open(UTInput.createBufferedFileReader(str));
            this.j_component = abstractStatisticalComponent;
        }

        @Override // java.lang.Runnable
        public void run() {
            while (true) {
                DEPTree next = this.j_reader.next();
                if (next == null) {
                    this.j_reader.close();
                    AbstractNLPTrainer.this.LOG.debug(".");
                    return;
                }
                this.j_component.process(next);
            }
        }
    }

    public void train(Element element, JointFtrXml[] jointFtrXmlArr, String[] strArr, String str) throws Exception {
        getComponent(element, getJointReader(UTXml.getFirstElementByTagName(element, "reader")), jointFtrXmlArr, strArr, -1).save(new ObjectOutputStream(new BufferedOutputStream(new GZIPOutputStream(new FileOutputStream(str + UNPunct.FORWARD_SLASH + getMode())))));
    }

    protected abstract AbstractStatisticalComponent<?> getComponent(Element element, JointReader jointReader, JointFtrXml[] jointFtrXmlArr, String[] strArr, int i);

    protected abstract AbstractStatisticalComponent<?> getComponent(Element element, String str, JointFtrXml[] jointFtrXmlArr, StringModel[] stringModelArr, Object[] objArr);

    protected abstract AbstractStatisticalComponent<?> getComponent(Element element, String str, JointFtrXml[] jointFtrXmlArr, StringTrainSpace[] stringTrainSpaceArr, StringModel[] stringModelArr, Object[] objArr);

    protected abstract StringTrainSpace[] getStringTrainSpaces(JointFtrXml[] jointFtrXmlArr, Object[] objArr, int i);

    public abstract String getMode();

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractStatisticalComponent<?> getTrainedComponent(Element element, JointReader jointReader, AbstractStatisticalComponent<?> abstractStatisticalComponent, JointFtrXml[] jointFtrXmlArr, String[] strArr, int i) {
        Object[] lexica = getLexica(jointReader, abstractStatisticalComponent, jointFtrXmlArr, strArr, i);
        StringTrainSpace[] stringTrainSpaces = getStringTrainSpaces(element, jointFtrXmlArr, strArr, null, lexica, 0, i);
        Element firstElementByTagName = UTXml.getFirstElementByTagName(element, getMode());
        int length = stringTrainSpaces.length;
        StringModel[] stringModelArr = new StringModel[length];
        for (int i2 = 0; i2 < length; i2++) {
            stringModelArr[i2] = (StringModel) getModel(firstElementByTagName, stringTrainSpaces[i2], i2);
            stringTrainSpaces[i2].clear();
        }
        return getComponent(firstElementByTagName, getLanguage(element), jointFtrXmlArr, stringModelArr, lexica);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractStatisticalComponent<?> getTrainedComponentBoot(Element element, JointReader jointReader, AbstractStatisticalComponent<?> abstractStatisticalComponent, JointFtrXml[] jointFtrXmlArr, String[] strArr, int i) {
        Object[] lexica = getLexica(jointReader, abstractStatisticalComponent, jointFtrXmlArr, strArr, i);
        AbstractStatisticalComponent<?> abstractStatisticalComponent2 = null;
        int numerOfBootstraps = getNumerOfBootstraps(UTXml.getFirstElementByTagName(element, getMode()));
        StringModel[] stringModelArr = null;
        for (int i2 = 0; i2 <= numerOfBootstraps; i2++) {
            this.LOG.info(String.format("=== Bootstrap: %d ===\n", Integer.valueOf(i2)));
            abstractStatisticalComponent2 = getTrainedComponent(element, jointFtrXmlArr, strArr, stringModelArr, lexica, i2, i);
            stringModelArr = abstractStatisticalComponent2.getModels();
        }
        return abstractStatisticalComponent2;
    }

    protected AbstractStatisticalComponent<?> getTrainedComponent(Element element, JointFtrXml[] jointFtrXmlArr, String[] strArr, StringModel[] stringModelArr, Object[] objArr, int i, int i2) {
        StringTrainSpace[] stringTrainSpaces = getStringTrainSpaces(element, jointFtrXmlArr, strArr, stringModelArr, objArr, i, i2);
        Element firstElementByTagName = UTXml.getFirstElementByTagName(element, getMode());
        int length = stringTrainSpaces.length;
        StringModel[] stringModelArr2 = new StringModel[length];
        for (int i3 = 0; i3 < length; i3++) {
            stringModelArr2[i3] = (StringModel) getModel(firstElementByTagName, stringTrainSpaces[i3], i3);
            stringTrainSpaces[i3].clear();
        }
        return getComponent(firstElementByTagName, getLanguage(element), jointFtrXmlArr, stringModelArr2, objArr);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void developComponent(Element element, JointReader jointReader, JointFtrXml[] jointFtrXmlArr, String[] strArr, String[] strArr2, AbstractStatisticalComponent<?> abstractStatisticalComponent, boolean z, int i) throws Exception {
        decode(jointReader, getTrainedComponent(element, jointReader, abstractStatisticalComponent, jointFtrXmlArr, strArr, i), strArr2, UNConstant.EMPTY, z);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void developComponentBoot(Element element, JointReader jointReader, JointFtrXml[] jointFtrXmlArr, String[] strArr, String[] strArr2, AbstractStatisticalComponent<?> abstractStatisticalComponent, boolean z, int i) throws Exception {
        double d;
        Object[] lexica = getLexica(jointReader, abstractStatisticalComponent, jointFtrXmlArr, strArr, i);
        double d2 = 0.0d;
        StringModel[] stringModelArr = null;
        int i2 = 0;
        do {
            this.LOG.info(String.format("=== Bootstrap: %d ===\n", Integer.valueOf(i2)));
            d = d2;
            ObjectDoublePair<StringModel[]> developComponent = developComponent(element, jointReader, jointFtrXmlArr, strArr, strArr2, stringModelArr, lexica, i2, z, i);
            stringModelArr = (StringModel[]) developComponent.o;
            d2 = developComponent.d;
            i2++;
        } while (d < d2);
    }

    private ObjectDoublePair<StringModel[]> developComponent(Element element, JointReader jointReader, JointFtrXml[] jointFtrXmlArr, String[] strArr, String[] strArr2, StringModel[] stringModelArr, Object[] objArr, int i, boolean z, int i2) throws Exception {
        AbstractStatisticalComponent<?> trainedComponent = getTrainedComponent(element, jointFtrXmlArr, strArr, stringModelArr, objArr, i, i2);
        return new ObjectDoublePair<>(trainedComponent.getModels(), decode(jointReader, trainedComponent, strArr2, "." + i, z));
    }

    protected double decode(JointReader jointReader, AbstractStatisticalComponent<?> abstractStatisticalComponent, String[] strArr, String str, boolean z) throws Exception {
        PrintStream printStream = null;
        for (String str2 : strArr) {
            if (z) {
                printStream = UTOutput.createPrintBufferedFileStream(str2 + str);
            }
            jointReader.open(UTInput.createBufferedFileReader(str2));
            while (true) {
                DEPTree next = jointReader.next();
                if (next == null) {
                    break;
                }
                abstractStatisticalComponent.process(next);
                if (z) {
                    printStream.println(toString(next, getMode()) + "\n");
                }
            }
            jointReader.close();
            if (z) {
                printStream.close();
            }
        }
        abstractStatisticalComponent.printAccuracies();
        return abstractStatisticalComponent.getAccuracies()[0];
    }

    protected Object[] getLexica(JointReader jointReader, AbstractStatisticalComponent<?> abstractStatisticalComponent, JointFtrXml[] jointFtrXmlArr, String[] strArr, int i) {
        if (abstractStatisticalComponent == null) {
            return null;
        }
        int length = strArr.length;
        this.LOG.info("Collecting lexica:\n");
        for (int i2 = 0; i2 < length; i2++) {
            if (i != i2) {
                jointReader.open(UTInput.createBufferedFileReader(strArr[i2]));
                while (true) {
                    DEPTree next = jointReader.next();
                    if (next == null) {
                        break;
                    }
                    abstractStatisticalComponent.process(next);
                }
                jointReader.close();
                this.LOG.debug(".");
            }
        }
        this.LOG.debug("\n");
        return abstractStatisticalComponent.getLexica();
    }

    protected JointFtrXml[] getFeatureTemplates(String[] strArr) throws Exception {
        int length = strArr.length;
        JointFtrXml[] jointFtrXmlArr = new JointFtrXml[length];
        for (int i = 0; i < length; i++) {
            jointFtrXmlArr[i] = new JointFtrXml(new FileInputStream(strArr[i]));
        }
        return jointFtrXmlArr;
    }

    protected StringTrainSpace[] getStringTrainSpaces(Element element, JointFtrXml[] jointFtrXmlArr, String[] strArr, StringModel[] stringModelArr, Object[] objArr, int i, int i2) {
        Element firstElementByTagName = UTXml.getFirstElementByTagName(element, getMode());
        int length = strArr.length;
        int numerOfThreads = getNumerOfThreads(firstElementByTagName);
        String language = getLanguage(element);
        ArrayList arrayList = new ArrayList();
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(numerOfThreads);
        this.LOG.info("Collecting training instances:\n");
        for (int i3 = 0; i3 < length; i3++) {
            if (i2 != i3) {
                StringTrainSpace[] stringTrainSpaces = getStringTrainSpaces(jointFtrXmlArr, objArr, i);
                arrayList.add(stringTrainSpaces);
                newFixedThreadPool.execute(new TrainTask(element, strArr[i3], getComponent(firstElementByTagName, language, jointFtrXmlArr, stringTrainSpaces, stringModelArr, objArr)));
            }
        }
        newFixedThreadPool.shutdown();
        try {
            newFixedThreadPool.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        this.LOG.debug("\n");
        int length2 = ((StringTrainSpace[]) arrayList.get(0)).length;
        StringTrainSpace[] stringTrainSpaceArr = new StringTrainSpace[length2];
        for (int i4 = 0; i4 < length2; i4++) {
            stringTrainSpaceArr[i4] = ((StringTrainSpace[]) arrayList.get(0))[i4];
            int size = arrayList.size();
            if (size > 1) {
                this.LOG.info("Merging training instances:\n");
                for (int i5 = 1; i5 < size; i5++) {
                    StringTrainSpace stringTrainSpace = stringTrainSpaceArr[i4];
                    StringTrainSpace stringTrainSpace2 = ((StringTrainSpace[]) arrayList.get(i5))[i4];
                    stringTrainSpace.appendSpace(stringTrainSpace2);
                    stringTrainSpace2.clear();
                    this.LOG.debug(".");
                }
                this.LOG.debug("\n");
            }
        }
        return stringTrainSpaceArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public StringTrainSpace[] getStringTrainSpaces(JointFtrXml[] jointFtrXmlArr) {
        return getStringTrainSpaces(jointFtrXmlArr, 0);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public StringTrainSpace[] getStringTrainSpaces(JointFtrXml[] jointFtrXmlArr, int i) {
        int length = jointFtrXmlArr.length;
        StringTrainSpace[] stringTrainSpaceArr = new StringTrainSpace[length];
        for (int i2 = 0; i2 < length; i2++) {
            stringTrainSpaceArr[i2] = new StringTrainSpace(false, jointFtrXmlArr[i2].getLabelCutoff(i), jointFtrXmlArr[i2].getFeatureCutoff(i));
        }
        return stringTrainSpaceArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public StringTrainSpace[] getStringTrainSpaces(JointFtrXml jointFtrXml, int i) {
        StringTrainSpace[] stringTrainSpaceArr = new StringTrainSpace[i];
        for (int i2 = 0; i2 < i; i2++) {
            stringTrainSpaceArr[i2] = new StringTrainSpace(false, jointFtrXml.getLabelCutoff(0), jointFtrXml.getFeatureCutoff(0));
        }
        return stringTrainSpaceArr;
    }

    protected AbstractModel getModel(Element element, AbstractTrainSpace abstractTrainSpace, int i) {
        NodeList elementsByTagName = element.getElementsByTagName("algorithm");
        int numerOfThreads = getNumerOfThreads(element);
        if (i >= elementsByTagName.getLength()) {
            i = 0;
        }
        Element element2 = (Element) elementsByTagName.item(i);
        String trimmedAttribute = UTXml.getTrimmedAttribute(element2, PBFLib.A_NAME);
        if (trimmedAttribute.equals("liblinear")) {
            return getLiblinearModel(abstractTrainSpace, numerOfThreads, Byte.parseByte(UTXml.getTrimmedAttribute(element2, "solver")), Double.parseDouble(UTXml.getTrimmedAttribute(element2, "cost")), Double.parseDouble(UTXml.getTrimmedAttribute(element2, "eps")), Double.parseDouble(UTXml.getTrimmedAttribute(element2, "bias")));
        }
        if (!trimmedAttribute.equals("adagrad")) {
            return null;
        }
        return getAdaGradModel(abstractTrainSpace, UTXml.getTrimmedAttribute(element2, "type").equals("hinge") ? (byte) 3 : (byte) 4, Double.parseDouble(UTXml.getTrimmedAttribute(element2, "alpha")), Double.parseDouble(UTXml.getTrimmedAttribute(element2, "rho")), Double.parseDouble(UTXml.getTrimmedAttribute(element2, "eps")), UTXml.getTrimmedAttribute(element2, "average").equalsIgnoreCase("true"));
    }

    protected AbstractModel getLiblinearModel(AbstractTrainSpace abstractTrainSpace, int i, byte b, double d, double d2, double d3) {
        abstractTrainSpace.build();
        this.LOG.info(String.format("Liblinear: solver=%d, cost=%5.3f, eps=%5.3f, bias=%5.3f\n", Byte.valueOf(b), Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d3)));
        return LiblinearTrain.getModel(abstractTrainSpace, i, b, d, d2, d3);
    }

    protected AbstractModel getAdaGradModel(AbstractTrainSpace abstractTrainSpace, byte b, double d, double d2, double d3, boolean z) {
        abstractTrainSpace.build();
        this.LOG.info(String.format("AdaGrad: solver=%d, alpha=%5.3f, rho=%5.3f, eps=%5.3f, average=%b\n", Byte.valueOf(b), Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d3), Boolean.valueOf(z)));
        return AdaGradTrain.getModel(abstractTrainSpace, b, d, d2, d3, z);
    }

    protected AbstractModel updateModel(Element element, AbstractTrainSpace abstractTrainSpace, int i, int i2) {
        NodeList elementsByTagName = element.getElementsByTagName("algorithm");
        if (i >= elementsByTagName.getLength()) {
            i = 0;
        }
        Element element2 = (Element) elementsByTagName.item(i);
        if (!UTXml.getTrimmedAttribute(element2, PBFLib.A_NAME).equals("adagrad")) {
            return null;
        }
        return getAdaGradModel(abstractTrainSpace, UTXml.getTrimmedAttribute(element2, "type").equals("hinge") ? (byte) 3 : (byte) 4, Double.parseDouble(UTXml.getTrimmedAttribute(element2, "alpha")), Double.parseDouble(UTXml.getTrimmedAttribute(element2, "rho")), Double.parseDouble(UTXml.getTrimmedAttribute(element2, "eps")), UTXml.getTrimmedAttribute(element2, "average").equalsIgnoreCase("true"));
    }

    protected AbstractModel updateAdaGradModel(AbstractTrainSpace abstractTrainSpace, byte b, double d, double d2, double d3, boolean z) {
        abstractTrainSpace.build();
        this.LOG.info(String.format("AdaGrad: solver=%d, alpha=%5.3f, rho=%5.3f, eps=%5.3f, average=%b\n", Byte.valueOf(b), Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d3), Boolean.valueOf(z)));
        AdaGradTrain.getAlgorithm(b, d, d2, d3).updateWeights(abstractTrainSpace, z);
        return AdaGradTrain.getModel(abstractTrainSpace, b, d, d2, d3, z);
    }
}
