package edu.stanford.nlp.parser.shiftreduce;

import edu.stanford.nlp.parser.common.ParserConstraint;
import edu.stanford.nlp.parser.metrics.EvaluateTreebank;
import edu.stanford.nlp.parser.shiftreduce.ShiftReduceTrainOptions;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.stats.TwoDimensionalIntCounter;
import edu.stanford.nlp.tagger.common.Tagger;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ReflectionLoading;
import edu.stanford.nlp.util.ScoredComparator;
import edu.stanford.nlp.util.ScoredObject;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/* loaded from: input_file:edu/stanford/nlp/parser/shiftreduce/PerceptronModel.class */
public class PerceptronModel extends BaseModel {
    private float learningRate;
    WeightMap featureWeights;
    final FeatureFactory featureFactory;
    private static final long serialVersionUID = 1;
    private static final Redwood.RedwoodChannels log = Redwood.channels(PerceptronModel.class);
    private static final NumberFormat NF = new DecimalFormat("0.00");
    private static final NumberFormat FILENAME = new DecimalFormat("0000");

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/stanford/nlp/parser/shiftreduce/PerceptronModel$TrainTreeProcessor.class */
    public class TrainTreeProcessor implements ThreadsafeProcessor<TrainingExample, TrainingResult> {
        public TrainTreeProcessor() {
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        public TrainingResult process(TrainingExample trainingExample) {
            return PerceptronModel.this.trainTree(trainingExample);
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        /* renamed from: newInstance, reason: merged with bridge method [inline-methods] */
        public ThreadsafeProcessor<TrainingExample, TrainingResult> newInstance2() {
            return this;
        }
    }

    public PerceptronModel(ShiftReduceOptions shiftReduceOptions, Index<Transition> index, Set<String> set, Set<String> set2, Set<String> set3) {
        super(shiftReduceOptions, index, set, set2, set3);
        this.learningRate = 1.0f;
        this.featureWeights = new WeightMap();
        String[] split = shiftReduceOptions.featureFactoryClass.split(";");
        if (split.length == 1) {
            this.featureFactory = (FeatureFactory) ReflectionLoading.loadByReflection(split[0], new Object[0]);
            return;
        }
        FeatureFactory[] featureFactoryArr = new FeatureFactory[split.length];
        for (int i = 0; i < split.length; i++) {
            int indexOf = split[i].indexOf(40);
            if (indexOf >= 0) {
                featureFactoryArr[i] = (FeatureFactory) ReflectionLoading.loadByReflection(split[i].substring(0, indexOf), split[i].substring(indexOf + 1, split[i].length() - 1));
            } else {
                featureFactoryArr[i] = (FeatureFactory) ReflectionLoading.loadByReflection(split[i], new Object[0]);
            }
        }
        this.featureFactory = new CombinationFeatureFactory(featureFactoryArr);
    }

    public PerceptronModel(PerceptronModel perceptronModel) {
        super(perceptronModel);
        this.learningRate = 1.0f;
        this.featureFactory = perceptronModel.featureFactory;
        this.featureWeights = new WeightMap();
        for (String str : perceptronModel.featureWeights.keySet()) {
            this.featureWeights.put(str, new Weight(perceptronModel.featureWeights.get(str)));
        }
    }

    public void averageScoredModels(Collection<ScoredObject<PerceptronModel>> collection) {
        if (collection.isEmpty()) {
            throw new IllegalArgumentException("Cannot average empty models");
        }
        log.info("Averaging " + collection.size() + " models with scores");
        Iterator<ScoredObject<PerceptronModel>> it = collection.iterator();
        while (it.hasNext()) {
            log.info(" " + NF.format(it.next().score()));
        }
        log.info(new Object[0]);
        averageModels(CollectionUtils.transformAsList(collection, (v0) -> {
            return v0.object();
        }));
    }

    public void averageModels(Collection<PerceptronModel> collection) {
        if (collection.isEmpty()) {
            throw new IllegalArgumentException("Cannot average empty models");
        }
        Set<String> newHashSet = Generics.newHashSet();
        Iterator<PerceptronModel> it = collection.iterator();
        while (it.hasNext()) {
            Iterator<String> it2 = it.next().featureWeights.keySet().iterator();
            while (it2.hasNext()) {
                newHashSet.add(it2.next());
            }
        }
        this.featureWeights = new WeightMap();
        Iterator it3 = newHashSet.iterator();
        while (it3.hasNext()) {
            this.featureWeights.put((String) it3.next(), new Weight());
        }
        int size = collection.size();
        for (String str : newHashSet) {
            for (PerceptronModel perceptronModel : collection) {
                if (perceptronModel.featureWeights.containsKey(str)) {
                    this.featureWeights.get(str).addScaled(perceptronModel.featureWeights.get(str), 1.0f / size);
                }
            }
        }
    }

    private void condenseFeatures() {
        Iterator<String> it = this.featureWeights.keySet().iterator();
        while (it.hasNext()) {
            Weight weight = this.featureWeights.get(it.next());
            weight.condense();
            if (weight.size() == 0) {
                it.remove();
            }
        }
    }

    private void filterFeatures(Set<String> set) {
        Iterator<String> it = this.featureWeights.keySet().iterator();
        while (it.hasNext()) {
            if (!set.contains(it.next())) {
                it.remove();
            }
        }
    }

    public int numWeights() {
        int i = 0;
        Iterator<Map.Entry<String, Weight>> it = this.featureWeights.entrySet().iterator();
        while (it.hasNext()) {
            i += it.next().getValue().size();
        }
        return i;
    }

    public float maxAbs() {
        float f = 0.0f;
        Iterator<Map.Entry<String, Weight>> it = this.featureWeights.entrySet().iterator();
        while (it.hasNext()) {
            f = Math.max(f, it.next().getValue().maxAbs());
        }
        return f;
    }

    public void outputStats(TrainingResult trainingResult) {
        log.info("While training, got " + trainingResult.numCorrect + " transitions correct and " + trainingResult.numWrong + " transitions wrong");
        log.info("Number of known features: " + this.featureWeights.size());
        log.info("Number of non-zero weights: " + numWeights());
        log.info("Weight values maxAbs: " + maxAbs());
        int i = 0;
        Iterator<String> it = this.featureWeights.keySet().iterator();
        while (it.hasNext()) {
            i += it.next().length();
        }
        log.info("Total word length: " + i);
        log.info("Number of transitions: " + this.transitionIndex.size());
        IntCounter<Pair<Integer, Integer>> intCounter = new IntCounter<>();
        Iterator<Pair<Integer, Integer>> it2 = trainingResult.firstErrors.iterator();
        while (it2.hasNext()) {
            intCounter.incrementCount(it2.next());
        }
        outputFirstErrors(intCounter);
        outputReordererStats(trainingResult.reorderSuccess, trainingResult.reorderFail);
        outputTransitionStats(trainingResult);
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @Override // edu.stanford.nlp.parser.shiftreduce.BaseModel
    public Set<String> tagSet() {
        Set<String> newHashSet = Generics.newHashSet();
        Pattern compile = Pattern.compile("Q0TQ1T-([^-]+)-.*");
        Pattern compile2 = Pattern.compile("S0T-(.*)");
        for (String str : this.featureWeights.keySet()) {
            Matcher matcher = compile.matcher(str);
            if (matcher.matches()) {
                newHashSet.add(matcher.group(1));
            }
            Matcher matcher2 = compile2.matcher(str);
            if (matcher2.matches()) {
                newHashSet.add(matcher2.group(1));
            }
        }
        newHashSet.add(".$$.");
        return newHashSet;
    }

    private ScoredObject<Integer> findHighestScoringTransition(State state, List<String> list, boolean z) {
        Collection<ScoredObject<Integer>> findHighestScoringTransitions = findHighestScoringTransitions(state, list, z, 1, null);
        if (findHighestScoringTransitions.isEmpty()) {
            return null;
        }
        return findHighestScoringTransitions.iterator().next();
    }

    @Override // edu.stanford.nlp.parser.shiftreduce.BaseModel
    public Collection<ScoredObject<Integer>> findHighestScoringTransitions(State state, boolean z, int i, List<ParserConstraint> list) {
        return findHighestScoringTransitions(state, this.featureFactory.featurize(state), z, i, list);
    }

    private Collection<ScoredObject<Integer>> findHighestScoringTransitions(State state, List<String> list, boolean z, int i, List<ParserConstraint> list2) {
        float[] fArr = new float[this.transitionIndex.size()];
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            Weight weight = this.featureWeights.get(it.next());
            if (weight != null) {
                weight.score(fArr);
            }
        }
        PriorityQueue priorityQueue = new PriorityQueue(i + 1, ScoredComparator.ASCENDING_COMPARATOR);
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (!z || this.transitionIndex.get(i2).isLegal(state, list2)) {
                priorityQueue.add(new ScoredObject(Integer.valueOf(i2), fArr[i2]));
                if (priorityQueue.size() > i) {
                    priorityQueue.poll();
                }
            }
        }
        return priorityQueue;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public TrainingResult trainTree(TrainingExample trainingExample) {
        int i = 0;
        int i2 = 0;
        Tree tree = trainingExample.binarizedTree;
        ArrayList newArrayList = Generics.newArrayList();
        Pair pair = null;
        IntCounter intCounter = new IntCounter();
        TwoDimensionalIntCounter twoDimensionalIntCounter = new TwoDimensionalIntCounter();
        ReorderingOracle reorderingOracle = (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) ? new ReorderingOracle(this.op, this.rootOnlyStates) : null;
        int i3 = 0;
        int i4 = 0;
        if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM || this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
            if (this.op.trainOptions().beamSize <= 0) {
                throw new IllegalArgumentException("Illegal beam size " + this.op.trainOptions().beamSize);
            }
            PriorityQueue priorityQueue = new PriorityQueue(this.op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
            State initialStateFromGoldTagTree = trainingExample.initialStateFromGoldTagTree();
            List<Transition> trainTransitions = trainingExample.trainTransitions();
            priorityQueue.add(initialStateFromGoldTagTree);
            while (true) {
                if (trainTransitions.size() <= 0) {
                    break;
                }
                Transition transition = trainTransitions.get(0);
                Transition transition2 = null;
                double d = 0.0d;
                PriorityQueue priorityQueue2 = new PriorityQueue(this.op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
                State state = null;
                State state2 = null;
                Iterator it = priorityQueue.iterator();
                while (it.hasNext()) {
                    State state3 = (State) it.next();
                    boolean areTransitionsEqual = initialStateFromGoldTagTree.areTransitionsEqual(state3);
                    for (ScoredObject<Integer> scoredObject : findHighestScoringTransitions(state3, this.featureFactory.featurize(state3), true, this.op.trainOptions().beamSize, null)) {
                        State apply = this.transitionIndex.get(scoredObject.object().intValue()).apply(state3, scoredObject.score());
                        priorityQueue2.add(apply);
                        if (priorityQueue2.size() > this.op.trainOptions().beamSize) {
                            priorityQueue2.poll();
                        }
                        if (state == null || state.score() < apply.score()) {
                            state = apply;
                            state2 = state3;
                        }
                        if (areTransitionsEqual && (transition2 == null || scoredObject.score() > d)) {
                            transition2 = this.transitionIndex.get(scoredObject.object().intValue());
                            d = scoredObject.score();
                        }
                    }
                }
                if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && transition2 == null) {
                    break;
                }
                if (state == null) {
                    System.err.println("Unable to find a best transition!");
                    System.err.println("Previous agenda:");
                    Iterator it2 = priorityQueue.iterator();
                    while (it2.hasNext()) {
                        System.err.println((State) it2.next());
                    }
                    System.err.println("Gold transitions:");
                    System.err.println(trainingExample.transitions);
                } else {
                    State apply2 = transition.apply(initialStateFromGoldTagTree, 0.0d);
                    if (pair == null && !transition2.equals(transition)) {
                        int indexOf = this.transitionIndex.indexOf(transition2);
                        int indexOf2 = this.transitionIndex.indexOf(transition);
                        if (indexOf < 0) {
                            throw new AssertionError("Predicted transition not in the index: " + transition2);
                        }
                        if (indexOf2 < 0) {
                            throw new AssertionError("Gold transition not in the index: " + transition);
                        }
                        pair = new Pair(Integer.valueOf(indexOf), Integer.valueOf(indexOf2));
                    }
                    if (apply2.areTransitionsEqual(state)) {
                        i++;
                        intCounter.incrementCount(transition.getClass());
                        trainTransitions.remove(0);
                    } else {
                        i2++;
                        twoDimensionalIntCounter.incrementCount(transition.getClass(), transition2.getClass());
                        List<String> featurize = this.featureFactory.featurize(initialStateFromGoldTagTree);
                        newArrayList.add(new TrainingUpdate(this.featureFactory.featurize(state2), -1, this.transitionIndex.indexOf(state.transitions.peek()), this.learningRate));
                        newArrayList.add(new TrainingUpdate(featurize, this.transitionIndex.indexOf(transition), -1, this.learningRate));
                        if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM) {
                            if (!ShiftReduceUtils.findStateOnAgenda(priorityQueue2, apply2)) {
                                break;
                            }
                            trainTransitions.remove(0);
                        } else if (this.op.trainOptions().trainingMethod != ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
                            continue;
                        } else if (ShiftReduceUtils.findStateOnAgenda(priorityQueue2, apply2)) {
                            trainTransitions.remove(0);
                        } else if (reorderingOracle.reorder(initialStateFromGoldTagTree, transition2, trainTransitions)) {
                            apply2 = transition2.apply(initialStateFromGoldTagTree);
                            if (ShiftReduceUtils.findStateOnAgenda(priorityQueue2, apply2)) {
                                i3 = 1;
                            } else if (i3 == 0) {
                                i4 = 1;
                            }
                        } else if (i3 == 0) {
                            i4 = 1;
                        }
                    }
                    initialStateFromGoldTagTree = apply2;
                    priorityQueue = priorityQueue2;
                }
            }
        } else if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.EARLY_TERMINATION || this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.GOLD) {
            State initialStateFromGoldTagTree2 = trainingExample.initialStateFromGoldTagTree();
            List<Transition> trainTransitions2 = trainingExample.trainTransitions();
            boolean z = true;
            while (trainTransitions2.size() > 0 && z) {
                Transition transition3 = trainTransitions2.get(0);
                int indexOf3 = this.transitionIndex.indexOf(transition3);
                List<String> featurize2 = this.featureFactory.featurize(initialStateFromGoldTagTree2);
                int intValue = findHighestScoringTransition(initialStateFromGoldTagTree2, featurize2, false).object().intValue();
                Transition transition4 = this.transitionIndex.get(intValue);
                if (indexOf3 == intValue) {
                    trainTransitions2.remove(0);
                    initialStateFromGoldTagTree2 = transition3.apply(initialStateFromGoldTagTree2);
                    i++;
                    intCounter.incrementCount(transition3.getClass());
                } else {
                    i2++;
                    twoDimensionalIntCounter.incrementCount(transition3.getClass(), transition4.getClass());
                    if (pair == null) {
                        pair = new Pair(Integer.valueOf(intValue), Integer.valueOf(indexOf3));
                    }
                    newArrayList.add(new TrainingUpdate(featurize2, indexOf3, intValue, this.learningRate));
                    switch (this.op.trainOptions().trainingMethod) {
                        case EARLY_TERMINATION:
                            z = false;
                            break;
                        case GOLD:
                            trainTransitions2.remove(0);
                            initialStateFromGoldTagTree2 = transition3.apply(initialStateFromGoldTagTree2);
                            break;
                        case REORDER_ORACLE:
                            z = reorderingOracle.reorder(initialStateFromGoldTagTree2, transition4, trainTransitions2);
                            if (z) {
                                initialStateFromGoldTagTree2 = transition4.apply(initialStateFromGoldTagTree2);
                                i3 = 1;
                                break;
                            } else if (i3 == 0) {
                                i4 = 1;
                                break;
                            } else {
                                break;
                            }
                        default:
                            throw new IllegalArgumentException("Unexpected method " + this.op.trainOptions().trainingMethod);
                    }
                }
            }
        }
        return new TrainingResult(newArrayList, i, i2, (Pair<Integer, Integer>) pair, (IntCounter<Class<? extends Transition>>) intCounter, (TwoDimensionalIntCounter<Class<? extends Transition>, Class<? extends Transition>>) twoDimensionalIntCounter, i3, i4);
    }

