/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.training.listener;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.metric.Metrics;
import ai.djl.training.Trainer;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.listener.EvaluatorTrainingListener;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LoggingTrainingListener
implements TrainingListener {
    private static final Logger logger = LoggerFactory.getLogger(LoggingTrainingListener.class);
    private int numEpochs;
    private ProgressBar trainingProgressBar;
    private ProgressBar validateProgressBar;

    @Override
    public void onEpoch(Trainer trainer) {
        logger.info("Epoch {} finished.", (Object)(this.numEpochs + 1));
        Metrics metrics = trainer.getMetrics();
        if (metrics != null) {
            Loss loss = trainer.getLoss();
            String status = this.getEvaluatorsStatus(metrics, trainer.getEvaluators(), "train/epoch", Short.MAX_VALUE);
            logger.info("Train: {}", (Object)status);
            String metricName = EvaluatorTrainingListener.metricName(loss, "validate/epoch");
            if (metrics.hasMetric(metricName)) {
                status = this.getEvaluatorsStatus(metrics, trainer.getEvaluators(), "validate/epoch", Short.MAX_VALUE);
                logger.info("Validate: {}", (Object)status);
            } else {
                logger.info("validation has not been run.");
            }
        }
        ++this.numEpochs;
    }

    @Override
    public void onTrainingBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        if (this.trainingProgressBar == null) {
            this.trainingProgressBar = new ProgressBar("Training", batchData.getBatch().getProgressTotal());
        }
        this.trainingProgressBar.update(batchData.getBatch().getProgress(), this.getTrainingStatus(trainer, batchData.getBatch().getSize()));
    }

    private String getTrainingStatus(Trainer trainer, int batchSize) {
        Metrics metrics = trainer.getMetrics();
        if (metrics == null) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        sb.append(this.getEvaluatorsStatus(metrics, trainer.getEvaluators(), "train/progress", 2));
        if (metrics.hasMetric("train")) {
            float batchTime = (float)metrics.latestMetric("train").getValue().longValue() / 1.0E9f;
            sb.append(String.format(", speed: %.2f images/sec", Float.valueOf((float)batchSize / batchTime)));
        }
        return sb.toString();
    }

    @Override
    public void onValidationBatch(Trainer trainer, TrainingListener.BatchData batchData) {
        if (this.validateProgressBar == null) {
            this.validateProgressBar = new ProgressBar("Validating", batchData.getBatch().getProgressTotal());
        }
        this.validateProgressBar.update(batchData.getBatch().getProgress());
    }

    @Override
    public void onTrainingBegin(Trainer trainer) {
        List<Device> devices = trainer.getDevices();
        String devicesMsg = devices.size() == 1 && "cpu".equals(devices.get(0).getDeviceType()) ? Device.cpu().toString() : devices.size() + " GPUs";
        logger.info("Training on: {}.", (Object)devicesMsg);
        long init = System.nanoTime();
        String engineName = Engine.getInstance().getEngineName();
        String version = Engine.getInstance().getVersion();
        long loaded = System.nanoTime();
        logger.info(String.format("Load %s Engine Version %s in %.3f ms.", engineName, version, Float.valueOf((float)(loaded - init) / 1000000.0f)));
    }

    @Override
    public void onTrainingEnd(Trainer trainer) {
        float p90;
        float p50;
        Metrics metrics = trainer.getMetrics();
        if (metrics == null) {
            return;
        }
        if (metrics.hasMetric("train")) {
            p50 = (float)metrics.percentile("train", 50).getValue().longValue() / 1000000.0f;
            p90 = (float)metrics.percentile("train", 90).getValue().longValue() / 1000000.0f;
            logger.info(String.format("train P50: %.3f ms, P90: %.3f ms", Float.valueOf(p50), Float.valueOf(p90)));
        }
        p50 = (float)metrics.percentile("forward", 50).getValue().longValue() / 1000000.0f;
        p90 = (float)metrics.percentile("forward", 90).getValue().longValue() / 1000000.0f;
        logger.info(String.format("forward P50: %.3f ms, P90: %.3f ms", Float.valueOf(p50), Float.valueOf(p90)));
        p50 = (float)metrics.percentile("training-metrics", 50).getValue().longValue() / 1000000.0f;
        p90 = (float)metrics.percentile("training-metrics", 90).getValue().longValue() / 1000000.0f;
        logger.info(String.format("training-metrics P50: %.3f ms, P90: %.3f ms", Float.valueOf(p50), Float.valueOf(p90)));
        p50 = (float)metrics.percentile("backward", 50).getValue().longValue() / 1000000.0f;
        p90 = (float)metrics.percentile("backward", 90).getValue().longValue() / 1000000.0f;
        logger.info(String.format("backward P50: %.3f ms, P90: %.3f ms", Float.valueOf(p50), Float.valueOf(p90)));
        p50 = (float)metrics.percentile("step", 50).getValue().longValue() / 1000000.0f;
        p90 = (float)metrics.percentile("step", 90).getValue().longValue() / 1000000.0f;
        logger.info(String.format("step P50: %.3f ms, P90: %.3f ms", Float.valueOf(p50), Float.valueOf(p90)));
        p50 = (float)metrics.percentile("epoch", 50).getValue().longValue() / 1.0E9f;
        p90 = (float)metrics.percentile("epoch", 90).getValue().longValue() / 1.0E9f;
        logger.info(String.format("epoch P50: %.3f s, P90: %.3f s", Float.valueOf(p50), Float.valueOf(p90)));
    }

    private String getEvaluatorsStatus(Metrics metrics, List<Evaluator> toOutput, String stage, int limit) {
        ArrayList<String> metricOutputs = new ArrayList<String>(limit + 1);
        int count = 0;
        for (Evaluator evaluator : toOutput) {
            if (++count > limit) {
                metricOutputs.add("...");
                break;
            }
            String metricName = EvaluatorTrainingListener.metricName(evaluator, stage);
            if (metrics.hasMetric(metricName)) {
                float value = metrics.latestMetric(metricName).getValue().floatValue();
                metricOutputs.add(String.format("%s: %.2f", evaluator.getName(), Float.valueOf(value)));
                continue;
            }
            metricOutputs.add(String.format("%s: _", evaluator.getName()));
        }
        return String.join((CharSequence)", ", metricOutputs);
    }
}

