/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.models.evaluation;

import ai.vespa.models.evaluation.Constant;
import ai.vespa.models.evaluation.FunctionReference;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.OnnxModel;
import com.yahoo.collections.Pair;
import com.yahoo.config.FileReference;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import java.io.File;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class RankProfilesConfigImporter {
    private final FileAcquirer fileAcquirer;

    public RankProfilesConfigImporter(FileAcquirer fileAcquirer) {
        this.fileAcquirer = fileAcquirer;
    }

    public Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig, OnnxModelsConfig onnxModelsConfig) {
        try {
            HashMap<String, Model> models = new HashMap<String, Model>();
            for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
                Model model = this.importProfile(profile, constantsConfig, onnxModelsConfig);
                models.put(model.name(), model);
            }
            return models;
        }
        catch (ParseException e) {
            throw new IllegalArgumentException("Could not read rank profiles config - version mismatch?", e);
        }
    }

    private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig, OnnxModelsConfig onnxModelsConfig) throws ParseException {
        List<OnnxModel> onnxModels = this.readOnnxModelsConfig(onnxModelsConfig);
        List<Constant> constants = this.readLargeConstants(constantsConfig);
        LinkedHashMap<FunctionReference, ExpressionFunction> functions = new LinkedHashMap<FunctionReference, ExpressionFunction>();
        LinkedHashMap<FunctionReference, ExpressionFunction> referencedFunctions = new LinkedHashMap<FunctionReference, ExpressionFunction>();
        SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo();
        ExpressionFunction firstPhase = null;
        ExpressionFunction secondPhase = null;
        for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) {
            ExpressionFunction function;
            Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name());
            Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name());
            Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name());
            if (reference.isPresent()) {
                RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value());
                function = new ExpressionFunction(reference.get().functionName(), Collections.emptyList(), expression);
                if (reference.get().isFree()) {
                    functions.put(reference.get(), function);
                }
                referencedFunctions.put(reference.get(), function);
                continue;
            }
            if (argumentType.isPresent()) {
                FunctionReference argReference = (FunctionReference)argumentType.get().getFirst();
                function = (ExpressionFunction)referencedFunctions.get(argReference);
                function = function.withArgument((String)argumentType.get().getSecond(), TensorType.fromSpec((String)property.value()));
                if (argReference.isFree()) {
                    functions.put(argReference, function);
                }
                referencedFunctions.put(argReference, function);
                continue;
            }
            if (returnType.isPresent()) {
                ExpressionFunction function2 = (ExpressionFunction)referencedFunctions.get(returnType.get());
                function2 = function2.withReturnType(TensorType.fromSpec((String)property.value()));
                if (returnType.get().isFree()) {
                    functions.put(returnType.get(), function2);
                }
                referencedFunctions.put(returnType.get(), function2);
                continue;
            }
            if (property.name().equals("vespa.rank.firstphase")) {
                firstPhase = new ExpressionFunction("firstphase", new ArrayList(), new RankingExpression("first-phase", property.value()));
                continue;
            }
            if (property.name().equals("vespa.rank.secondphase")) {
                secondPhase = new ExpressionFunction("secondphase", new ArrayList(), new RankingExpression("second-phase", property.value()));
                continue;
            }
            smallConstantsInfo.addIfSmallConstantInfo(property.name(), property.value());
        }
        if (this.functionByName("firstphase", functions.values()) == null && firstPhase != null) {
            functions.put(FunctionReference.fromName("firstphase"), firstPhase);
        }
        if (this.functionByName("secondphase", functions.values()) == null && secondPhase != null) {
            functions.put(FunctionReference.fromName("secondphase"), secondPhase);
        }
        constants.addAll(smallConstantsInfo.asConstants());
        try {
            return new Model(profile.name(), functions, referencedFunctions, constants, onnxModels);
        }
        catch (RuntimeException e) {
            throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e);
        }
    }

    private ExpressionFunction functionByName(String name, Collection<ExpressionFunction> functions) {
        for (ExpressionFunction function : functions) {
            if (!function.getName().equals(name)) continue;
            return function;
        }
        return null;
    }

    private List<OnnxModel> readOnnxModelsConfig(OnnxModelsConfig onnxModelsConfig) {
        ArrayList<OnnxModel> onnxModels = new ArrayList<OnnxModel>();
        if (onnxModelsConfig != null) {
            for (OnnxModelsConfig.Model onnxModelConfig : onnxModelsConfig.model()) {
                onnxModels.add(this.readOnnxModelConfig(onnxModelConfig));
            }
        }
        return onnxModels;
    }

    private OnnxModel readOnnxModelConfig(OnnxModelsConfig.Model onnxModelConfig) {
        try {
            String name = onnxModelConfig.name();
            File file = this.fileAcquirer.waitFor(onnxModelConfig.fileref(), 7L, TimeUnit.DAYS);
            return new OnnxModel(name, file);
        }
        catch (InterruptedException e) {
            throw new IllegalStateException("Gave up waiting for ONNX model " + onnxModelConfig.name());
        }
    }

    private List<Constant> readLargeConstants(RankingConstantsConfig constantsConfig) {
        ArrayList<Constant> constants = new ArrayList<Constant>();
        for (RankingConstantsConfig.Constant constantConfig : constantsConfig.constant()) {
            constants.add(new Constant(constantConfig.name(), this.readTensorFromFile(constantConfig.name(), TensorType.fromSpec((String)constantConfig.type()), constantConfig.fileref())));
        }
        return constants;
    }

    protected Tensor readTensorFromFile(String name, TensorType type, FileReference fileReference) {
        try {
            File file = this.fileAcquirer.waitFor(fileReference, 7L, TimeUnit.DAYS);
            if (file.getName().endsWith(".tbf")) {
                return TypedBinaryFormat.decode(Optional.of(type), (GrowableByteBuffer)GrowableByteBuffer.wrap((byte[])IOUtils.readFileBytes((File)file)));
            }
            throw new IllegalArgumentException("Constant files on other formats than .tbf are not supported, got " + file + " for constant " + name);
        }
        catch (InterruptedException e) {
            throw new IllegalStateException("Gave up waiting for constant " + name);
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    private static class SmallConstantsInfo {
        private static final Pattern valuePattern = Pattern.compile("constant\\(([a-zA-Z0-9_.]+)\\)\\.value");
        private static final Pattern typePattern = Pattern.compile("constant\\(([a-zA-Z0-9_.]+)\\)\\.type");
        private Map<String, TensorType> types = new HashMap<String, TensorType>();
        private Map<String, String> values = new HashMap<String, String>();

        private SmallConstantsInfo() {
        }

        void addIfSmallConstantInfo(String key, String value) {
            this.tryValue(key, value);
            this.tryType(key, value);
        }

        private void tryValue(String key, String value) {
            Matcher matcher = valuePattern.matcher(key);
            if (matcher.matches()) {
                this.values.put(matcher.group(1), value);
            }
        }

        private void tryType(String key, String value) {
            Matcher matcher = typePattern.matcher(key);
            if (matcher.matches()) {
                this.types.put(matcher.group(1), TensorType.fromSpec((String)value));
            }
        }

        List<Constant> asConstants() {
            ArrayList<Constant> constants = new ArrayList<Constant>();
            for (Map.Entry<String, String> entry : this.values.entrySet()) {
                TensorType type = this.types.get(entry.getKey());
                if (type == null) {
                    throw new IllegalStateException("Missing type of '" + entry.getKey() + "'");
                }
                constants.add(new Constant(entry.getKey(), Tensor.from((TensorType)type, (String)entry.getValue())));
            }
            return constants;
        }
    }
}

