/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.zero.cv;

import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.basicdataset.cv.classification.ImageClassificationDataset;
import ai.djl.basicmodelzoo.cv.classification.MobileNetV2;
import ai.djl.basicmodelzoo.cv.classification.ResNetV1;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.Image;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.Accuracy;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.zero.Performance;
import ai.djl.zero.RequireZoo;
import java.io.IOException;
import java.util.List;

public final class ImageClassification {
    private ImageClassification() {
    }

    public static <I> ZooModel<I, Classifications> pretrained(Class<I> input, Classes classes, Performance performance) throws MalformedModelException, ModelNotFoundException, IOException {
        Criteria.Builder criteria = Criteria.builder().setTypes(input, Classifications.class).optApplication(Application.CV.IMAGE_CLASSIFICATION);
        switch (classes) {
            case IMAGENET: {
                RequireZoo.mxnet();
                String layers = performance.switchPerformance("18", "50", "152");
                criteria.optGroupId("ai.djl.mxnet").optArtifactId("resnet").optFilter("dataset", "imagenet").optFilter("layers", layers);
                break;
            }
            case DIGITS: {
                RequireZoo.basic();
                criteria.optGroupId("ai.djl.zoo").optArtifactId("mlp").optFilter("dataset", "mnist");
                break;
            }
            default: {
                throw new IllegalArgumentException("Unknown classes");
            }
        }
        return criteria.build().loadModel();
    }

    public static ZooModel<Image, Classifications> train(ImageClassificationDataset dataset, Performance performance) throws IOException, TranslateException {
        int channels = dataset.getImageChannels();
        int width = (Integer)dataset.getImageWidth().orElseThrow(() -> new IllegalArgumentException("The dataset must have a fixed image width"));
        int height = (Integer)dataset.getImageHeight().orElseThrow(() -> new IllegalArgumentException("The dataset must have a fixed image height"));
        Shape imageShape = new Shape(new long[]{channels, height, width});
        List classes = dataset.getClasses();
        RandomAccessDataset[] splitDataset = dataset.randomSplit(new int[]{8, 2});
        RandomAccessDataset trainDataset = splitDataset[0];
        RandomAccessDataset validateDataset = splitDataset[1];
        int numLayers = performance.switchPerformance(18, 50, 152);
        Object block = performance.equals((Object)Performance.FAST) ? MobileNetV2.builder().setOutSize((long)classes.size()).build() : ResNetV1.builder().setImageShape(imageShape).setNumLayers(numLayers).setOutSize((long)classes.size()).build();
        Model model = Model.newInstance((String)"ImageClassification");
        model.setBlock(block);
        DefaultTrainingConfig trainingConfig = new DefaultTrainingConfig((Loss)Loss.softmaxCrossEntropyLoss()).addEvaluator((Evaluator)new Accuracy()).addTrainingListeners(TrainingListener.Defaults.basic());
        try (Trainer trainer = model.newTrainer((TrainingConfig)trainingConfig);){
            trainer.initialize(new Shape[]{new Shape(new long[]{1L}).addAll(imageShape)});
            EasyTrain.fit((Trainer)trainer, (int)35, (Dataset)trainDataset, (Dataset)validateDataset);
        }
        Translator translator = dataset.matchingTranslatorOptions().option(Image.class, Classifications.class);
        return new ZooModel(model, translator);
    }

    public static enum Classes {
        IMAGENET,
        DIGITS;

    }
}

