package ai.djl.examples.inference.benchmark.util;

import ai.djl.modality.Classifications;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.util.BufferedImageUtils;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import com.google.gson.Gson;
import com.google.gson.reflect.TypeToken;
import java.awt.image.BufferedImage;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Map;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Option;
import org.apache.commons.cli.OptionGroup;
import org.apache.commons.cli.Options;

/* loaded from: input_file:ai/djl/examples/inference/benchmark/util/Arguments.class */
public class Arguments {
    private String modelDir;
    private String artifactId;
    private String imageFile;
    private String outputDir;
    private Map<String, String> criteria;
    private int duration;
    private int iteration;
    private int threads;
    private String inputClass;
    private String outputClass;
    private Shape inputShape;

    /* JADX WARN: Type inference failed for: r0v28, types: [ai.djl.examples.inference.benchmark.util.Arguments$1] */
    public Arguments(CommandLine commandLine) {
        this.modelDir = commandLine.getOptionValue("model-dir");
        this.artifactId = commandLine.getOptionValue("artifact-id");
        this.outputDir = commandLine.getOptionValue("output-dir");
        this.imageFile = commandLine.getOptionValue("image");
        this.inputClass = commandLine.getOptionValue("input-class");
        this.outputClass = commandLine.getOptionValue("output-class");
        if (commandLine.hasOption("duration")) {
            this.duration = Integer.parseInt(commandLine.getOptionValue("duration"));
        }
        this.iteration = 1;
        if (commandLine.hasOption("iteration")) {
            this.iteration = Integer.parseInt(commandLine.getOptionValue("iteration"));
        }
        if (commandLine.hasOption("threads")) {
            this.threads = Integer.parseInt(commandLine.getOptionValue("threads"));
        } else {
            this.threads = (Runtime.getRuntime().availableProcessors() * 2) - 1;
        }
        if (commandLine.hasOption("criteria")) {
            this.criteria = (Map) new Gson().fromJson(commandLine.getOptionValue("criteria"), new TypeToken<Map<String, String>>() { // from class: ai.djl.examples.inference.benchmark.util.Arguments.1
            }.getType());
        }
        if (commandLine.hasOption("input-shape")) {
            this.inputShape = new Shape(Arrays.stream(commandLine.getOptionValue("input-shape").split(",")).mapToLong(Long::parseLong).toArray());
        }
    }

    public static Options getOptions() {
        Options options = new Options();
        options.addOption(Option.builder("p").longOpt("model-dir").hasArg().argName("MODEL-DIR").desc("Path to the model directory.").build());
        options.addOption(Option.builder("n").longOpt("artifact-id").hasArg().argName("ARTIFACT-ID").desc("Model artifact id.").build());
        options.addOption(Option.builder("ic").longOpt("input-class").hasArg().argName("INPUT-CLASS").desc("Input class type.").build());
        options.addOption(Option.builder("is").longOpt("input-shape").hasArg().argName("INPUT-SHAPE").desc("Input data shape.").build());
        options.addOption(Option.builder("oc").longOpt("output-class").hasArg().argName("OUTPUT-CLASS").desc("Output class type.").build());
        options.addOption(Option.builder("i").longOpt("image").hasArg().argName("IMAGE").desc("Image file.").build());
        options.addOptionGroup(new OptionGroup().addOption(Option.builder("d").longOpt("duration").hasArg().argName("DURATION").desc("Duration of the test in minutes.").build()).addOption(Option.builder("c").longOpt("iteration").hasArg().argName("ITERATION").desc("Number of total iterations.").build()));
        options.addOption(Option.builder("t").longOpt("threads").hasArg().argName("NUMBER_THREADS").desc("Number of inference threads.").build());
        options.addOption(Option.builder("o").longOpt("output-dir").hasArg().argName("OUTPUT-DIR").desc("Directory for output logs.").build());
        options.addOption(Option.builder("r").longOpt("criteria").hasArg().argName("CRITERIA").desc("The criteria used for the model.").build());
        return options;
    }

    public int getDuration() {
        return this.duration;
    }

    public Path getModelDir() throws IOException {
        if (this.modelDir == null) {
            throw new IOException("Please specify --model-dir");
        }
        Path path = Paths.get(this.modelDir, new String[0]);
        if (Files.notExists(path, new LinkOption[0])) {
            throw new FileNotFoundException("model directory not found: " + this.modelDir);
        }
        return path;
    }

    public String getArtifactId() {
        return this.artifactId;
    }

    public Path getImageFile() throws FileNotFoundException {
        if (this.imageFile == null) {
            Path path = Paths.get("src/test/resources/kitten.jpg", new String[0]);
            if (Files.notExists(path, new LinkOption[0])) {
                throw new FileNotFoundException("Missing --image parameter.");
            }
            return path;
        }
        Path path2 = Paths.get(this.imageFile, new String[0]);
        if (Files.notExists(path2, new LinkOption[0])) {
            throw new FileNotFoundException("image file not found: " + this.imageFile);
        }
        return path2;
    }

    public int getIteration() {
        return this.iteration;
    }

    public int getThreads() {
        return this.threads;
    }

    public String getOutputDir() {
        return this.outputDir;
    }

    public Map<String, String> getCriteria() {
        return this.criteria;
    }

    public Class<?> getInputClass() throws ClassNotFoundException {
        return this.inputClass == null ? BufferedImage.class : Class.forName(this.inputClass);
    }

    public Class<?> getOutputClass() throws ClassNotFoundException {
        return this.outputClass == null ? (this.artifactId == null || !this.artifactId.contains("ssd")) ? Classifications.class : DetectedObjects.class : Class.forName(this.outputClass);
    }

    public Object getInputData() throws IOException, ClassNotFoundException {
        Class<?> inputClass = getInputClass();
        if (inputClass == BufferedImage.class) {
            return BufferedImageUtils.fromFile(getImageFile());
        }
        if (inputClass == float[].class || inputClass == NDList.class) {
            return null;
        }
        throw new IllegalArgumentException("Unsupported input class: " + inputClass);
    }

    public Shape getInputShape() {
        return this.inputShape;
    }
}