    private TrainingResult trainBatch(List<TrainingExample> list, MulticoreWrapper<TrainingExample, TrainingResult> multicoreWrapper) {
        ArrayList arrayList = new ArrayList();
        if (this.op.trainOptions.trainingThreads == 1) {
            Iterator<TrainingExample> it = list.iterator();
            while (it.hasNext()) {
                arrayList.add(trainTree(it.next()));
            }
        } else {
            Iterator<TrainingExample> it2 = list.iterator();
            while (it2.hasNext()) {
                multicoreWrapper.put(it2.next());
            }
            multicoreWrapper.join(false);
            while (multicoreWrapper.peek()) {
                arrayList.add(multicoreWrapper.poll());
            }
        }
        return new TrainingResult(arrayList);
    }

    private double evaluate(Tagger tagger, Treebank treebank, String str) {
        ShiftReduceParser shiftReduceParser = new ShiftReduceParser(this.op, this);
        EvaluateTreebank evaluateTreebank = new EvaluateTreebank(shiftReduceParser.getOp(), null, shiftReduceParser, tagger, shiftReduceParser.getExtraEvals(), shiftReduceParser.getParserQueryEvals());
        evaluateTreebank.testOnTreebank(treebank);
        double lBScore = evaluateTreebank.getLBScore();
        log.info(str + ": " + lBScore);
        return lBScore;
    }

