package ai.djl.examples.inference;

import ai.djl.ModelException;
import ai.djl.examples.inference.util.AbstractBenchmark;
import ai.djl.examples.inference.util.Arguments;
import ai.djl.examples.util.MemoryUtils;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.util.BufferedImageUtils;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.TranslateException;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/examples/inference/MultithreadedBenchmark.class */
public class MultithreadedBenchmark extends AbstractBenchmark<Classifications> {
    private static final Logger logger = LoggerFactory.getLogger(MultithreadedBenchmark.class);

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/examples/inference/MultithreadedBenchmark$PredictorCallable.class */
    public static class PredictorCallable implements Callable<Classifications> {
        private Predictor<BufferedImage, Classifications> predictor;
        private BufferedImage img;
        private Metrics metrics;
        private int iteration;
        private String workerId;
        private boolean collectMemory;

        public PredictorCallable(ZooModel<BufferedImage, Classifications> zooModel, BufferedImage bufferedImage, Metrics metrics, int i, int i2, boolean z) {
            this.predictor = zooModel.newPredictor();
            this.img = bufferedImage;
            this.metrics = metrics;
            this.iteration = i;
            this.workerId = String.format("%02d", Integer.valueOf(i2));
            this.collectMemory = z;
            this.predictor.setMetrics(metrics);
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Classifications call() throws TranslateException {
            Classifications classifications = null;
            for (int i = 0; i < this.iteration; i++) {
                classifications = (Classifications) this.predictor.predict(this.img);
                if (this.collectMemory) {
                    MemoryUtils.collectMemoryInfo(this.metrics);
                }
                MultithreadedBenchmark.logger.trace("Worker-{}: {} iteration finished.", this.workerId, Integer.valueOf(i + 1));
            }
            MultithreadedBenchmark.logger.debug("Worker-{}: finished.", this.workerId);
            return classifications;
        }
    }

    public static void main(String[] strArr) {
        if (new MultithreadedBenchmark().runBenchmark(strArr)) {
            System.exit(0);
        }
        System.exit(-1);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // ai.djl.examples.inference.util.AbstractBenchmark
    public Classifications predict(Arguments arguments, Metrics metrics, int i) throws IOException, ModelException {
        BufferedImage fromFile = BufferedImageUtils.fromFile(arguments.getImageFile());
        ZooModel<BufferedImage, Classifications> loadModel = loadModel(arguments, metrics);
        int threads = arguments.getThreads();
        logger.info("Multithreaded inference with {} threads.", Integer.valueOf(threads));
        metrics.addMetric("thread", Integer.valueOf(threads));
        ArrayList arrayList = new ArrayList(threads);
        int i2 = 0;
        while (i2 < threads) {
            arrayList.add(new PredictorCallable(loadModel, fromFile, metrics, i, i2, i2 == 0));
            i2++;
        }
        Classifications classifications = null;
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(threads);
        int i3 = 0;
        try {
            try {
                Iterator it = newFixedThreadPool.invokeAll(arrayList).iterator();
                while (it.hasNext()) {
                    try {
                        classifications = (Classifications) ((Future) it.next()).get();
                        i3++;
                    } catch (InterruptedException | ExecutionException e) {
                        logger.error("", e);
                    }
                }
            } catch (InterruptedException e2) {
                logger.error("", e2);
                newFixedThreadPool.shutdown();
            }
            if (i3 != threads) {
                logger.error("Only {}/{} threads finished.", Integer.valueOf(i3), Integer.valueOf(threads));
            }
            return classifications;
        } finally {
            newFixedThreadPool.shutdown();
        }
    }
}
