package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;

/* loaded from: input_file:ai/djl/training/loss/IndexLoss.class */
public class IndexLoss extends Loss {
    private Loss loss;
    private Integer predictionsIndex;
    private Integer labelsIndex;

    public IndexLoss(Loss loss, int i) {
        this(loss, Integer.valueOf(i), Integer.valueOf(i));
    }

    public IndexLoss(Loss loss, Integer num, Integer num2) {
        super(loss.getName());
        this.loss = loss;
        this.predictionsIndex = num;
        this.labelsIndex = num2;
    }

    @Override // ai.djl.training.evaluator.Evaluator
    public NDArray evaluate(NDList nDList, NDList nDList2) {
        return this.loss.evaluate(getLabels(nDList), getPredictions(nDList2));
    }

    private NDList getPredictions(NDList nDList) {
        return this.predictionsIndex == null ? nDList : new NDList(nDList.get(this.predictionsIndex.intValue()));
    }

    private NDList getLabels(NDList nDList) {
        return this.labelsIndex == null ? nDList : new NDList(nDList.get(this.labelsIndex.intValue()));
    }
}