    static void augmentSubsentences(List<TrainingExample> list, List<TrainingExample> list2, Random random, float f) {
        for (TrainingExample trainingExample : list2) {
            if (trainingExample.transitions.size() > 10 && random.nextDouble() < f) {
                list.add(new TrainingExample(trainingExample.binarizedTree, trainingExample.transitions, random.nextInt(trainingExample.transitions.size() - 10) + 7));
            }
        }
    }

    private void outputFirstErrors(IntCounter<Pair<Integer, Integer>> intCounter) {
        if (intCounter == null || intCounter.size() == 0) {
            return;
        }
        IntCounter intCounter2 = new IntCounter(intCounter);
        log.info("Most common transition errors: gold -> predicted");
        for (int i = 0; i < 9 && intCounter2.size() > 0; i++) {
            Pair pair = (Pair) intCounter2.argmax();
            intCounter2.decrementCount((IntCounter) pair, intCounter2.max());
            log.info("  # " + (i + 1) + ": " + this.transitionIndex.get(((Integer) pair.second()).intValue()) + " -> " + this.transitionIndex.get(((Integer) pair.first()).intValue()) + " happened " + intCounter2.max() + " times");
        }
    }

    private void outputReordererStats(int i, int i2) {
        if (i == 0 && i2 == 0) {
            return;
        }
        log.info("Reorderer successfully operated at least once on " + i + " training trees and failed to do anything useful on " + i2 + " trees");
    }

