package com.clearnlp.classification.algorithm.old;

import com.clearnlp.classification.prediction.IntPrediction;
import com.clearnlp.util.UTMath;

/* loaded from: input_file:com/clearnlp/classification/algorithm/old/AdaGradHinge.class */
public class AdaGradHinge extends AbstractAdaGrad {
    public AdaGradHinge(double d, double d2, double d3) {
        super(d, d2, d3);
    }

    @Override // com.clearnlp.classification.algorithm.old.AbstractAdaGrad
    protected boolean update(int i, int i2, int[] iArr, double[] dArr, double[] dArr2, double[] dArr3) {
        IntPrediction prediction = getPrediction(i, i2, iArr, dArr, dArr3);
        if (prediction.label == i2) {
            return false;
        }
        updateCounts(i, dArr2, i2, prediction.label, iArr, dArr);
        updateWeights(i, dArr2, i2, prediction.label, iArr, dArr, dArr3);
        return true;
    }

    @Override // com.clearnlp.classification.algorithm.old.AbstractAdaGrad
    protected boolean update(int i, int i2, int[] iArr, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, int i3) {
        IntPrediction prediction = getPrediction(i, i2, iArr, dArr, dArr3);
        if (prediction.label == i2) {
            return false;
        }
        updateCounts(i, dArr2, i2, prediction.label, iArr, dArr);
        updateWeights(i, dArr2, i2, prediction.label, iArr, dArr, dArr3, dArr4, i3);
        return true;
    }

    private IntPrediction getPrediction(int i, int i2, int[] iArr, double[] dArr, double[] dArr2) {
        double[] scores = getScores(i, iArr, dArr, dArr2);
        scores[i2] = scores[i2] - 1.0d;
        IntPrediction intPrediction = new IntPrediction(0, scores[0]);
        for (int i3 = 1; i3 < i; i3++) {
            if (intPrediction.score < scores[i3]) {
                intPrediction.set(i3, scores[i3]);
            }
        }
        return intPrediction;
    }

    private void updateCounts(int i, double[] dArr, int i2, int i3, int[] iArr, double[] dArr2) {
        int length = iArr.length;
        if (dArr2 == null) {
            for (int i4 = 0; i4 < length; i4++) {
                int weightIndex = getWeightIndex(i, i2, iArr[i4]);
                dArr[weightIndex] = dArr[weightIndex] + 1.0d;
                int weightIndex2 = getWeightIndex(i, i3, iArr[i4]);
                dArr[weightIndex2] = dArr[weightIndex2] + 1.0d;
            }
            return;
        }
        for (int i5 = 0; i5 < length; i5++) {
            double sq = UTMath.sq(dArr2[i5]);
            int weightIndex3 = getWeightIndex(i, i2, iArr[i5]);
            dArr[weightIndex3] = dArr[weightIndex3] + sq;
            int weightIndex4 = getWeightIndex(i, i3, iArr[i5]);
            dArr[weightIndex4] = dArr[weightIndex4] + sq;
        }
    }

    private void updateWeights(int i, double[] dArr, int i2, int i3, int[] iArr, double[] dArr2, double[] dArr3) {
        int length = iArr.length;
        if (dArr2 == null) {
            for (int i4 : iArr) {
                int weightIndex = getWeightIndex(i, i2, i4);
                dArr3[weightIndex] = dArr3[weightIndex] + getCost(i, dArr, i2, i4);
                int weightIndex2 = getWeightIndex(i, i3, i4);
                dArr3[weightIndex2] = dArr3[weightIndex2] - getCost(i, dArr, i3, i4);
            }
            return;
        }
        for (int i5 = 0; i5 < length; i5++) {
            int i6 = iArr[i5];
            double d = dArr2[i5];
            int weightIndex3 = getWeightIndex(i, i2, i6);
            dArr3[weightIndex3] = dArr3[weightIndex3] + (getCost(i, dArr, i2, i6) * d);
            int weightIndex4 = getWeightIndex(i, i3, i6);
            dArr3[weightIndex4] = dArr3[weightIndex4] - (getCost(i, dArr, i3, i6) * d);
        }
    }

    private void updateWeights(int i, double[] dArr, int i2, int i3, int[] iArr, double[] dArr2, double[] dArr3, double[] dArr4, int i4) {
        int length = iArr.length;
        if (dArr2 == null) {
            for (int i5 : iArr) {
                updateWeightForAveraging(getWeightIndex(i, i2, i5), getCost(i, dArr, i2, i5), dArr3, dArr4, i4);
                updateWeightForAveraging(getWeightIndex(i, i3, i5), -getCost(i, dArr, i3, i5), dArr3, dArr4, i4);
            }
            return;
        }
        for (int i6 = 0; i6 < length; i6++) {
            int i7 = iArr[i6];
            double d = dArr2[i6];
            updateWeightForAveraging(getWeightIndex(i, i2, i7), getCost(i, dArr, i2, i7) * d, dArr3, dArr4, i4);
            updateWeightForAveraging(getWeightIndex(i, i3, i7), (-getCost(i, dArr, i3, i7)) * d, dArr3, dArr4, i4);
        }
    }
}
