/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.mxnet.engine;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.TrainingDivergedException;
import ai.djl.metric.Metrics;
import ai.djl.mxnet.engine.MxGradientCollector;
import ai.djl.mxnet.engine.MxModel;
import ai.djl.mxnet.engine.MxNDManager;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.training.GradientCollector;
import ai.djl.training.LocalParameterServer;
import ai.djl.training.ParameterServer;
import ai.djl.training.ParameterStore;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.TrainingListener;
import ai.djl.training.dataset.Batch;
import ai.djl.training.loss.Loss;
import ai.djl.training.metrics.TrainingMetric;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class MxTrainer
implements Trainer {
    private static final Logger logger = LoggerFactory.getLogger(MxTrainer.class);
    private MxModel model;
    private MxNDManager manager;
    private Metrics metrics;
    private TrainingListener listener;
    private Device[] devices;
    private ParameterStore parameterStore;
    private List<TrainingMetric> trainingMetrics;
    private List<TrainingMetric> validateMetrics;
    private Loss trainingLoss;
    private Loss validationLoss;
    long batchBeginTime;
    private boolean gradientsChecked;

    MxTrainer(MxModel model, TrainingConfig trainingConfig) {
        this.model = model;
        this.manager = (MxNDManager)model.getNDManager().newSubManager();
        this.devices = trainingConfig.getDevices();
        this.trainingLoss = trainingConfig.getLossFunction();
        if (this.trainingLoss == null) {
            throw new IllegalArgumentException("You must specify a loss for the trainer");
        }
        this.validationLoss = this.trainingLoss.duplicate();
        this.trainingMetrics = new ArrayList<TrainingMetric>(trainingConfig.getTrainingMetrics());
        this.validateMetrics = new ArrayList<TrainingMetric>();
        this.trainingMetrics.forEach(i -> this.validateMetrics.add(i.duplicate()));
        this.trainingMetrics.add((TrainingMetric)this.trainingLoss);
        this.validateMetrics.add((TrainingMetric)this.validationLoss);
        LocalParameterServer parameterServer = new LocalParameterServer(trainingConfig.getOptimizer());
        this.parameterStore = new ParameterStore((NDManager)this.manager, false);
        this.parameterStore.setParameterServer((ParameterServer)parameterServer, this.devices);
    }

    public void initialize(Shape ... shapes) {
        this.model.getBlock().initialize(this.model.getNDManager(), this.model.getDataType(), shapes);
        this.model.getBlock().getParameters().forEach(pair -> {
            for (Device device : this.devices) {
                this.parameterStore.getValue((Parameter)pair.getValue(), device);
            }
        });
    }

    public GradientCollector newGradientCollector() {
        return new MxGradientCollector();
    }

    public void trainBatch(Batch batch) {
        Batch[] splits = batch.split(this.devices, false);
        try (MxGradientCollector collector = new MxGradientCollector();){
            for (Batch split : splits) {
                NDList data = split.getData();
                NDList labels = split.getLabels();
                NDList preds = this.forward(data);
                long time = System.nanoTime();
                NDArray loss = this.trainingLoss.getLoss(labels, preds);
                collector.backward(loss);
                this.addMetric("backward", time);
                time = System.nanoTime();
                this.updateTrainingMetrics(labels, preds);
                this.addMetric("training-metrics", time);
            }
        }
        this.addMetric("train", this.batchBeginTime);
        this.batchBeginTime = System.nanoTime();
        if (this.listener != null) {
            this.listener.onTrainingBatch();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public NDList forward(NDList input) {
        long begin = System.nanoTime();
        try {
            NDList nDList = this.model.getBlock().forward(this.parameterStore, input);
            return nDList;
        }
        finally {
            this.addMetric("forward", begin);
        }
    }

    public void validateBatch(Batch batch) {
        Batch[] splits;
        long begin = System.nanoTime();
        for (Batch split : splits = batch.split(this.devices, false)) {
            NDList data = split.getData();
            NDList labels = split.getLabels();
            NDList preds = this.forward(data);
            this.updateValidationMetrics(labels, preds);
        }
        this.addMetric("validate", begin);
        if (this.listener != null) {
            this.listener.onValidationBatch();
        }
    }

    public void step() {
        if (!this.gradientsChecked) {
            this.checkGradients();
        }
        long begin = System.nanoTime();
        this.parameterStore.updateAllParameters();
        this.addMetric("step", begin);
    }

    public void setMetrics(Metrics metrics) {
        this.metrics = metrics;
    }

    public void setTrainingListener(TrainingListener listener) {
        this.listener = listener;
    }

    private void updateTrainingMetrics(NDList labels, NDList preds) {
        MxGradientCollector.setRecording(false);
        MxGradientCollector.setTraining(false);
        this.trainingMetrics.forEach(metrics -> metrics.update(labels, preds));
        this.addMetric("train", (TrainingMetric)this.trainingLoss);
        if (Float.isNaN(this.trainingLoss.getValue())) {
            throw new TrainingDivergedException("The Loss became NaN, try reduce learning rate,add clipGradient option to your optimizer, check input data and loss calculation.");
        }
        this.trainingMetrics.forEach(metric -> this.addMetric("train", (TrainingMetric)metric));
        MxGradientCollector.setRecording(true);
        MxGradientCollector.setTraining(true);
    }

    private void updateValidationMetrics(NDList labels, NDList preds) {
        this.validateMetrics.forEach(metrics -> metrics.update(labels, preds));
        this.validateMetrics.forEach(metric -> this.addMetric("validate", (TrainingMetric)metric));
    }

    public void resetTrainingMetrics() {
        this.trainingMetrics.forEach(TrainingMetric::reset);
        this.validateMetrics.forEach(TrainingMetric::reset);
        if (this.listener != null) {
            this.listener.onEpoch();
        }
    }

    public Loss getLoss() {
        return this.trainingLoss;
    }

    public Loss getValidationLoss() {
        return this.validationLoss;
    }

    public Model getModel() {
        return this.model;
    }

    public Metrics getMetrics() {
        return this.metrics;
    }

    public final <T extends TrainingMetric> T getTrainingMetric(Class<T> clazz) {
        for (TrainingMetric metric : this.trainingMetrics) {
            if (!clazz.isInstance(metric)) continue;
            return (T)metric;
        }
        return null;
    }

    public <T extends TrainingMetric> T getValidationMetric(Class<T> clazz) {
        for (TrainingMetric metric : this.validateMetrics) {
            if (!clazz.isInstance(metric)) continue;
            return (T)metric;
        }
        return null;
    }

    public NDManager getManager() {
        return this.manager;
    }

    private void checkGradients() {
        ArrayList grads = new ArrayList();
        this.model.getBlock().getParameters().values().stream().filter(Parameter::requireGradient).forEach(param -> grads.add(this.parameterStore.getValue(param, this.devices[0]).getGradient()));
        NDList list = new NDList((NDArray[])grads.stream().map(NDArray::sum).toArray(NDArray[]::new));
        NDArray gradSum = NDArrays.stack((NDList)list);
        list.close();
        NDArray array = gradSum.sum();
        float[] sums = array.toFloatArray();
        array.close();
        gradSum.close();
        float sum = 0.0f;
        for (float num : sums) {
            sum += num;
        }
        if (sum == 0.0f) {
            throw new IllegalStateException("Gradient values are all zeros, please call gradientCollector.backward() onyour target NDArray (usually loss), before calling step() ");
        }
        this.gradientsChecked = true;
    }

    protected void finalize() throws Throwable {
        if (this.manager.isOpen()) {
            if (logger.isDebugEnabled()) {
                logger.warn("Model was not closed explicitly: {}", (Object)this.getClass().getSimpleName());
            }
            this.close();
        }
        super.finalize();
    }

    public void close() {
        this.parameterStore.sync();
        this.manager.close();
    }

    private void addMetric(String metricName, long begin) {
        if (this.metrics != null && begin > 0L) {
            this.metrics.addMetric(metricName, (Number)(System.nanoTime() - begin));
        }
    }

    private void addMetric(String stage, TrainingMetric metric) {
        if (this.metrics != null) {
            this.metrics.addMetric(stage + '_' + metric.getName(), (Number)Float.valueOf(metric.getValue()));
        }
    }
}

