package edu.stanford.nlp.parser.shiftreduce;

import edu.stanford.nlp.io.ByteArrayUtils;
import edu.stanford.nlp.util.ArrayUtils;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.Serializable;

/* loaded from: input_file:edu/stanford/nlp/parser/shiftreduce/Weight.class */
public class Weight implements Serializable {
    static final short[] EMPTY = new short[0];
    private static final float THRESHOLD = 1.0E-5f;
    private short[] packed;
    private static final long serialVersionUID = 3;

    public Weight() {
        this.packed = EMPTY;
    }

    public Weight(Weight weight) {
        if (weight.size() == 0) {
            this.packed = EMPTY;
        } else {
            this.packed = ArrayUtils.copy(weight.packed);
            condense();
        }
    }

    public int size() {
        return this.packed.length / 3;
    }

    private short unpackIndex(int i) {
        return this.packed[i * 3];
    }

    private float unpackScore(int i) {
        int i2 = (i * 3) + 1;
        return Float.intBitsToFloat((this.packed[i2] << 16) | (this.packed[i2 + 1] & 65535));
    }

    private static void pack(short[] sArr, int i, int i2, float f) {
        if (i > 32767) {
            throw new ArithmeticException("How did you make an index with 30,000 weights??");
        }
        int i3 = i * 3;
        int i4 = i3 + 1;
        sArr[i3] = (short) i2;
        int floatToIntBits = Float.floatToIntBits(f);
        sArr[i4] = (short) ((floatToIntBits & (-65536)) >> 16);
        sArr[i4 + 1] = (short) (floatToIntBits & 65535);
    }

    private void pack(int i, int i2, float f) {
        if (i > 32767) {
            throw new ArithmeticException("How did you make an index with 30,000 weights??");
        }
        int i3 = i * 3;
        int i4 = i3 + 1;
        this.packed[i3] = (short) i2;
        int floatToIntBits = Float.floatToIntBits(f);
        this.packed[i4] = (short) ((floatToIntBits & (-65536)) >> 16);
        this.packed[i4 + 1] = (short) (floatToIntBits & 65535);
    }

    public void score(float[] fArr) {
        if (this.packed.length > fArr.length * 3) {
            throw new AssertionError("Called with an array of scores too small to fit");
        }
        int i = 0;
        while (i < this.packed.length) {
            int i2 = i;
            int i3 = i + 1;
            short s = this.packed[i2];
            int i4 = i3 + 1;
            int i5 = this.packed[i3] << 16;
            i = i4 + 1;
            fArr[s] = fArr[s] + Float.intBitsToFloat(i5 | (this.packed[i4] & 65535));
        }
    }

    public void addScaled(Weight weight, float f) {
        int size = weight.size();
        for (int i = 0; i < size; i++) {
            updateWeight(weight.unpackIndex(i), weight.unpackScore(i) * f);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void condense() {
        if (this.packed == null || this.packed.length == 0) {
            return;
        }
        int i = 0;
        int size = size();
        for (int i2 = 0; i2 < size; i2++) {
            if (Math.abs(unpackScore(i2)) > THRESHOLD) {
                i++;
            }
        }
        if (i == 0) {
            this.packed = EMPTY;
            return;
        }
        if (i == size) {
            return;
        }
        short[] sArr = new short[i * 3];
        int i3 = 0;
        for (int i4 = 0; i4 < size; i4++) {
            if (Math.abs(unpackScore(i4)) > THRESHOLD) {
                pack(sArr, i3, unpackIndex(i4), unpackScore(i4));
                i3++;
            }
        }
        this.packed = sArr;
    }

    public float getScore(int i) {
        if (this.packed == null) {
            return 0.0f;
        }
        int size = size();
        for (int i2 = 0; i2 < size; i2++) {
            if (unpackIndex(i2) == i) {
                return unpackScore(i2);
            }
        }
        return 0.0f;
    }

    public void updateWeight(int i, float f) {
        if (i < 0) {
            return;
        }
        if (this.packed == null || this.packed.length == 0) {
            this.packed = new short[3];
            pack(0, i, f);
            return;
        }
        int size = size();
        for (int i2 = 0; i2 < size; i2++) {
            if (unpackIndex(i2) == i) {
                pack(i2, i, unpackScore(i2) + f);
                return;
            }
        }
        short[] sArr = new short[this.packed.length + 3];
        for (int i3 = 0; i3 < this.packed.length; i3++) {
            sArr[i3] = this.packed[i3];
        }
        pack(sArr, size, i, f);
        this.packed = sArr;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public float maxAbs() {
        if (this.packed == null) {
            return 0.0f;
        }
        float f = 0.0f;
        int size = size();
        for (int i = 0; i < size; i++) {
            f = Math.max(Math.abs(unpackScore(i)), f);
        }
        return f;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void l1Reg(float f) {
        if (this.packed == null) {
            return;
        }
        int size = size();
        for (int i = 0; i < size; i++) {
            short unpackIndex = unpackIndex(i);
            float unpackScore = unpackScore(i);
            pack(i, unpackIndex, unpackScore > 0.0f ? Math.max(0.0f, unpackScore - f) : Math.min(0.0f, unpackScore + f));
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void l2Reg(float f) {
        if (this.packed == null) {
            return;
        }
        int size = size();
        for (int i = 0; i < size; i++) {
            short unpackIndex = unpackIndex(i);
            float unpackScore = unpackScore(i);
            pack(i, unpackIndex, unpackScore - (unpackScore * f));
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        int size = size();
        sb.append("Weight(");
        for (int i = 0; i < size; i++) {
            if (i > 0) {
                sb.append("  ");
            }
            sb.append(((int) unpackIndex(i)) + "=" + unpackScore(i));
        }
        sb.append(")");
        return sb.toString();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public void writeBytes(ByteArrayOutputStream byteArrayOutputStream) {
        ByteArrayUtils.writeInt(byteArrayOutputStream, this.packed.length);
        for (int i = 0; i < this.packed.length; i++) {
            ByteArrayUtils.writeShort(byteArrayOutputStream, this.packed[i]);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static Weight readBytes(ByteArrayInputStream byteArrayInputStream) {
        int readInt = ByteArrayUtils.readInt(byteArrayInputStream);
        Weight weight = new Weight();
        weight.packed = new short[readInt];
        for (int i = 0; i < readInt; i++) {
            weight.packed[i] = ByteArrayUtils.readShort(byteArrayInputStream);
        }
        return weight;
    }
}