    private void outputTransitionStats(TrainingResult trainingResult) {
        List<Class> sortedList = Counters.toSortedList(trainingResult.correctTransitions);
        ArrayList arrayList = new ArrayList();
        arrayList.add("Got the following transition types correct:");
        for (Class cls : sortedList) {
            arrayList.add(ShiftReduceUtils.transitionShortName(cls) + ": " + trainingResult.correctTransitions.getCount(cls));
        }
        log.info(StringUtils.join(arrayList, "\n  "));
        List<Class<? extends Transition>> sortedList2 = Counters.toSortedList(trainingResult.wrongTransitions.totalCounts());
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add("Got the following transition types incorrect:");
        for (Class<? extends Transition> cls2 : sortedList2) {
            IntCounter<Class<? extends Transition>> counter = trainingResult.wrongTransitions.getCounter(cls2);
            for (Class cls3 : Counters.toSortedList(counter)) {
                arrayList2.add(ShiftReduceUtils.transitionShortName(cls2) + " -> " + ShiftReduceUtils.transitionShortName(cls3) + ": " + counter.getCount(cls3));
            }
        }
        log.info(StringUtils.join(arrayList2, "\n  "));
    }

    private void trainModel(String str, Tagger tagger, Random random, List<TrainingExample> list, Treebank treebank, int i, Set<String> set) {
        double d = 0.0d;
        int i2 = 0;
        PriorityQueue priorityQueue = this.op.trainOptions().averagedModels > 0 ? new PriorityQueue(this.op.trainOptions().averagedModels + 1, ScoredComparator.ASCENDING_COMPARATOR) : null;
        MulticoreWrapper<TrainingExample, TrainingResult> multicoreWrapper = i != 1 ? new MulticoreWrapper<>(this.op.trainOptions.trainingThreads, new TrainTreeProcessor()) : null;
        IntCounter intCounter = null;
        if (this.op.trainOptions().featureFrequencyCutoff > 1 && set == null) {
            intCounter = new IntCounter();
        }
        int i3 = 1;
        while (true) {
            if (i3 > this.op.trainOptions.trainingIterations) {
                break;
            }
            Timing timing = new Timing();
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList(list);
            augmentSubsentences(arrayList2, list, random, this.op.trainOptions().augmentSubsentences);
            Collections.shuffle(arrayList2, random);
            log.info("Original list " + list.size() + "; augmented " + arrayList2.size());
            int i4 = 0;
            while (true) {
                int i5 = i4;
                if (i5 >= arrayList2.size()) {
                    break;
                }
                TrainingResult trainBatch = trainBatch(arrayList2.subList(i5, Math.min(i5 + this.op.trainOptions.batchSize, arrayList2.size())), multicoreWrapper);
                arrayList.add(trainBatch);
                for (TrainingUpdate trainingUpdate : trainBatch.updates) {
                    for (String str2 : trainingUpdate.features) {
                        if (set == null || set.contains(str2)) {
                            Weight weight = this.featureWeights.get(str2);
                            if (weight == null) {
                                weight = new Weight();
                                this.featureWeights.put(str2, weight);
                            }
                            weight.updateWeight(trainingUpdate.goldTransition, trainingUpdate.delta);
                            weight.updateWeight(trainingUpdate.predictedTransition, -trainingUpdate.delta);
                            if (intCounter != null) {
                                intCounter.incrementCount((IntCounter) str2, (trainingUpdate.goldTransition < 0 || trainingUpdate.predictedTransition < 0) ? 1 : 2);
                            }
                        }
                    }
                }
                i4 = i5 + this.op.trainOptions.batchSize;
            }
            float f = this.op.trainOptions().l2Reg;
            if (f > 0.0f) {
                Iterator<Map.Entry<String, Weight>> it = this.featureWeights.entrySet().iterator();
                while (it.hasNext()) {
                    it.next().getValue().l2Reg(f);
                }
            }
            float f2 = this.op.trainOptions().l1Reg;
            if (f2 > 0.0f) {
                Iterator<Map.Entry<String, Weight>> it2 = this.featureWeights.entrySet().iterator();
                while (it2.hasNext()) {
                    it2.next().getValue().l1Reg(f2);
                }
            }
            timing.done("Iteration " + i3);
            outputStats(new TrainingResult(arrayList));
            double d2 = 0.0d;
            if (treebank != null) {
                d2 = evaluate(tagger, treebank, "Label F1 for iteration " + i3);
                if (d2 <= d) {
                    log.info("Failed to improve for " + (i3 - i2) + " iteration(s) on previous best score of " + d);
                    if (this.op.trainOptions.stalledIterationLimit > 0 && i3 - i2 >= this.op.trainOptions.stalledIterationLimit) {
                        log.info("Failed to improve for too long, stopping training");
                        break;
                    }
                } else {
                    log.info("New best dev score (previous best " + d + ")");
                    d = d2;
                    i2 = i3;
                }
                log.info("\n\n");
                if (priorityQueue != null) {
                    PerceptronModel perceptronModel = new PerceptronModel(this);
                    perceptronModel.condenseFeatures();
                    priorityQueue.add(new ScoredObject(perceptronModel, d2));
                    if (priorityQueue.size() > this.op.trainOptions().averagedModels) {
                        priorityQueue.poll();
                    }
                }
            }
            if (this.op.trainOptions().saveIntermediateModels && str != null && this.op.trainOptions.debugOutputFrequency > 0) {
                new ShiftReduceParser(this.op, this).saveModel(str.substring(0, str.length() - 7) + "-" + FILENAME.format(i3) + "-" + NF.format(d2) + ".ser.gz");
            }
            if (i3 % 10 == 0 && this.op.trainOptions().decayLearningRate > 0.0d) {
                this.learningRate = (float) (this.learningRate * this.op.trainOptions().decayLearningRate);
            }
            i3++;
        }
        if (multicoreWrapper != null) {
            multicoreWrapper.join();
        }
        if (priorityQueue != null) {
            if (!this.op.trainOptions().cvAveragedModels || treebank == null) {
                averageScoredModels(priorityQueue);
            } else {
                ArrayList newArrayList = Generics.newArrayList();
                while (priorityQueue.size() > 0) {
                    newArrayList.add((ScoredObject) priorityQueue.poll());
                }
                Collections.reverse(newArrayList);
                double d3 = 0.0d;
                int i6 = 0;
                for (int i7 = 1; i7 <= newArrayList.size(); i7++) {
                    log.info("Testing with " + i7 + " models averaged together");
                    averageScoredModels(newArrayList.subList(0, i7));
                    double evaluate = evaluate(tagger, treebank, "Label F1 for " + i7 + " models");
                    if (evaluate > d3) {
                        d3 = evaluate;
                        i6 = i7;
                    }
                }
                averageScoredModels(newArrayList.subList(0, i6));
                log.info("Label F1 for " + i6 + " models: " + d3);
            }
        }
        if (intCounter != null) {
            filterFeatures(intCounter.keysAbove(this.op.trainOptions().featureFrequencyCutoff));
        }
        condenseFeatures();
    }

