/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.basicdataset.tabular;

import ai.djl.Model;
import ai.djl.basicdataset.tabular.ListFeatures;
import ai.djl.basicdataset.tabular.MapFeatures;
import ai.djl.basicdataset.tabular.TabularResults;
import ai.djl.basicdataset.tabular.TabularTranslator;
import ai.djl.basicdataset.tabular.utils.Feature;
import ai.djl.modality.Classifications;
import ai.djl.ndarray.NDList;
import ai.djl.translate.ExpansionTranslatorFactory;
import ai.djl.translate.PostProcessor;
import ai.djl.translate.PreProcessor;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import java.lang.reflect.Type;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;

public class TabularTranslatorFactory
extends ExpansionTranslatorFactory<ListFeatures, TabularResults> {
    protected Translator<ListFeatures, TabularResults> buildBaseTranslator(Model model, Map<String, ?> arguments) {
        return new TabularTranslator(model, arguments);
    }

    public Class<ListFeatures> getBaseInputType() {
        return ListFeatures.class;
    }

    public Class<TabularResults> getBaseOutputType() {
        return TabularResults.class;
    }

    protected Map<Type, Function<PreProcessor<ListFeatures>, PreProcessor<?>>> getPreprocessorExpansions() {
        ConcurrentHashMap expansions = new ConcurrentHashMap();
        expansions.put((Type)((Object)MapFeatures.class), (Function<PreProcessor<ListFeatures>, PreProcessor<?>>)((Function<PreProcessor, PreProcessor>)MapPreProcessor::new));
        return expansions;
    }

    protected Map<Type, Function<PostProcessor<TabularResults>, PostProcessor<?>>> getPostprocessorExpansions() {
        ConcurrentHashMap expansions = new ConcurrentHashMap();
        expansions.put((Type)((Object)Classifications.class), (Function<PostProcessor<TabularResults>, PostProcessor<?>>)((Function<PostProcessor, PostProcessor>)ClassificationsTabularPostProcessor::new));
        expansions.put((Type)((Object)Float.class), (Function<PostProcessor<TabularResults>, PostProcessor<?>>)((Function<PostProcessor, PostProcessor>)RegressionTabularPostProcessor::new));
        return expansions;
    }

    static final class RegressionTabularPostProcessor
    implements PostProcessor<Float> {
        private PostProcessor<TabularResults> postProcessor;

        RegressionTabularPostProcessor(PostProcessor<TabularResults> postProcessor) {
            this.postProcessor = postProcessor;
        }

        public Float processOutput(TranslatorContext ctx, NDList list) throws Exception {
            TabularResults results = (TabularResults)this.postProcessor.processOutput(ctx, list);
            if (results.size() != 1) {
                throw new IllegalStateException("The RegressionTabularPostProcessor expected the model to produce one output, but instead it produced " + results.size());
            }
            Object result = results.getFeature(0).getResult();
            if (result instanceof Float) {
                return (Float)result;
            }
            throw new IllegalStateException("The RegressionTabularPostProcessor expected the model to produce a float, but instead it produced " + result.getClass().getName());
        }
    }

    static final class ClassificationsTabularPostProcessor
    implements PostProcessor<Classifications> {
        private PostProcessor<TabularResults> postProcessor;

        ClassificationsTabularPostProcessor(PostProcessor<TabularResults> postProcessor) {
            this.postProcessor = postProcessor;
        }

        public Classifications processOutput(TranslatorContext ctx, NDList list) throws Exception {
            TabularResults results = (TabularResults)this.postProcessor.processOutput(ctx, list);
            if (results.size() != 1) {
                throw new IllegalStateException("The ClassificationsTabularPostProcessor expected the model to produce one output, but instead it produced " + results.size());
            }
            Object result = results.getFeature(0).getResult();
            if (result instanceof Classifications) {
                return (Classifications)result;
            }
            throw new IllegalStateException("The ClassificationsTabularPostProcessor expected the model to produce a Classifications, but instead it produced " + result.getClass().getName());
        }
    }

    static final class MapPreProcessor
    implements PreProcessor<MapFeatures> {
        private TabularTranslator preProcessor;

        MapPreProcessor(PreProcessor<ListFeatures> preProcessor) {
            if (!(preProcessor instanceof TabularTranslator)) {
                throw new IllegalArgumentException("The MapPreProcessor for the TabularTranslatorFactory expects a TabularTranslator, but received " + preProcessor.getClass().getName());
            }
            this.preProcessor = (TabularTranslator)preProcessor;
        }

        public NDList processInput(TranslatorContext ctx, MapFeatures input) throws Exception {
            ListFeatures list = new ListFeatures(this.preProcessor.getFeatures().size());
            for (Feature feature : this.preProcessor.getFeatures()) {
                if (input.containsKey(feature.getName())) {
                    list.add((String)input.get(feature.getName()));
                    continue;
                }
                throw new IllegalArgumentException("The input to the TabularTranslator is missing the feature: " + feature.getName());
            }
            return this.preProcessor.processInput(ctx, list);
        }
    }
}

