package ai.djl.examples.inference.util;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.examples.util.MemoryUtils;
import ai.djl.metric.Metric;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.mxnet.zoo.MxModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.zoo.ModelZoo;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.time.Duration;
import java.util.Properties;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.DefaultParser;
import org.apache.commons.cli.HelpFormatter;
import org.apache.commons.cli.Options;
import org.apache.commons.cli.ParseException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/examples/inference/util/AbstractBenchmark.class */
public abstract class AbstractBenchmark<T> {
    private static final Logger logger = LoggerFactory.getLogger(AbstractBenchmark.class);
    private T lastResult;
    protected ProgressBar progressBar;

    protected abstract T predict(Arguments arguments, Metrics metrics, int i) throws IOException, ModelException, TranslateException;

    protected Options getOptions() {
        return Arguments.getOptions();
    }

    protected Arguments parseArguments(CommandLine commandLine) {
        return new Arguments(commandLine);
    }

    public final boolean runBenchmark(String[] strArr) {
        Options options = getOptions();
        try {
            Arguments parseArguments = parseArguments(new DefaultParser().parse(options, strArr, (Properties) null, false));
            logger.info(String.format("Load library %s in %.3f ms.", Engine.getInstance().getVersion(), Float.valueOf(((float) (System.nanoTime() - System.nanoTime())) / 1000000.0f)));
            Duration ofMinutes = Duration.ofMinutes(parseArguments.getDuration());
            if (parseArguments.getDuration() != 0) {
                logger.info("Running {} on: {}, duration: {} minutes.", new Object[]{getClass().getSimpleName(), Device.defaultDevice(), Long.valueOf(ofMinutes.toMinutes())});
            }
            int iteration = parseArguments.getIteration();
            while (!ofMinutes.isNegative()) {
                Metrics metrics = new Metrics();
                logger.info("Running {} on: {}, iteration: {}.", new Object[]{getClass().getSimpleName(), Device.defaultDevice(), Integer.valueOf(iteration)});
                this.progressBar = new ProgressBar("Iteration", iteration);
                long currentTimeMillis = System.currentTimeMillis();
                this.lastResult = predict(parseArguments, metrics, iteration);
                long currentTimeMillis2 = System.currentTimeMillis() - currentTimeMillis;
                logger.info("Inference result: {}", this.lastResult);
                int i = iteration;
                if (metrics.hasMetric("thread")) {
                    i *= ((Metric) metrics.getMetric("thread").get(0)).getValue().intValue();
                }
                logger.info(String.format("total time: %d ms, total runs: %d iterations", Long.valueOf(currentTimeMillis2), Integer.valueOf(i)));
                if (metrics.hasMetric("LoadModel")) {
                    logger.info("Model loading time: {} ms.", String.format("%.3f", Float.valueOf(((float) ((Metric) metrics.getMetric("LoadModel").get(0)).getValue().longValue()) / 1000000.0f)));
                }
                if (metrics.hasMetric("Inference") && iteration > 1) {
                    float longValue = ((float) metrics.percentile("Inference", 50).getValue().longValue()) / 1000000.0f;
                    float longValue2 = ((float) metrics.percentile("Inference", 90).getValue().longValue()) / 1000000.0f;
                    float longValue3 = ((float) metrics.percentile("Inference", 99).getValue().longValue()) / 1000000.0f;
                    float longValue4 = ((float) metrics.percentile("Preprocess", 50).getValue().longValue()) / 1000000.0f;
                    float longValue5 = ((float) metrics.percentile("Preprocess", 90).getValue().longValue()) / 1000000.0f;
                    float longValue6 = ((float) metrics.percentile("Preprocess", 99).getValue().longValue()) / 1000000.0f;
                    float longValue7 = ((float) metrics.percentile("Postprocess", 50).getValue().longValue()) / 1000000.0f;
                    float longValue8 = ((float) metrics.percentile("Postprocess", 90).getValue().longValue()) / 1000000.0f;
                    float longValue9 = ((float) metrics.percentile("Postprocess", 99).getValue().longValue()) / 1000000.0f;
                    logger.info(String.format("inference P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", Float.valueOf(longValue), Float.valueOf(longValue2), Float.valueOf(longValue3)));
                    logger.info(String.format("preprocess P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", Float.valueOf(longValue4), Float.valueOf(longValue5), Float.valueOf(longValue6)));
                    logger.info(String.format("postprocess P50: %.3f ms, P90: %.3f ms, P99: %.3f ms", Float.valueOf(longValue7), Float.valueOf(longValue8), Float.valueOf(longValue9)));
                    if (Boolean.getBoolean("collect-memory")) {
                        float longValue10 = (float) metrics.percentile("Heap", 90).getValue().longValue();
                        float longValue11 = (float) metrics.percentile("NonHeap", 90).getValue().longValue();
                        float longValue12 = (float) metrics.percentile("cpu", 90).getValue().longValue();
                        float longValue13 = (float) metrics.percentile("rss", 90).getValue().longValue();
                        logger.info(String.format("heap P90: %.3f", Float.valueOf(longValue10)));
                        logger.info(String.format("nonHeap P90: %.3f", Float.valueOf(longValue11)));
                        logger.info(String.format("cpu P90: %.3f", Float.valueOf(longValue12)));
                        logger.info(String.format("rss P90: %.3f", Float.valueOf(longValue13)));
                    }
                }
                MemoryUtils.dumpMemoryInfo(metrics, parseArguments.getOutputDir());
                ofMinutes = ofMinutes.minus(Duration.ofMillis(System.currentTimeMillis() - currentTimeMillis));
                if (!ofMinutes.isNegative()) {
                    logger.info(ofMinutes.toMinutes() + " minutes left");
                }
            }
            return true;
        } catch (ParseException e) {
            HelpFormatter helpFormatter = new HelpFormatter();
            helpFormatter.setLeftPadding(1);
            helpFormatter.setWidth(120);
            helpFormatter.printHelp(e.getMessage(), options);
            return false;
        } catch (Throwable th) {
            logger.error("Unexpected error", th);
            return false;
        }
    }

    public T getPredictResult() {
        return this.lastResult;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ZooModel<BufferedImage, Classifications> loadModel(Arguments arguments, Metrics metrics) throws ModelException, IOException {
        long nanoTime = System.nanoTime();
        String modelName = arguments.getModelName();
        if (modelName == null) {
            modelName = "RESNET";
        }
        ZooModel<BufferedImage, Classifications> loadModel = (arguments.isImperative() ? ModelZoo.getModelLoader(modelName) : MxModelZoo.getModelLoader(modelName)).loadModel(arguments.getCriteria(), new ProgressBar());
        long nanoTime2 = System.nanoTime() - nanoTime;
        logger.info("Model {} loaded in: {} ms.", loadModel.getName(), String.format("%.3f", Float.valueOf(((float) nanoTime2) / 1000000.0f)));
        metrics.addMetric("LoadModel", Long.valueOf(nanoTime2));
        return loadModel;
    }
}
