package ai.djl.examples.training.util;

import ai.djl.Device;
import ai.djl.engine.Engine;
import ai.djl.examples.util.MemoryUtils;
import ai.djl.metric.Metric;
import ai.djl.metric.Metrics;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.training.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.util.ProgressBar;
import java.io.IOException;
import java.util.List;
import java.util.Properties;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/examples/training/util/AbstractTraining.class */
public abstract class AbstractTraining implements TrainingListener {
    private static final Logger logger = LoggerFactory.getLogger(AbstractTraining.class);
    protected float trainingAccuracy;
    protected float trainingLoss;
    protected float validationAccuracy;
    protected float validationLoss;
    protected int batchSize;
    protected int trainDataSize;
    protected int validateDataSize;
    protected int trainingProgress;
    protected int validateProgress;
    private long epochTime;
    private int numEpochs;
    private ProgressBar trainingProgressBar;
    private ProgressBar validateProgressBar;
    protected Metrics metrics = new Metrics();
    protected Loss loss;

    public boolean runExample(String[] strArr) {
        Options options = Arguments.getOptions();
        try {
            Arguments arguments = new Arguments(new DefaultParser().parse(options, strArr, (Properties) null, false));
            int maxGpus = arguments.getMaxGpus();
            this.batchSize = arguments.getBatchSize();
            logger.info("Running {} on: {}, epoch: {}.", new Object[]{getClass().getSimpleName(), maxGpus > 0 ? maxGpus + " GPUs" : Device.cpu().toString(), Integer.valueOf(arguments.getEpoch())});
            logger.info(String.format("Load library %s in %.3f ms.", Engine.getInstance().getVersion(), Float.valueOf(((float) (System.nanoTime() - System.nanoTime())) / 1000000.0f)));
            this.epochTime = System.nanoTime();
            train(arguments);
            logger.info("Training: {} batches", Integer.valueOf(this.trainDataSize));
            logger.info("Validation: {} batches", Integer.valueOf(this.validateDataSize));
            logger.info(String.format("train P50: %.3f ms, P90: %.3f ms", Float.valueOf(((float) this.metrics.percentile("train", 50).getValue().longValue()) / 1000000.0f), Float.valueOf(((float) this.metrics.percentile("train", 90).getValue().longValue()) / 1000000.0f)));
            logger.info(String.format("forward P50: %.3f ms, P90: %.3f ms", Float.valueOf(((float) this.metrics.percentile("forward", 50).getValue().longValue()) / 1000000.0f), Float.valueOf(((float) this.metrics.percentile("forward", 90).getValue().longValue()) / 1000000.0f)));
            logger.info(String.format("training-metrics P50: %.3f ms, P90: %.3f ms", Float.valueOf(((float) this.metrics.percentile("training-metrics", 50).getValue().longValue()) / 1000000.0f), Float.valueOf(((float) this.metrics.percentile("training-metrics", 90).getValue().longValue()) / 1000000.0f)));
            logger.info(String.format("backward P50: %.3f ms, P90: %.3f ms", Float.valueOf(((float) this.metrics.percentile("backward", 50).getValue().longValue()) / 1000000.0f), Float.valueOf(((float) this.metrics.percentile("backward", 90).getValue().longValue()) / 1000000.0f)));
            logger.info(String.format("step P50: %.3f ms, P90: %.3f ms", Float.valueOf(((float) this.metrics.percentile("step", 50).getValue().longValue()) / 1000000.0f), Float.valueOf(((float) this.metrics.percentile("step", 90).getValue().longValue()) / 1000000.0f)));
            logger.info(String.format("epoch P50: %.3f s, P90: %.3f s", Float.valueOf(((float) this.metrics.percentile("epoch", 50).getValue().longValue()) / 1.0E9f), Float.valueOf(((float) this.metrics.percentile("epoch", 90).getValue().longValue()) / 1.0E9f)));
            if (arguments.getOutputDir() == null) {
                return true;
            }
            MemoryUtils.dumpMemoryInfo(this.metrics, arguments.getOutputDir());
            TrainingUtils.dumpTrainingTimeInfo(this.metrics, arguments.getOutputDir());
            return true;
        } catch (ParseException e) {
            HelpFormatter helpFormatter = new HelpFormatter();
            helpFormatter.setLeftPadding(1);
            helpFormatter.setWidth(120);
            helpFormatter.printHelp(e.getMessage(), options);
            return false;
        } catch (Throwable th) {
            logger.error("Unexpected error", th);
            return false;
        }
    }

