/*
 * Decompiled with CFR 0.152.
 */
package com.yahoo.searchlib.rankingexpression.integration.ml;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.yolean.Exceptions;
import java.io.File;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Logger;

public abstract class ModelImporter {
    private static final Logger log = Logger.getLogger(ModelImporter.class.getName());

    public abstract ImportedModel importModel(String var1, String var2);

    public ImportedModel importModel(String modelName, File modelDir) {
        return this.importModel(modelName, modelDir.toString());
    }

    static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph) {
        ImportedModel model = new ImportedModel(graph.name());
        graph.optimize();
        ModelImporter.importSignatures(graph, model);
        ModelImporter.importExpressions(graph, model);
        ModelImporter.reportWarnings(graph, model);
        ModelImporter.logVariableTypes(graph);
        return model;
    }

    private static void importSignatures(IntermediateGraph graph, ImportedModel model) {
        for (String signatureName : graph.signatures()) {
            ImportedModel.Signature signature = model.signature(signatureName);
            for (Map.Entry<String, String> input : graph.inputs(signatureName).entrySet()) {
                signature.input(input.getKey(), input.getValue());
            }
            for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) {
                signature.output(output.getKey(), output.getValue());
            }
        }
    }

    private static boolean isSignatureInput(ImportedModel model, IntermediateOperation operation) {
        for (ImportedModel.Signature signature : model.signatures().values()) {
            for (String inputName : signature.inputs().values()) {
                if (!inputName.equals(operation.name())) continue;
                return true;
            }
        }
        return false;
    }

    private static boolean isSignatureOutput(ImportedModel model, IntermediateOperation operation) {
        for (ImportedModel.Signature signature : model.signatures().values()) {
            for (String outputName : signature.outputs().values()) {
                if (!outputName.equals(operation.name())) continue;
                return true;
            }
        }
        return false;
    }

    static void importExpressions(IntermediateGraph graph, ImportedModel model) {
        for (ImportedModel.Signature signature : model.signatures().values()) {
            for (String outputName : signature.outputs().values()) {
                try {
                    Optional<TensorFunction> function = ModelImporter.importExpression(graph.get(outputName), model);
                    if (function.isPresent()) continue;
                    signature.skippedOutput(outputName, "No valid output function could be found.");
                }
                catch (IllegalArgumentException e) {
                    signature.skippedOutput(outputName, Exceptions.toMessageString((Throwable)e));
                }
            }
        }
    }

    private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) {
        if (!operation.type().isPresent()) {
            return Optional.empty();
        }
        if (operation.isConstant()) {
            return ModelImporter.importConstant(operation, model);
        }
        ModelImporter.importExpressionInputs(operation, model);
        ModelImporter.importRankingExpression(operation, model);
        ModelImporter.importArgumentExpression(operation, model);
        ModelImporter.importMacroExpression(operation, model);
        return operation.function();
    }

    private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) {
        operation.inputs().forEach(input -> ModelImporter.importExpression(input, model));
    }

    private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) {
        String name = operation.vespaName();
        if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
            return operation.function();
        }
        Value value = operation.getConstantValue().orElseThrow(() -> new IllegalArgumentException("Operation '" + operation.vespaName() + "' is constant but does not have a value."));
        if (!(value instanceof TensorValue)) {
            return operation.function();
        }
        Tensor tensor = value.asTensor();
        if (tensor.type().rank() == 0) {
            model.smallConstant(name, tensor);
        } else {
            model.largeConstant(name, tensor);
        }
        return operation.function();
    }

    private static void importRankingExpression(IntermediateOperation operation, ImportedModel model) {
        if (operation.function().isPresent()) {
            String name = operation.name();
            if (!model.expressions().containsKey(name)) {
                OrderedTensorType standardNamingType;
                OrderedTensorType operationType;
                TensorFunction function = operation.function().get();
                if (ModelImporter.isSignatureOutput(model, operation) && !(operationType = operation.type().get()).equals(standardNamingType = OrderedTensorType.standardType(operationType))) {
                    List<String> renameFrom = operationType.dimensionNames();
                    List<String> renameTo = standardNamingType.dimensionNames();
                    function = new Rename(function, renameFrom, renameTo);
                }
                try {
                    model.expression(name, new RankingExpression(name, function.toString()));
                }
                catch (ParseException e) {
                    throw new RuntimeException("Imported function " + function + " cannot be parsed as a ranking expression", e);
                }
            }
        }
    }

    private static void importArgumentExpression(IntermediateOperation operation, ImportedModel model) {
        if (operation.isInput()) {
            OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
            model.argument(operation.vespaName(), standardNamingConvention.type());
            model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
        }
    }

    private static void importMacroExpression(IntermediateOperation operation, ImportedModel model) {
        if (operation.macro().isPresent()) {
            TensorFunction function = operation.macro().get();
            try {
                model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString()));
            }
            catch (ParseException e) {
                throw new RuntimeException("Tensorflow function " + function + " cannot be parsed as a ranking expression", e);
            }
        }
    }

    private static void reportWarnings(IntermediateGraph graph, ImportedModel model) {
        for (ImportedModel.Signature signature : model.signatures().values()) {
            for (String outputName : signature.outputs().values()) {
                ModelImporter.reportWarnings(graph.get(outputName), model);
            }
        }
    }

    private static void reportWarnings(IntermediateOperation operation, ImportedModel model) {
        for (String warning : operation.warnings()) {
            model.defaultSignature().importWarning(warning);
        }
        for (IntermediateOperation input : operation.inputs()) {
            ModelImporter.reportWarnings(input, model);
        }
    }

    private static void logVariableTypes(IntermediateGraph graph) {
        for (IntermediateOperation operation : graph.operations()) {
            if (!(operation instanceof Constant) || !operation.type().isPresent()) continue;
            log.info("Importing TensorFlow variable " + operation.name() + " as " + operation.vespaName() + " of type " + operation.type().get());
        }
    }
}

