/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.models.evaluation;

import ai.vespa.models.evaluation.Constant;
import ai.vespa.models.evaluation.FunctionEvaluator;
import ai.vespa.models.evaluation.FunctionReference;
import ai.vespa.models.evaluation.LazyArrayContext;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

@Beta
public class Model {
    private static final String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_";
    private static final Value missingValue = DoubleValue.frozen((double)Double.NaN);
    private final String name;
    private final ImmutableList<ExpressionFunction> functions;
    private final ImmutableList<ExpressionFunction> publicFunctions;
    private final ImmutableMap<FunctionReference, ExpressionFunction> referencedFunctions;
    private final ImmutableMap<String, LazyArrayContext> contextPrototypes;
    private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer();

    public Model(String name, Collection<ExpressionFunction> functions) {
        this(name, functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)), Collections.emptyMap(), Collections.emptyList());
    }

    Model(String name, Map<FunctionReference, ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants) {
        this.name = name;
        ImmutableMap.Builder contextBuilder = new ImmutableMap.Builder();
        for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
            try {
                LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this, missingValue);
                contextBuilder.put((Object)function.getValue().getName(), (Object)context);
                if (!function.getValue().returnType().isPresent()) {
                    functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty));
                }
                for (String argument : context.arguments()) {
                    if (function.getValue().getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)) {
                        if (function.getValue().arguments().contains(argument)) continue;
                        functions.put(function.getKey(), function.getValue().withArgument(argument));
                        continue;
                    }
                    if (function.getValue().argumentTypes().get(argument) != null) continue;
                    functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty));
                }
            }
            catch (RuntimeException e) {
                throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e);
            }
        }
        this.contextPrototypes = contextBuilder.build();
        this.functions = ImmutableList.copyOf(functions.values());
        this.publicFunctions = ImmutableList.copyOf((Collection)functions.values().stream().filter(f -> !f.getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)).collect(Collectors.toList()));
        ImmutableMap.Builder functionsBuilder = new ImmutableMap.Builder();
        for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) {
            ExpressionFunction optimizedFunction = this.optimize(function.getValue(), (ContextIndex)this.contextPrototypes.get((Object)function.getKey().functionName()));
            functionsBuilder.put((Object)function.getKey(), (Object)optimizedFunction);
        }
        this.referencedFunctions = functionsBuilder.build();
    }

    private ExpressionFunction optimize(ExpressionFunction function, ContextIndex context) {
        this.expressionOptimizer.optimize(function.getBody(), context);
        return function;
    }

    public String name() {
        return this.name;
    }

    public List<ExpressionFunction> functions() {
        return this.publicFunctions;
    }

    ExpressionFunction requireFunction(String name) {
        ExpressionFunction function = this.function(name);
        if (function == null) {
            throw new IllegalArgumentException("No function named '" + name + "' in " + this + ". Available functions: " + this.functions.stream().map(f -> f.getName()).collect(Collectors.joining(", ")));
        }
        return function;
    }

    private LazyArrayContext requireContextProprotype(String name) {
        LazyArrayContext context = (LazyArrayContext)((Object)this.contextPrototypes.get((Object)name));
        if (context == null) {
            throw new IllegalArgumentException("No function named '" + name + "' in " + this + ". Available functions: " + this.functions.stream().map(f -> f.getName()).collect(Collectors.joining(", ")));
        }
        return context;
    }

    ExpressionFunction function(String name) {
        for (ExpressionFunction function : this.functions) {
            if (!function.getName().equals(name)) continue;
            return function;
        }
        return null;
    }

    Map<FunctionReference, ExpressionFunction> referencedFunctions() {
        return this.referencedFunctions;
    }

    ExpressionFunction requireReferencedFunction(FunctionReference reference) {
        ExpressionFunction function = (ExpressionFunction)this.referencedFunctions.get((Object)reference);
        if (function == null) {
            throw new IllegalArgumentException("No " + reference + " in " + this + ". References: " + this.referencedFunctions.keySet().stream().map(FunctionReference::serialForm).collect(Collectors.joining(", ")));
        }
        return function;
    }

    public FunctionEvaluator evaluatorOf(String ... names) {
        if (names.length == 0) {
            if (this.functions.size() > 1) {
                this.throwUndeterminedFunction("More than one function is available in " + this + ", but no name is given");
            }
            return this.evaluatorOf((ExpressionFunction)this.functions.get(0));
        }
        if (names.length == 1) {
            String name = names[0];
            ExpressionFunction function = this.function(name);
            if (function != null) {
                return this.evaluatorOf(function);
            }
            List functionsStartingByName = this.functions.stream().filter(f -> f.getName().startsWith(name + ".")).collect(Collectors.toList());
            if (functionsStartingByName.size() == 0) {
                this.throwUndeterminedFunction("No function '" + name + "' in " + this);
            } else {
                if (functionsStartingByName.size() == 1) {
                    return this.evaluatorOf((ExpressionFunction)functionsStartingByName.get(0));
                }
                this.throwUndeterminedFunction("Multiple functions start by '" + name + "' in " + this);
            }
        } else if (names.length == 2) {
            String name = names[0] + "." + names[1];
            ExpressionFunction function = this.function(name);
            if (function == null) {
                this.throwUndeterminedFunction("No function '" + name + "' in " + this);
            }
            return this.evaluatorOf(function);
        }
        throw new IllegalArgumentException("No more than 2 names can be given when choosing a function, got " + Arrays.toString(names));
    }

    private FunctionEvaluator evaluatorOf(ExpressionFunction function) {
        return new FunctionEvaluator(function, this.requireContextProprotype(function.getName()).copy());
    }

    private void throwUndeterminedFunction(String message) {
        throw new IllegalArgumentException(message + ". Available functions: " + this.functions.stream().map(f -> f.getName()).collect(Collectors.joining(", ")));
    }

    public String toString() {
        return "model '" + this.name + "'";
    }
}

