/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.onnxruntime.zoo.tabular.softmax_regression;

import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.onnxruntime.zoo.OrtModelZoo;
import ai.djl.onnxruntime.zoo.tabular.softmax_regression.IrisFlower;
import ai.djl.repository.MRL;
import ai.djl.repository.Repository;
import ai.djl.repository.zoo.BaseModelLoader;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;

public class IrisClassificationModelLoader
extends BaseModelLoader {
    private static final Application APPLICATION = Application.Tabular.SOFTMAX_REGRESSION;
    private static final String GROUP_ID = "ai.djl.onnxruntime";
    private static final String ARTIFACT_ID = "iris_flowers";
    private static final String VERSION = "0.0.1";

    public IrisClassificationModelLoader(Repository repository) {
        super(repository, MRL.model((Application)APPLICATION, (String)GROUP_ID, (String)ARTIFACT_ID), VERSION, (ModelZoo)new OrtModelZoo());
        this.factories.put(new Pair(IrisFlower.class, Classifications.class), new FactoryImpl());
    }

    public ZooModel<String, Classifications> loadModel() throws IOException, ModelNotFoundException, MalformedModelException {
        Criteria criteria = Criteria.builder().setTypes(String.class, Classifications.class).build();
        return this.loadModel(criteria);
    }

    private static final class IrisTranslator
    implements Translator<IrisFlower, Classifications> {
        private List<String> synset = Arrays.asList("setosa", "versicolor", "virginica");

        public NDList processInput(TranslatorContext ctx, IrisFlower input) {
            float[] data = new float[]{input.getSepalLength(), input.getSepalWidth(), input.getPetalLength(), input.getPetalWidth()};
            NDArray array = ctx.getNDManager().create(data, new Shape(new long[]{1L, 4L}));
            return new NDList(new NDArray[]{array});
        }

        public Classifications processOutput(TranslatorContext ctx, NDList list) {
            float[] data = ((NDArray)list.get(1)).toFloatArray();
            ArrayList<Double> probabilities = new ArrayList<Double>(data.length);
            for (float f : data) {
                probabilities.add(Double.valueOf(f));
            }
            return new Classifications(this.synset, probabilities);
        }

        public Batchifier getBatchifier() {
            return null;
        }
    }

    private static final class FactoryImpl
    implements TranslatorFactory<IrisFlower, Classifications> {
        private FactoryImpl() {
        }

        public Translator<IrisFlower, Classifications> newInstance(Model model, Map<String, ?> arguments) {
            return new IrisTranslator();
        }
    }
}