    protected abstract void train(Arguments arguments) throws IOException, ModelNotFoundException;

    public void onTrainingBatch() {
        MemoryUtils.collectMemoryInfo(this.metrics);
        if (this.trainingProgressBar == null) {
            this.trainingProgressBar = new ProgressBar("Training", this.trainDataSize);
        }
        ProgressBar progressBar = this.trainingProgressBar;
        int i = this.trainingProgress;
        this.trainingProgress = i + 1;
        progressBar.update(i, getTrainingStatus(this.metrics));
    }

    public void onValidationBatch() {
        MemoryUtils.collectMemoryInfo(this.metrics);
        if (this.validateProgressBar == null) {
            this.validateProgressBar = new ProgressBar("Validating", this.validateDataSize);
        }
        ProgressBar progressBar = this.validateProgressBar;
        int i = this.validateProgress;
        this.validateProgress = i + 1;
        progressBar.update(i);
    }

    public void onEpoch() {
        if (this.epochTime > 0) {
            this.metrics.addMetric("epoch", Long.valueOf(System.nanoTime() - this.epochTime));
        }
        logger.info("Epoch " + this.numEpochs + " finished.");
        printTrainingStatus(this.metrics);
        this.epochTime = System.nanoTime();
        this.numEpochs++;
        this.trainingProgress = 0;
        this.validateProgress = 0;
    }

    public float getTrainingAccuracy() {
        return this.trainingAccuracy;
    }

    public float getTrainingLoss() {
        return this.trainingLoss;
    }

    public float getValidationAccuracy() {
        return this.validationAccuracy;
    }

    public float getValidationLoss() {
        return this.validationLoss;
    }

    public String getTrainingStatus(Metrics metrics) {
        StringBuilder sb = new StringBuilder();
        List metric = metrics.getMetric("train_" + this.loss.getName());
        this.trainingLoss = ((Metric) metric.get(metric.size() - 1)).getValue().floatValue();
        List metric2 = metrics.getMetric("train_Accuracy");
        this.trainingAccuracy = ((Metric) metric2.get(metric2.size() - 1)).getValue().floatValue();
        sb.append(String.format("accuracy: %.2f loss: %.2f", Float.valueOf(this.trainingAccuracy), Float.valueOf(this.trainingLoss)));
        List metric3 = metrics.getMetric("train");
        if (!metric3.isEmpty()) {
            sb.append(String.format(" speed: %.2f images/sec", Float.valueOf(this.batchSize / (((float) ((Metric) metric3.get(metric3.size() - 1)).getValue().longValue()) / 1.0E9f))));
        }
        return sb.toString();
    }

    public void printTrainingStatus(Metrics metrics) {
        List metric = metrics.getMetric("train_" + this.loss.getName());
        this.trainingLoss = ((Metric) metric.get(metric.size() - 1)).getValue().floatValue();
        List metric2 = metrics.getMetric("train_Accuracy");
        this.trainingAccuracy = ((Metric) metric2.get(metric2.size() - 1)).getValue().floatValue();
        logger.info("train accuracy: {}, train loss: {}", Float.valueOf(this.trainingAccuracy), Float.valueOf(this.trainingLoss));
        List metric3 = metrics.getMetric("validate_" + this.loss.getName());
        if (metric3.isEmpty()) {
            logger.info("validation has not been run.");
            return;
        }
        this.validationLoss = ((Metric) metric3.get(metric3.size() - 1)).getValue().floatValue();
        List metric4 = metrics.getMetric("validate_Accuracy");
        this.validationAccuracy = ((Metric) metric4.get(metric4.size() - 1)).getValue().floatValue();
        logger.info("validate accuracy: {}, validate loss: {}", Float.valueOf(this.validationAccuracy), Float.valueOf(this.validationLoss));
    }
}
