/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.loss;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.training.loss.Loss;
import java.util.Arrays;

public class CompositeLoss
extends Loss {
    private Loss[] components;

    public CompositeLoss(Loss ... components) {
        super("CompositeLoss");
        this.components = components;
    }

    @Override
    public NDArray getLoss(NDList label, NDList prediction) {
        return NDArrays.add((NDArray[])Arrays.stream(this.components).map(component -> component.getLoss(label, prediction)).toArray(NDArray[]::new));
    }

    @Override
    public Loss duplicate() {
        return new CompositeLoss((Loss[])Arrays.stream(this.components).map(Loss::duplicate).toArray(Loss[]::new));
    }

    @Override
    public void update(NDList labels, NDList predictions) {
        for (Loss component : this.components) {
            component.update(labels, predictions);
        }
    }

    @Override
    public void reset() {
        for (Loss component : this.components) {
            component.reset();
        }
    }

    @Override
    public float getValue() {
        return (float)Arrays.stream(this.components).mapToDouble(Loss::getValue).sum();
    }
}

