/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.rag.advisor;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.stream.Collectors;
import org.jspecify.annotations.Nullable;
import org.springframework.ai.chat.client.ChatClientRequest;
import org.springframework.ai.chat.client.ChatClientResponse;
import org.springframework.ai.chat.client.advisor.api.AdvisorChain;
import org.springframework.ai.chat.client.advisor.api.BaseAdvisor;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.document.Document;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.generation.augmentation.ContextualQueryAugmenter;
import org.springframework.ai.rag.generation.augmentation.QueryAugmenter;
import org.springframework.ai.rag.postretrieval.document.DocumentPostProcessor;
import org.springframework.ai.rag.preretrieval.query.expansion.QueryExpander;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.retrieval.join.ConcatenationDocumentJoiner;
import org.springframework.ai.rag.retrieval.join.DocumentJoiner;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.core.task.TaskDecorator;
import org.springframework.core.task.TaskExecutor;
import org.springframework.core.task.support.ContextPropagatingTaskDecorator;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.util.Assert;
import reactor.core.scheduler.Scheduler;

public final class RetrievalAugmentationAdvisor
implements BaseAdvisor {
    public static final String DOCUMENT_CONTEXT = "rag_document_context";
    private final List<QueryTransformer> queryTransformers;
    private final @Nullable QueryExpander queryExpander;
    private final DocumentRetriever documentRetriever;
    private final DocumentJoiner documentJoiner;
    private final List<DocumentPostProcessor> documentPostProcessors;
    private final QueryAugmenter queryAugmenter;
    private final TaskExecutor taskExecutor;
    private final Scheduler scheduler;
    private final int order;

    private RetrievalAugmentationAdvisor(@Nullable List<QueryTransformer> queryTransformers, @Nullable QueryExpander queryExpander, DocumentRetriever documentRetriever, @Nullable DocumentJoiner documentJoiner, @Nullable List<DocumentPostProcessor> documentPostProcessors, @Nullable QueryAugmenter queryAugmenter, @Nullable TaskExecutor taskExecutor, @Nullable Scheduler scheduler, @Nullable Integer order) {
        Assert.notNull((Object)documentRetriever, (String)"documentRetriever cannot be null");
        Assert.noNullElements(queryTransformers, (String)"queryTransformers cannot contain null elements");
        this.queryTransformers = queryTransformers != null ? queryTransformers : List.of();
        this.queryExpander = queryExpander;
        this.documentRetriever = documentRetriever;
        this.documentJoiner = documentJoiner != null ? documentJoiner : new ConcatenationDocumentJoiner();
        this.documentPostProcessors = documentPostProcessors != null ? documentPostProcessors : List.of();
        this.queryAugmenter = queryAugmenter != null ? queryAugmenter : ContextualQueryAugmenter.builder().build();
        this.taskExecutor = taskExecutor != null ? taskExecutor : RetrievalAugmentationAdvisor.buildDefaultTaskExecutor();
        this.scheduler = scheduler != null ? scheduler : BaseAdvisor.DEFAULT_SCHEDULER;
        this.order = order != null ? order : 0;
    }

    public static Builder builder() {
        return new Builder();
    }

    public ChatClientRequest before(ChatClientRequest chatClientRequest, @Nullable AdvisorChain advisorChain) {
        Query originalQuery;
        HashMap<String, Object> context = new HashMap<String, Object>(chatClientRequest.context());
        String text = chatClientRequest.prompt().getUserMessage().getText();
        Query transformedQuery = originalQuery = Query.builder().text(Objects.requireNonNullElse(text, "")).history(chatClientRequest.prompt().getInstructions()).context(context).build();
        for (QueryTransformer queryTransformer : this.queryTransformers) {
            transformedQuery = queryTransformer.apply(transformedQuery);
        }
        List<Query> expandedQueries = this.queryExpander != null ? this.queryExpander.expand(transformedQuery) : List.of(transformedQuery);
        Map<Query, List<List<Document>>> documentsForQuery = expandedQueries.stream().map(query -> CompletableFuture.supplyAsync(() -> this.getDocumentsForQuery((Query)query), (Executor)this.taskExecutor)).toList().stream().map(CompletableFuture::join).collect(Collectors.toMap(Map.Entry::getKey, entry -> List.of((List)entry.getValue())));
        List<Document> documents = this.documentJoiner.join(documentsForQuery);
        for (DocumentPostProcessor documentPostProcessor : this.documentPostProcessors) {
            documents = documentPostProcessor.process(originalQuery, documents);
        }
        context.put(DOCUMENT_CONTEXT, documents);
        Query augmentedQuery = this.queryAugmenter.augment(originalQuery, documents);
        return chatClientRequest.mutate().prompt(chatClientRequest.prompt().augmentUserMessage(augmentedQuery.text())).context(context).build();
    }

    private Map.Entry<Query, List<Document>> getDocumentsForQuery(Query query) {
        List<Document> documents = this.documentRetriever.retrieve(query);
        return Map.entry(query, documents);
    }

    public ChatClientResponse after(ChatClientResponse chatClientResponse, @Nullable AdvisorChain advisorChain) {
        ChatResponse.Builder chatResponseBuilder = chatClientResponse.chatResponse() == null ? ChatResponse.builder() : ChatResponse.builder().from(chatClientResponse.chatResponse());
        Object ctx = chatClientResponse.context().get(DOCUMENT_CONTEXT);
        if (ctx != null) {
            chatResponseBuilder.metadata(DOCUMENT_CONTEXT, ctx);
        }
        return ChatClientResponse.builder().chatResponse(chatResponseBuilder.build()).context(chatClientResponse.context()).build();
    }

    public Scheduler getScheduler() {
        return this.scheduler;
    }

    public int getOrder() {
        return this.order;
    }

    private static TaskExecutor buildDefaultTaskExecutor() {
        ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
        taskExecutor.setThreadNamePrefix("ai-advisor-");
        taskExecutor.setCorePoolSize(4);
        taskExecutor.setMaxPoolSize(16);
        taskExecutor.setTaskDecorator((TaskDecorator)new ContextPropagatingTaskDecorator());
        taskExecutor.initialize();
        return taskExecutor;
    }

    public static final class Builder {
        private @Nullable List<QueryTransformer> queryTransformers;
        private @Nullable QueryExpander queryExpander;
        private @Nullable DocumentRetriever documentRetriever;
        private @Nullable DocumentJoiner documentJoiner;
        private @Nullable List<DocumentPostProcessor> documentPostProcessors;
        private @Nullable QueryAugmenter queryAugmenter;
        private @Nullable TaskExecutor taskExecutor;
        private @Nullable Scheduler scheduler;
        private @Nullable Integer order;

        private Builder() {
        }

        public Builder queryTransformers(List<QueryTransformer> queryTransformers) {
            Assert.noNullElements(queryTransformers, (String)"queryTransformers cannot contain null elements");
            this.queryTransformers = queryTransformers;
            return this;
        }

        public Builder queryTransformers(QueryTransformer ... queryTransformers) {
            Assert.notNull((Object)queryTransformers, (String)"queryTransformers cannot be null");
            Assert.noNullElements((Object[])queryTransformers, (String)"queryTransformers cannot contain null elements");
            this.queryTransformers = Arrays.asList(queryTransformers);
            return this;
        }

        public Builder queryExpander(QueryExpander queryExpander) {
            this.queryExpander = queryExpander;
            return this;
        }

        public Builder documentRetriever(DocumentRetriever documentRetriever) {
            this.documentRetriever = documentRetriever;
            return this;
        }

        public Builder documentJoiner(DocumentJoiner documentJoiner) {
            this.documentJoiner = documentJoiner;
            return this;
        }

        public Builder documentPostProcessors(List<DocumentPostProcessor> documentPostProcessors) {
            Assert.noNullElements(documentPostProcessors, (String)"documentPostProcessors cannot contain null elements");
            this.documentPostProcessors = documentPostProcessors;
            return this;
        }

        public Builder documentPostProcessors(DocumentPostProcessor ... documentPostProcessors) {
            Assert.notNull((Object)documentPostProcessors, (String)"documentPostProcessors cannot be null");
            Assert.noNullElements((Object[])documentPostProcessors, (String)"documentPostProcessors cannot contain null elements");
            this.documentPostProcessors = Arrays.asList(documentPostProcessors);
            return this;
        }

        public Builder queryAugmenter(QueryAugmenter queryAugmenter) {
            this.queryAugmenter = queryAugmenter;
            return this;
        }

        public Builder taskExecutor(TaskExecutor taskExecutor) {
            this.taskExecutor = taskExecutor;
            return this;
        }

        public Builder scheduler(Scheduler scheduler) {
            this.scheduler = scheduler;
            return this;
        }

        public Builder order(Integer order) {
            this.order = order;
            return this;
        }

        public RetrievalAugmentationAdvisor build() {
            Assert.state((this.documentRetriever != null ? 1 : 0) != 0, (String)"documentRetriever cannot be null");
            return new RetrievalAugmentationAdvisor(this.queryTransformers, this.queryExpander, this.documentRetriever, this.documentJoiner, this.documentPostProcessors, this.queryAugmenter, this.taskExecutor, this.scheduler, this.order);
        }
    }
}

