/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.evaluator;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.util.Pair;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public abstract class AbstractAccuracy
extends Evaluator {
    protected Map<String, Long> correctInstances = new ConcurrentHashMap<String, Long>();
    protected int axis;
    protected int index;

    public AbstractAccuracy(String name, int index) {
        this(name, index, 1);
    }

    public AbstractAccuracy(String name, int index, int axis) {
        super(name);
        this.axis = axis;
        this.index = index;
    }

    protected abstract Pair<Long, NDArray> accuracyHelper(NDList var1, NDList var2);

    @Override
    public NDArray evaluate(NDList labels, NDList predictions) {
        return this.accuracyHelper(labels, predictions).getValue();
    }

    @Override
    public void addAccumulator(String key) {
        this.totalInstances.put(key, 0L);
        this.correctInstances.put(key, 0L);
    }

    @Override
    public void updateAccumulator(String key, NDList labels, NDList predictions) {
        Pair<Long, NDArray> update = this.accuracyHelper(labels, predictions);
        this.totalInstances.compute(key, (k, v) -> v + (Long)update.getKey());
        this.correctInstances.compute(key, (k, v) -> v + ((NDArray)update.getValue()).sum().getLong(new long[0]));
    }

    @Override
    public void resetAccumulator(String key) {
        this.totalInstances.compute(key, (k, v) -> 0L);
        this.correctInstances.compute(key, (k, v) -> 0L);
    }

    @Override
    public float getAccumulator(String key) {
        Long total = (Long)this.totalInstances.get(key);
        if (total == null || total == 0L) {
            return Float.NaN;
        }
        return (float)this.correctInstances.get(key).longValue() / (float)((Long)this.totalInstances.get(key)).longValue();
    }
}

