/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.zero.cv;

import ai.djl.Model;
import ai.djl.basicdataset.cv.ObjectDetectionDataset;
import ai.djl.basicmodelzoo.cv.object_detection.ssd.SingleShotDetection;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.SingleShotDetectionTranslator;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.DefaultTrainingConfig;
import ai.djl.training.EasyTrain;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.training.dataset.Dataset;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.training.evaluator.BoundingBoxError;
import ai.djl.training.evaluator.Evaluator;
import ai.djl.training.evaluator.SingleShotDetectionAccuracy;
import ai.djl.training.listener.TrainingListener;
import ai.djl.training.loss.Loss;
import ai.djl.training.loss.SingleShotDetectionLoss;
import ai.djl.translate.Transform;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.zero.Performance;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public final class ObjectDetection {
    private ObjectDetection() {
    }

    public static ZooModel<Image, DetectedObjects> train(ObjectDetectionDataset dataset, Performance performance) throws IOException, TranslateException {
        List classes = dataset.getClasses();
        int channels = dataset.getImageChannels();
        int width = (Integer)dataset.getImageWidth().orElseThrow(() -> new IllegalArgumentException("The dataset must have a fixed image width"));
        int height = (Integer)dataset.getImageHeight().orElseThrow(() -> new IllegalArgumentException("The dataset must have a fixed image height"));
        Shape imageShape = new Shape(new long[]{channels, height, width});
        RandomAccessDataset[] splitDataset = dataset.randomSplit(new int[]{8, 2});
        RandomAccessDataset trainDataset = splitDataset[0];
        RandomAccessDataset validateDataset = splitDataset[1];
        Block block = ObjectDetection.getSsdTrainBlock(classes.size());
        Model model = Model.newInstance((String)"ObjectDetection");
        model.setBlock(block);
        DefaultTrainingConfig trainingConfig = new DefaultTrainingConfig((Loss)new SingleShotDetectionLoss()).addEvaluator((Evaluator)new SingleShotDetectionAccuracy("classAccuracy")).addEvaluator((Evaluator)new BoundingBoxError("boundingBoxError")).addTrainingListeners(TrainingListener.Defaults.basic());
        try (Trainer trainer = model.newTrainer((TrainingConfig)trainingConfig);){
            trainer.initialize(new Shape[]{new Shape(new long[]{1L}).addAll(imageShape)});
            EasyTrain.fit((Trainer)trainer, (int)50, (Dataset)trainDataset, (Dataset)validateDataset);
        }
        SingleShotDetectionTranslator translator = ((SingleShotDetectionTranslator.Builder)((SingleShotDetectionTranslator.Builder)((SingleShotDetectionTranslator.Builder)SingleShotDetectionTranslator.builder().addTransform((Transform)new ToTensor())).optSynset(classes)).optThreshold(0.6f)).build();
        return new ZooModel(model, (Translator)translator);
    }

    private static Block getSsdTrainBlock(int numClasses) {
        int[] numFilters = new int[]{16, 32, 64};
        SequentialBlock baseBlock = new SequentialBlock();
        for (int numFilter : numFilters) {
            baseBlock.add((Block)SingleShotDetection.getDownSamplingBlock((int)numFilter));
        }
        ArrayList<List<Float>> sizes = new ArrayList<List<Float>>();
        ArrayList<List<Float>> ratios = new ArrayList<List<Float>>();
        for (int i = 0; i < 5; ++i) {
            ratios.add(Arrays.asList(Float.valueOf(1.0f), Float.valueOf(2.0f), Float.valueOf(0.5f)));
        }
        sizes.add(Arrays.asList(Float.valueOf(0.2f), Float.valueOf(0.272f)));
        sizes.add(Arrays.asList(Float.valueOf(0.37f), Float.valueOf(0.447f)));
        sizes.add(Arrays.asList(Float.valueOf(0.54f), Float.valueOf(0.619f)));
        sizes.add(Arrays.asList(Float.valueOf(0.71f), Float.valueOf(0.79f)));
        sizes.add(Arrays.asList(Float.valueOf(0.88f), Float.valueOf(0.961f)));
        return SingleShotDetection.builder().setNumClasses(numClasses).setNumFeatures(3).optGlobalPool(true).setRatios(ratios).setSizes(sizes).setBaseNetwork((Block)baseBlock).build();
    }
}

