package com.clearnlp.classification.algorithm.old;

import com.carrotsearch.hppc.IntArrayList;
import com.clearnlp.classification.model.AbstractModel;
import com.clearnlp.classification.train.AbstractTrainSpace;
import com.clearnlp.util.UTArray;
import com.clearnlp.util.UTMath;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;

/* loaded from: input_file:com/clearnlp/classification/algorithm/old/AbstractAdaGrad.class */
public abstract class AbstractAdaGrad extends AbstractMulticlass {
    protected final int MAX_ITER = 1000;
    protected double d_alpha;
    protected double d_rho;
    protected double d_eps;

    public AbstractAdaGrad(double d, double d2, double d3) {
        this.d_alpha = d;
        this.d_rho = d2;
        this.d_eps = d3;
    }

    protected abstract boolean update(int i, int i2, int[] iArr, double[] dArr, double[] dArr2, double[] dArr3);

    protected abstract boolean update(int i, int i2, int[] iArr, double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4, int i3);

    @Override // com.clearnlp.classification.algorithm.old.AbstractMulticlass
    public void updateWeights(AbstractTrainSpace abstractTrainSpace, boolean z) {
        int featureSize = abstractTrainSpace.getFeatureSize();
        int labelSize = abstractTrainSpace.getLabelSize();
        int instanceSize = abstractTrainSpace.getInstanceSize();
        int i = featureSize * labelSize;
        IntArrayList ys = abstractTrainSpace.getYs();
        ArrayList<int[]> xs = abstractTrainSpace.getXs();
        ArrayList<double[]> vs = abstractTrainSpace.getVs();
        AbstractModel model = abstractTrainSpace.getModel();
        double[] dArr = new double[i];
        double[] dArr2 = z ? new double[i] : null;
        double[] dArr3 = new double[i];
        double d = 0.0d;
        int[] range = UTArray.range(instanceSize);
        int i2 = 1;
        double[] dArr4 = null;
        for (int i3 = 0; i3 < 1000; i3++) {
            UTArray.shuffle(new Random(5L), range, instanceSize);
            double d2 = d;
            Arrays.fill(dArr3, 0.0d);
            int i4 = 0;
            for (int i5 = 0; i5 < instanceSize; i5++) {
                int i6 = ys.get(range[i5]);
                int[] iArr = xs.get(range[i5]);
                if (abstractTrainSpace.hasWeight()) {
                    dArr4 = vs.get(range[i5]);
                }
                if (z) {
                    if (!update(labelSize, i6, iArr, dArr4, dArr3, dArr, dArr2, i2)) {
                        i4++;
                    }
                    i2++;
                } else if (!update(labelSize, i6, iArr, dArr4, dArr3, dArr)) {
                    i4++;
                }
            }
            d = (100.0d * i4) / instanceSize;
            double stdev = UTMath.stdev(d2, d);
            this.LOG.info(String.format("%4d: acc = %5.2f, stdev = %7.4f\n", Integer.valueOf(i3 + 1), Double.valueOf(d), Double.valueOf(stdev)));
            if (stdev < this.d_eps) {
                break;
            }
        }
        if (z) {
            model.setWeights(getWeights(dArr, dArr2, i2));
        } else {
            model.setWeights(UTArray.toFloatArray(dArr));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getCost(int i, double[] dArr, int i2, int i3) {
        return this.d_alpha / (this.d_rho + Math.sqrt(dArr[getWeightIndex(i, i2, i3)]));
    }

    protected float[] getWeights(double[] dArr, double[] dArr2, int i) {
        int length = dArr.length;
        float[] fArr = new float[length];
        double d = 1.0d / i;
        for (int i2 = 0; i2 < length; i2++) {
            fArr[i2] = (float) (dArr[i2] - (d * dArr2[i2]));
        }
        return fArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateWeightForAveraging(int i, double d, double[] dArr, double[] dArr2, int i2) {
        dArr[i] = dArr[i] + d;
        dArr2[i] = dArr2[i] + (d * i2);
    }
}
