package com.clearnlp.classification.algorithm;

import com.clearnlp.classification.instance.IntInstance;
import com.clearnlp.classification.model.StringModelAD;
import com.clearnlp.classification.prediction.IntPrediction;
import com.clearnlp.classification.vector.SparseFeatureVector;
import com.clearnlp.util.UTMath;
import java.util.Collections;
import java.util.List;

/* loaded from: input_file:com/clearnlp/classification/algorithm/AdaGradOnlineHingeLoss.class */
public class AdaGradOnlineHingeLoss extends AbstractAdaGrad {
    public AdaGradOnlineHingeLoss(double d, double d2, boolean z) {
        super(d, d2, z);
    }

    @Override // com.clearnlp.classification.algorithm.AbstractAdaGrad
    protected boolean update(StringModelAD stringModelAD, IntInstance intInstance, int i) {
        IntPrediction prediction = getPrediction(stringModelAD, intInstance);
        if (prediction.label == intInstance.getLabel()) {
            return false;
        }
        updateCounts(stringModelAD, intInstance, intInstance.getLabel(), prediction.label);
        updateWeights(stringModelAD, intInstance, intInstance.getLabel(), prediction.label, i);
        return true;
    }

    protected IntPrediction getPrediction(StringModelAD stringModelAD, IntInstance intInstance) {
        List<IntPrediction> intPredictions = stringModelAD.getIntPredictions(intInstance.getFeatureVector());
        intPredictions.get(intInstance.getLabel()).score -= 1.0d;
        return (IntPrediction) Collections.max(intPredictions);
    }

    private void updateCounts(StringModelAD stringModelAD, IntInstance intInstance, int i, int i2) {
        SparseFeatureVector featureVector = intInstance.getFeatureVector();
        int size = featureVector.size();
        for (int i3 = 0; i3 < size; i3++) {
            int index = featureVector.getIndex(i3);
            double sq = UTMath.sq(featureVector.getWeight(i3));
            double[] dArr = this.d_gradients;
            int weightIndex = stringModelAD.getWeightIndex(i, index);
            dArr[weightIndex] = dArr[weightIndex] + sq;
            double[] dArr2 = this.d_gradients;
            int weightIndex2 = stringModelAD.getWeightIndex(i2, index);
            dArr2[weightIndex2] = dArr2[weightIndex2] + sq;
        }
    }

    private void updateWeights(StringModelAD stringModelAD, IntInstance intInstance, int i, int i2, int i3) {
        SparseFeatureVector featureVector = intInstance.getFeatureVector();
        int size = featureVector.size();
        for (int i4 = 0; i4 < size; i4++) {
            int index = featureVector.getIndex(i4);
            double weight = featureVector.getWeight(i4);
            updateWeight(stringModelAD, i, index, weight, i3);
            updateWeight(stringModelAD, i2, index, -weight, i3);
        }
    }
}
