/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training;

import ai.djl.Device;
import ai.djl.training.TrainingConfig;
import ai.djl.training.initializer.Initializer;
import ai.djl.training.loss.Loss;
import ai.djl.training.metrics.TrainingMetric;
import ai.djl.training.optimizer.Optimizer;
import java.util.ArrayList;
import java.util.List;

public class DefaultTrainingConfig
implements TrainingConfig {
    private Initializer initializer;
    private Optimizer optimizer;
    private Device[] devices;
    private Loss loss;
    private List<TrainingMetric> trainingMetrics;
    private int batchSize;

    public DefaultTrainingConfig(Initializer initializer, Loss loss) {
        this.initializer = initializer;
        this.trainingMetrics = new ArrayList<TrainingMetric>();
        this.loss = loss;
    }

    public DefaultTrainingConfig setDevices(Device[] devices) {
        this.devices = devices;
        return this;
    }

    public DefaultTrainingConfig setOptimizer(Optimizer optimizer) {
        this.optimizer = optimizer;
        return this;
    }

    public DefaultTrainingConfig addTrainingMetric(TrainingMetric trainingMetric) {
        this.trainingMetrics.add(trainingMetric);
        return this;
    }

    public DefaultTrainingConfig setBatchSize(int batchSize) {
        this.batchSize = batchSize;
        return this;
    }

    @Override
    public Device[] getDevices() {
        if (this.devices == null) {
            return Device.getDevices(Integer.MAX_VALUE);
        }
        return this.devices;
    }

    @Override
    public Initializer getInitializer() {
        return this.initializer;
    }

    @Override
    public Optimizer getOptimizer() {
        return this.optimizer;
    }

    @Override
    public Loss getLossFunction() {
        return this.loss;
    }

    @Override
    public List<TrainingMetric> getTrainingMetrics() {
        return this.trainingMetrics;
    }

    @Override
    public int getBatchSize() {
        return this.batchSize;
    }
}

