/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.loss.Loss;

public class L1Loss
extends Loss {
    private float weight;
    private int batchAxis;

    public L1Loss(float weight, int batchAxis) {
        super("L1Loss");
        this.weight = weight;
        this.batchAxis = batchAxis;
    }

    public L1Loss() {
        this(1.0f, 0);
    }

    @Override
    public NDArray getLoss(NDList label, NDList prediction) {
        NDArray pred = prediction.singletonOrThrow();
        NDArray labelReshaped = label.singletonOrThrow().reshape(pred.getShape());
        NDArray loss = labelReshaped.sub(pred).abs();
        if (this.weight != 1.0f) {
            loss = labelReshaped.mul(Float.valueOf(this.weight));
        }
        return loss.mean(this.excludeBatchAxis(loss, this.batchAxis));
    }
}

