/*
 * Decompiled with CFR 0.152.
 */
package ai.vespa.llm.generation;

import ai.vespa.llm.InferenceParameters;
import ai.vespa.llm.LanguageModel;
import ai.vespa.llm.completion.Completion;
import ai.vespa.llm.completion.Prompt;
import ai.vespa.llm.completion.StringPrompt;
import ai.vespa.llm.generation.LanguageModelTextGeneratorConfig;
import ai.vespa.llm.generation.LanguageModelUtils;
import com.yahoo.component.AbstractComponent;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.provider.ComponentRegistry;
import com.yahoo.language.process.TextGenerator;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.logging.Logger;

public class LanguageModelTextGenerator
extends AbstractComponent
implements TextGenerator {
    private static final Logger logger = Logger.getLogger(LanguageModelTextGenerator.class.getName());
    private final LanguageModel languageModel;
    private static final String DEFAULT_PROMPT_TEMPLATE = "{input}";
    private final LanguageModelTextGeneratorConfig config;
    private final String promptTemplate;

    @Inject
    public LanguageModelTextGenerator(LanguageModelTextGeneratorConfig config, ComponentRegistry<LanguageModel> languageModels) {
        this.languageModel = LanguageModelUtils.findLanguageModel(config.providerId(), languageModels, logger);
        this.config = config;
        this.promptTemplate = this.loadPromptTemplate(config);
    }

    private String loadPromptTemplate(LanguageModelTextGeneratorConfig config) {
        if (config.promptTemplate() != null && !config.promptTemplate().isEmpty()) {
            return config.promptTemplate();
        }
        if (config.promptTemplateFile().isPresent()) {
            Path path = config.promptTemplateFile().get();
            try {
                String promptTemplate = new String(Files.readAllBytes(path));
                if (!promptTemplate.isEmpty()) {
                    return promptTemplate;
                }
            }
            catch (IOException e) {
                throw new IllegalArgumentException("Could not read prompt template file: " + String.valueOf(path), e);
            }
        }
        return DEFAULT_PROMPT_TEMPLATE;
    }

    public String generate(Prompt prompt, TextGenerator.Context context) {
        Prompt finalPrompt = this.buildPrompt(prompt);
        InferenceParameters options = new InferenceParameters(s -> null);
        List completions = this.languageModel.complete(finalPrompt, options);
        Completion firstCompletion = (Completion)completions.get(0);
        String generatedText = firstCompletion.text();
        if (this.config.maxLength() > -1) {
            generatedText = generatedText.substring(0, Math.min(this.config.maxLength(), generatedText.length()));
        }
        return generatedText;
    }

    private Prompt buildPrompt(Prompt inputPrompt) {
        String finalPrompt = this.promptTemplate.replace(DEFAULT_PROMPT_TEMPLATE, inputPrompt.asString());
        return StringPrompt.from((String)finalPrompt);
    }
}

