/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.metrics;

import ai.djl.ndarray.NDArray;
import ai.djl.training.metrics.Accuracy;
import java.util.stream.IntStream;

public class TopKAccuracy
extends Accuracy {
    private int topK;

    public TopKAccuracy(String name, int index, int topK) {
        super(name, index);
        if (topK <= 1) {
            throw new IllegalArgumentException("Please use TopKAccuracy with topK more than 1");
        }
        this.topK = topK;
    }

    public TopKAccuracy(int index, int topK) {
        this("Top_" + topK + "_Accuracy", index, topK);
    }

    public TopKAccuracy(int topK) {
        this("Top_" + topK + "_Accuracy", 0, topK);
    }

    @Override
    public void update(NDArray labels, NDArray predictions) {
        this.checkLabelShapes(labels, predictions);
        if (predictions.getShape().dimension() > 2) {
            throw new IllegalStateException("Prediction should be less than 2 dimensions");
        }
        NDArray topKPrediction = predictions.argSort(this.axis);
        int numDims = topKPrediction.getShape().dimension();
        if (numDims == 1) {
            this.addCorrectInstances(topKPrediction.flatten().eq(labels.flatten()).countNonzero().getLong(new long[0]));
        } else if (numDims == 2) {
            int numClasses = (int)topKPrediction.getShape().get(1);
            this.topK = Math.min(this.topK, numClasses);
            IntStream.range(0, this.topK).forEach(j -> {
                NDArray jPrediction = topKPrediction.get(":, " + (numClasses - j - 1));
                this.addCorrectInstances(jPrediction.flatten().eq(labels.flatten()).countNonzero().getLong(new long[0]));
            });
        }
        this.addTotalInstances((int)labels.getShape().get(0));
    }
}