    static Set<String> pruneFeatures(Set<String> set, Random random, double d) {
        HashSet hashSet = new HashSet();
        for (String str : set) {
            if (random.nextDouble() > d) {
                hashSet.add(str);
            }
        }
        if (hashSet.size() == 0) {
            Iterator<String> it = set.iterator();
            if (it.hasNext()) {
                hashSet.add(it.next());
            }
        }
        return hashSet;
    }

    public static PerceptronModel trainModel(ShiftReduceOptions shiftReduceOptions, Index<Transition> index, Set<String> set, Set<String> set2, Set<String> set3, PerceptronModel perceptronModel, String str, Tagger tagger, Random random, List<TrainingExample> list, Treebank treebank, int i) {
        if (perceptronModel == null) {
            perceptronModel = new PerceptronModel(shiftReduceOptions, index, set, set2, set3);
        }
        if ((!shiftReduceOptions.trainOptions().retrainAfterCutoff || shiftReduceOptions.trainOptions().featureFrequencyCutoff <= 0) && shiftReduceOptions.trainOptions().retrainShards <= 1) {
            PerceptronModel perceptronModel2 = new PerceptronModel(perceptronModel);
            perceptronModel2.trainModel(str, tagger, random, list, treebank, i, null);
            return perceptronModel2;
        }
        String str2 = str.substring(0, str.length() - 7) + "-temp.ser.gz";
        PerceptronModel perceptronModel3 = new PerceptronModel(perceptronModel);
        perceptronModel3.trainModel(str2, tagger, random, list, treebank, i, null);
        if (shiftReduceOptions.trainOptions().saveIntermediateModels) {
            new ShiftReduceParser(shiftReduceOptions, perceptronModel3).saveModel(str2);
        }
        log.info("Beginning retraining");
        Set<String> keySet = perceptronModel3.featureWeights.keySet();
        PerceptronModel perceptronModel4 = new PerceptronModel(perceptronModel);
        perceptronModel4.filterFeatures(keySet);
        perceptronModel4.trainModel(str, tagger, random, list, treebank, i, keySet);
        if (shiftReduceOptions.trainOptions().retrainShards > 1) {
            ArrayList newArrayList = Generics.newArrayList();
            newArrayList.add(perceptronModel4);
            for (int i2 = 1; i2 < shiftReduceOptions.trainOptions().retrainShards; i2++) {
                log.info("Beginning retraining of shard " + (i2 + 1));
                Set<String> pruneFeatures = pruneFeatures(keySet, random, shiftReduceOptions.trainOptions().retrainShardFeatureDrop);
                PerceptronModel perceptronModel5 = new PerceptronModel(perceptronModel);
                perceptronModel5.filterFeatures(pruneFeatures);
                perceptronModel5.trainModel(str, tagger, random, list, treebank, i, pruneFeatures);
                newArrayList.add(perceptronModel5);
            }
            log.info("Averaging " + shiftReduceOptions.trainOptions().retrainShards + " shards");
            perceptronModel4 = new PerceptronModel(perceptronModel);
            perceptronModel4.averageModels(newArrayList);
            perceptronModel4.condenseFeatures();
            perceptronModel4.evaluate(tagger, treebank, "Label F1 for " + shiftReduceOptions.trainOptions().retrainShards + " averaged shards");
        }
        return perceptronModel4;
    }
}
