/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.pytorch.zoo.cv.objectdetection;

import ai.djl.Application;
import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.transform.Normalize;
import ai.djl.modality.cv.transform.Resize;
import ai.djl.modality.cv.transform.ToTensor;
import ai.djl.modality.cv.translator.SingleShotDetectionTranslator;
import ai.djl.modality.cv.translator.wrapper.FileTranslatorFactory;
import ai.djl.modality.cv.translator.wrapper.InputStreamTranslatorFactory;
import ai.djl.modality.cv.translator.wrapper.UrlTranslatorFactory;
import ai.djl.pytorch.zoo.PtModelZoo;
import ai.djl.pytorch.zoo.cv.objectdetection.PtSSDTranslator;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.BaseModelLoader;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Transform;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import ai.djl.util.Progress;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;

public class SingleShotDetectionModelLoader
extends BaseModelLoader<Image, DetectedObjects> {
    private static final Application APPLICATION = Application.CV.OBJECT_DETECTION;
    private static final String GROUP_ID = "ai.djl.pytorch";
    private static final String ARTIFACT_ID = "ssd";
    private static final String VERSION = "0.0.1";
    private static final float[] MEAN = new float[]{0.485f, 0.456f, 0.406f};
    private static final float[] STD = new float[]{0.229f, 0.224f, 0.225f};

    public SingleShotDetectionModelLoader(Repository repository) {
        super(repository, MRL.model((Application)APPLICATION, (String)GROUP_ID, (String)ARTIFACT_ID), VERSION, (ModelZoo)new PtModelZoo());
        FactoryImpl factory = new FactoryImpl();
        this.factories.put(new Pair(Image.class, DetectedObjects.class), factory);
        this.factories.put(new Pair(Path.class, DetectedObjects.class), new FileTranslatorFactory((TranslatorFactory)factory));
        this.factories.put(new Pair(URL.class, DetectedObjects.class), new UrlTranslatorFactory((TranslatorFactory)factory));
        this.factories.put(new Pair(InputStream.class, DetectedObjects.class), new InputStreamTranslatorFactory((TranslatorFactory)factory));
    }

    public Application getApplication() {
        return APPLICATION;
    }

    public ZooModel<Image, DetectedObjects> loadModel(Map<String, String> filters, Device device, Progress progress) throws IOException, ModelNotFoundException, MalformedModelException {
        Criteria criteria = Criteria.builder().setTypes(Image.class, DetectedObjects.class).optFilters(filters).optDevice(device).optProgress(progress).build();
        return this.loadModel(criteria);
    }

    private static final class FactoryImpl
    implements TranslatorFactory<Image, DetectedObjects> {
        private FactoryImpl() {
        }

        public Translator<Image, DetectedObjects> newInstance(Model model, Map<String, Object> arguments) {
            int[][] aspectRatio;
            int width = ((Double)arguments.getOrDefault("width", 300)).intValue();
            int height = ((Double)arguments.getOrDefault("height", 300)).intValue();
            double threshold = (Double)arguments.getOrDefault("threshold", 0.4);
            int figSize = ((Double)arguments.getOrDefault("size", 300)).intValue();
            List list = (List)arguments.get("feat_size");
            int[] featSize = list == null ? new int[]{38, 19, 10, 5, 3, 1} : list.stream().mapToInt(Double::intValue).toArray();
            list = (List)arguments.get("steps");
            int[] steps = list == null ? new int[]{8, 16, 32, 64, 100, 300} : list.stream().mapToInt(Double::intValue).toArray();
            list = (List)arguments.get("scale");
            int[] scale = list == null ? new int[]{21, 45, 99, 153, 207, 261, 315} : list.stream().mapToInt(Double::intValue).toArray();
            List ratio = (List)arguments.get("aspect_ratios");
            if (ratio == null) {
                aspectRatio = new int[][]{{2}, {2, 3}, {2, 3}, {2, 3}, {2}, {2}};
            } else {
                aspectRatio = new int[ratio.size()][];
                for (int i = 0; i < aspectRatio.length; ++i) {
                    aspectRatio[i] = ((List)ratio.get(i)).stream().mapToInt(Double::intValue).toArray();
                }
            }
            return ((SingleShotDetectionTranslator.Builder)((SingleShotDetectionTranslator.Builder)((SingleShotDetectionTranslator.Builder)((SingleShotDetectionTranslator.Builder)((SingleShotDetectionTranslator.Builder)PtSSDTranslator.builder().setBoxes(figSize, featSize, steps, scale, aspectRatio).addTransform((Transform)new Resize(width, height))).addTransform((Transform)new ToTensor())).addTransform((Transform)new Normalize(MEAN, STD))).optSynsetArtifactName("classes.txt")).optThreshold((float)threshold)).build();
        }
    }
}

