package ai.djl.examples.inference;

import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.DetectedObjects;
import ai.djl.modality.cv.ImageVisualization;
import ai.djl.modality.cv.Joints;
import ai.djl.modality.cv.Rectangle;
import ai.djl.modality.cv.util.BufferedImageUtils;
import ai.djl.mxnet.zoo.MxModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.util.Iterator;
import java.util.concurrent.ConcurrentHashMap;
import javax.imageio.ImageIO;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/examples/inference/PoseEstimation.class */
public class PoseEstimation {
    private static final Logger logger = LoggerFactory.getLogger(PoseEstimation.class);

    public static void main(String[] strArr) throws IOException, ModelException, TranslateException {
        logger.info("{}", new PoseEstimation().predict());
    }

    public Joints predict() throws IOException, ModelException, TranslateException {
        BufferedImage fromFile = BufferedImageUtils.fromFile(Paths.get("src/test/resources/pose_soccer.png", new String[0]));
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        concurrentHashMap.put("size", "512");
        concurrentHashMap.put("backbone", "resnet50");
        concurrentHashMap.put("flavor", "v1");
        concurrentHashMap.put("dataset", "voc");
        ZooModel loadModel = MxModelZoo.SSD.loadModel(concurrentHashMap, new ProgressBar());
        Predictor newPredictor = loadModel.newPredictor();
        Throwable th = null;
        try {
            try {
                DetectedObjects detectedObjects = (DetectedObjects) newPredictor.predict(fromFile);
                if (newPredictor != null) {
                    if (0 != 0) {
                        try {
                            newPredictor.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        newPredictor.close();
                    }
                }
                loadModel.close();
                BufferedImage bufferedImage = null;
                Iterator it = detectedObjects.items().iterator();
                while (true) {
                    if (!it.hasNext()) {
                        break;
                    }
                    DetectedObjects.DetectedObject detectedObject = (DetectedObjects.DetectedObject) it.next();
                    if ("person".equals(detectedObject.getClassName())) {
                        Rectangle bounds = detectedObject.getBoundingBox().getBounds();
                        int width = fromFile.getWidth();
                        int height = fromFile.getHeight();
                        bufferedImage = fromFile.getSubimage((int) (bounds.getX() * width), (int) (bounds.getY() * height), (int) (bounds.getWidth() * width), (int) (bounds.getHeight() * height));
                        break;
                    }
                }
                if (bufferedImage == null) {
                    logger.warn("No person found in image.");
                    return null;
                }
                ConcurrentHashMap concurrentHashMap2 = new ConcurrentHashMap();
                concurrentHashMap2.put("flavor", "v1b");
                concurrentHashMap2.put("backbone", "resnet18");
                concurrentHashMap2.put("dataset", "imagenet");
                ZooModel loadModel2 = MxModelZoo.SIMPLE_POSE.loadModel(concurrentHashMap2);
                Throwable th3 = null;
                try {
                    Predictor newPredictor2 = loadModel2.newPredictor();
                    Throwable th4 = null;
                    try {
                        try {
                            Joints joints = (Joints) newPredictor2.predict(bufferedImage);
                            logger.info("Pose image has been saved in: {}", drawJoints(bufferedImage, joints));
                            if (newPredictor2 != null) {
                                if (0 != 0) {
                                    try {
                                        newPredictor2.close();
                                    } catch (Throwable th5) {
                                        th4.addSuppressed(th5);
                                    }
                                } else {
                                    newPredictor2.close();
                                }
                            }
                            return joints;
                        } finally {
                        }
                    } catch (Throwable th6) {
                        if (newPredictor2 != null) {
                            if (th4 != null) {
                                try {
                                    newPredictor2.close();
                                } catch (Throwable th7) {
                                    th4.addSuppressed(th7);
                                }
                            } else {
                                newPredictor2.close();
                            }
                        }
                        throw th6;
                    }
                } finally {
                    if (loadModel2 != null) {
                        if (0 != 0) {
                            try {
                                loadModel2.close();
                            } catch (Throwable th8) {
                                th3.addSuppressed(th8);
                            }
                        } else {
                            loadModel2.close();
                        }
                    }
                }
            } finally {
            }
        } catch (Throwable th9) {
            if (newPredictor != null) {
                if (th != null) {
                    try {
                        newPredictor.close();
                    } catch (Throwable th10) {
                        th.addSuppressed(th10);
                    }
                } else {
                    newPredictor.close();
                }
            }
            throw th9;
        }
    }

    private static Path drawJoints(BufferedImage bufferedImage, Joints joints) throws IOException {
        Path path = Paths.get("build/output", new String[0]);
        Files.createDirectories(path, new FileAttribute[0]);
        ImageVisualization.drawJoints(bufferedImage, joints);
        Path resolve = path.resolve("joints.png");
        ImageIO.write(bufferedImage, "png", resolve.toFile());
        return resolve;
    }
}
