package ai.djl.examples.training.util;

import ai.djl.Model;
import ai.djl.metric.Metric;
import ai.djl.metric.Metrics;
import ai.djl.training.Trainer;
import ai.djl.training.dataset.Batch;
import ai.djl.training.dataset.Dataset;
import java.io.BufferedWriter;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardOpenOption;
import java.nio.file.attribute.FileAttribute;
import java.util.Iterator;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/examples/training/util/TrainingUtils.class */
public final class TrainingUtils {
    private static final Logger logger = LoggerFactory.getLogger(TrainingUtils.class);

    private TrainingUtils() {
    }

    public static void fit(Trainer trainer, int i, Dataset dataset, Dataset dataset2, String str, String str2) throws IOException {
        for (int i2 = 0; i2 < i; i2++) {
            for (Batch batch : trainer.iterateDataset(dataset)) {
                trainer.trainBatch(batch);
                trainer.step();
                batch.close();
            }
            if (dataset2 != null) {
                for (Batch batch2 : trainer.iterateDataset(dataset2)) {
                    trainer.validateBatch(batch2);
                    batch2.close();
                }
            }
            trainer.resetTrainingMetrics();
            if (str != null) {
                Model model = trainer.getModel();
                model.setProperty("Epoch", String.valueOf(i2));
                model.save(Paths.get(str, new String[0]), str2);
            }
        }
    }

    /* JADX WARN: Finally extract failed */
    public static void dumpTrainingTimeInfo(Metrics metrics, String str) {
        if (str == null) {
            return;
        }
        try {
            Path path = Paths.get(str, new String[0]);
            Files.createDirectories(path, new FileAttribute[0]);
            BufferedWriter newBufferedWriter = Files.newBufferedWriter(path.resolve("training.log"), StandardOpenOption.CREATE, StandardOpenOption.APPEND);
            Throwable th = null;
            try {
                Iterator it = metrics.getMetric("train").iterator();
                while (it.hasNext()) {
                    newBufferedWriter.append((CharSequence) ((Metric) it.next()).toString());
                    newBufferedWriter.newLine();
                }
                if (newBufferedWriter != null) {
                    if (0 != 0) {
                        try {
                            newBufferedWriter.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        newBufferedWriter.close();
                    }
                }
            } catch (Throwable th3) {
                if (newBufferedWriter != null) {
                    if (0 != 0) {
                        try {
                            newBufferedWriter.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    } else {
                        newBufferedWriter.close();
                    }
                }
                throw th3;
            }
        } catch (IOException e) {
            logger.error("Failed dump training log", e);
        }
    }
}
