package ai.djl.examples.training;

import ai.djl.Device;
import ai.djl.Model;
import ai.djl.basicdataset.Mnist;
import ai.djl.examples.training.util.AbstractTraining;
import ai.djl.examples.training.util.Arguments;
import ai.djl.examples.training.util.TrainingUtils;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.initializer.XavierInitializer;
import ai.djl.training.loss.Loss;
import ai.djl.training.metrics.Accuracy;
import ai.djl.training.optimizer.Optimizer;
import ai.djl.training.optimizer.Sgd;
import ai.djl.training.optimizer.learningrate.LearningRateTracker;
import ai.djl.training.util.ProgressBar;
import ai.djl.zoo.cv.classification.Mlp;
import java.io.IOException;
import java.nio.file.Paths;

/* loaded from: input_file:ai/djl/examples/training/TrainMnist.class */
public final class TrainMnist extends AbstractTraining {
    public static void main(String[] strArr) {
        new TrainMnist().runExample(strArr);
    }

    @Override // ai.djl.examples.training.util.AbstractTraining
    protected void train(Arguments arguments) throws IOException {
        Mlp mlp = new Mlp(28, 28);
        Model newInstance = Model.newInstance();
        Throwable th = null;
        try {
            newInstance.setBlock(mlp);
            Dataset dataset = getDataset(newInstance.getNDManager(), Dataset.Usage.TRAIN, arguments);
            Dataset dataset2 = getDataset(newInstance.getNDManager(), Dataset.Usage.TEST, arguments);
            Trainer newTrainer = newInstance.newTrainer(setupTrainingConfig(arguments));
            Throwable th2 = null;
            try {
                try {
                    newTrainer.setMetrics(this.metrics);
                    newTrainer.setTrainingListener(this);
                    newTrainer.initialize(new Shape[]{new Shape(new long[]{1, 784})});
                    TrainingUtils.fit(newTrainer, arguments.getEpoch(), dataset, dataset2, arguments.getOutputDir(), "mlp");
                    if (newTrainer != null) {
                        if (0 != 0) {
                            try {
                                newTrainer.close();
                            } catch (Throwable th3) {
                                th2.addSuppressed(th3);
                            }
                        } else {
                            newTrainer.close();
                        }
                    }
                    newInstance.setProperty("Epoch", String.valueOf(arguments.getEpoch()));
                    newInstance.setProperty("Accuracy", String.format("%.2f", Float.valueOf(getValidationAccuracy())));
                    newInstance.save(Paths.get(arguments.getOutputDir(), new String[0]), "mlp");
                    if (newInstance != null) {
                        if (0 == 0) {
                            newInstance.close();
                            return;
                        }
                        try {
                            newInstance.close();
                        } catch (Throwable th4) {
                            th.addSuppressed(th4);
                        }
                    }
                } catch (Throwable th5) {
                    th2 = th5;
                    throw th5;
                }
            } catch (Throwable th6) {
                if (newTrainer != null) {
                    if (th2 != null) {
                        try {
                            newTrainer.close();
                        } catch (Throwable th7) {
                            th2.addSuppressed(th7);
                        }
                    } else {
                        newTrainer.close();
                    }
                }
                throw th6;
            }
        } catch (Throwable th8) {
            if (newInstance != null) {
                if (0 != 0) {
                    try {
                        newInstance.close();
                    } catch (Throwable th9) {
                        th.addSuppressed(th9);
                    }
                } else {
                    newInstance.close();
                }
            }
            throw th8;
        }
    }

    private TrainingConfig setupTrainingConfig(Arguments arguments) {
        int batchSize = arguments.getBatchSize();
        Sgd build = Optimizer.sgd().setRescaleGrad(1.0f / batchSize).setLearningRateTracker(LearningRateTracker.factorTracker().optBaseLearningRate(0.1f).setStep(60000 / batchSize).optFactor(0.1f).optWarmUpBeginLearningRate(0.01f).optWarmUpSteps(500).optStopFactorLearningRate(0.001f).build()).optWeightDecays(0.001f).optMomentum(0.9f).optClipGrad(1.0f).build();
        this.loss = Loss.softmaxCrossEntropyLoss();
        return new DefaultTrainingConfig(new XavierInitializer(), this.loss).setOptimizer(build).addTrainingMetric(new Accuracy()).setBatchSize(batchSize).setDevices(Device.getDevices(arguments.getMaxGpus()));
    }

    private Dataset getDataset(NDManager nDManager, Dataset.Usage usage, Arguments arguments) throws IOException {
        int batchSize = arguments.getBatchSize();
        long maxIterations = arguments.getMaxIterations();
        Mnist build = Mnist.builder(nDManager).optUsage(usage).setSampling(batchSize, true).optMaxIteration(maxIterations).build();
        build.prepare(new ProgressBar());
        if (usage == Dataset.Usage.TRAIN) {
            this.trainDataSize = (int) Math.min(build.size() / batchSize, maxIterations);
        } else {
            this.validateDataSize = (int) Math.min(build.size() / batchSize, maxIterations);
        }
        return build;
    }
}
