/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.loss.Loss;

public class L2Loss
extends Loss {
    private float weight;
    private int batchAxis;

    public L2Loss(float weight, int batchAxis) {
        super("L2Loss");
        this.weight = weight;
        this.batchAxis = batchAxis;
    }

    public L2Loss() {
        this(0.5f, 0);
    }

    @Override
    public NDArray getLoss(NDList label, NDList prediction) {
        NDArray pred = prediction.singletonOrThrow();
        NDArray labelReshaped = label.singletonOrThrow().reshape(pred.getShape());
        NDArray loss = labelReshaped.sub(pred).square().mul(Float.valueOf(this.weight));
        return loss.mean(this.excludeBatchAxis(loss, this.batchAxis));
    }
}

