/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.data.mongodb.repository.aot;

import java.util.List;
import org.bson.Document;
import org.jspecify.annotations.NullUnmarked;
import org.springframework.data.core.TypeInformation;
import org.springframework.data.domain.Limit;
import org.springframework.data.domain.ScoringFunction;
import org.springframework.data.domain.Sort;
import org.springframework.data.mongodb.core.MongoOperations;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation;
import org.springframework.data.mongodb.repository.VectorSearch;
import org.springframework.data.mongodb.repository.aot.AotStringQuery;
import org.springframework.data.mongodb.repository.aot.ExpressionSnippet;
import org.springframework.data.mongodb.repository.aot.MongoCodeBlocks;
import org.springframework.data.mongodb.repository.aot.QueryInteraction;
import org.springframework.data.mongodb.repository.aot.Snippet;
import org.springframework.data.mongodb.repository.aot.VariableSnippet;
import org.springframework.data.mongodb.repository.query.MongoQueryExecution;
import org.springframework.data.mongodb.repository.query.MongoQueryMethod;
import org.springframework.data.repository.aot.generate.AotQueryMethodGenerationContext;
import org.springframework.javapoet.CodeBlock;
import org.springframework.util.StringUtils;

class VectorSearchBlocks {
    VectorSearchBlocks() {
    }

    @NullUnmarked
    static class VectorSearchQueryCodeBlockBuilder {
        private final AotQueryMethodGenerationContext context;
        private final MongoQueryMethod queryMethod;
        private final VectorSearch vectorSearchAnnotation;
        private String searchQueryVariableName;
        private AotStringQuery filter;
        private final String searchPath;

        VectorSearchQueryCodeBlockBuilder(AotQueryMethodGenerationContext context, MongoQueryMethod queryMethod, String searchPath) {
            this.context = context;
            this.queryMethod = queryMethod;
            this.vectorSearchAnnotation = queryMethod.getRequiredVectorSearchAnnotation();
            this.searchPath = searchPath;
        }

        VectorSearchQueryCodeBlockBuilder withFilter(AotStringQuery filter) {
            this.filter = filter;
            return this;
        }

        VectorSearchQueryCodeBlockBuilder usingVariableName(String searchQueryVariableName) {
            this.searchQueryVariableName = searchQueryVariableName;
            return this;
        }

        CodeBlock build() {
            ExpressionSnippet numCandidates;
            CodeBlock.Builder builder = CodeBlock.builder();
            String vectorParameterName = this.context.getVectorParameterName();
            String indexName = this.vectorSearchAnnotation.indexName();
            VectorSearchOperation.SearchType searchType = this.vectorSearchAnnotation.searchType();
            ExpressionSnippet limit = this.getLimitExpression();
            if (limit.requiresEvaluation() && !StringUtils.hasText((String)this.vectorSearchAnnotation.numCandidates()) && (searchType == VectorSearchOperation.SearchType.ANN || searchType == VectorSearchOperation.SearchType.DEFAULT)) {
                VariableSnippet variableBlock = limit.as(VariableSnippet::create).variableName(this.context.localVariable("limitToUse"));
                variableBlock.renderDeclaration(builder);
                limit = variableBlock;
            }
            Snippet.BuilderStyleBuilder vectorSearchOperationBuilder = Snippet.declare(builder).variableBuilder(VectorSearchOperation.class, this.context.localVariable("$vectorSearch")).as("$T.vectorSearch($S).path($S).vector($L).limit($L)", Aggregation.class, indexName, this.searchPath, vectorParameterName, limit.code());
            if (!searchType.equals((Object)VectorSearchOperation.SearchType.DEFAULT)) {
                vectorSearchOperationBuilder.call("searchType").with("$T.$L", VectorSearchOperation.SearchType.class, searchType.name());
            }
            if (!(numCandidates = this.getNumCandidatesExpression(searchType, limit)).isEmpty()) {
                vectorSearchOperationBuilder.call("numCandidates").with(numCandidates);
            }
            vectorSearchOperationBuilder.call("withSearchScore").with("\"__score__\"", new Object[0]);
            if (StringUtils.hasText((String)this.context.getScoreParameterName())) {
                vectorSearchOperationBuilder.call("withFilterBySore").with("$1L -> { $1L.gt($2L.getValue()); }", this.context.localVariable("criteria"), this.context.getScoreParameterName());
            } else if (StringUtils.hasText((String)this.context.getScoreRangeParameterName())) {
                vectorSearchOperationBuilder.call("withFilterBySore").with("scoreBetween($1L.getLowerBound(), $1L.getUpperBound())", this.context.getScoreRangeParameterName());
            }
            VariableSnippet vectorSearchOperation = vectorSearchOperationBuilder.variable();
            this.getFilter(vectorSearchOperation.getVariableName()).appendTo(builder);
            VariableSnippet sortStage = this.getSort().as(VariableSnippet::create).variableName(this.context.localVariable("$sort"));
            sortStage.renderDeclaration(builder);
            builder.add("\n", new Object[0]);
            VariableSnippet aggregationPipeline = Snippet.declare(builder).variable(AggregationPipeline.class, this.searchQueryVariableName).as("new $T($T.of($L, $L))", AggregationPipeline.class, List.class, vectorSearchOperation.getVariableName(), sortStage.code());
            String scoringFunctionVar = this.context.localVariable("scoringFunction");
            builder.add("$1T $2L = ", new Object[]{ScoringFunction.class, scoringFunctionVar});
            if (StringUtils.hasText((String)this.context.getScoreParameterName())) {
                builder.add("$L.getFunction();\n", new Object[]{this.context.getScoreParameterName()});
            } else if (StringUtils.hasText((String)this.context.getScoreRangeParameterName())) {
                builder.add("scoringFunction($L);\n", new Object[]{this.context.getScoreRangeParameterName()});
            } else {
                builder.add("$1T.unspecified();\n", new Object[]{ScoringFunction.class});
            }
            builder.addStatement("return ($5T) new $1T($2L, $3T.class, $2L.getCollectionName($3T.class), $4T.of($5T.class), $6L, $7L).execute(null)", new Object[]{MongoQueryExecution.VectorSearchExecution.class, this.context.fieldNameOf(MongoOperations.class), this.context.getRepositoryInformation().getDomainType(), TypeInformation.class, this.queryMethod.getReturnType().getType(), aggregationPipeline.getVariableName(), scoringFunctionVar});
            return builder.build();
        }

        private ExpressionSnippet getSort() {
            if (!this.filter.isSorted()) {
                return new ExpressionSnippet(CodeBlock.of((String)"$T.sort($T.Direction.DESC, $S)", (Object[])new Object[]{Aggregation.class, Sort.class, "__score__"}));
            }
            CodeBlock.Builder builder = CodeBlock.builder();
            String ctx = this.context.localVariable("ctx");
            String mappedSort = this.context.localVariable("mappedSort");
            builder.add("($T) ($L) -> {\n", new Object[]{AggregationOperation.class, ctx});
            builder.indent();
            builder.add("$1T $4L = $5L.getMappedObject(parse($2S), $3T.class);\n", new Object[]{Document.class, this.filter.getSortString(), this.context.getMethodReturn().getActualClassName(), mappedSort, ctx});
            builder.add("return new $1T($2S, $3L.append(\"__score__\", -1));\n", new Object[]{Document.class, "$sort", mappedSort});
            builder.unindent();
            builder.add("}", new Object[0]);
            return new ExpressionSnippet(builder.build());
        }

        private Snippet getFilter(String vectorSearchVar) {
            if (!StringUtils.hasText((String)this.filter.getQueryString())) {
                return ExpressionSnippet.empty();
            }
            CodeBlock.Builder builder = CodeBlock.builder();
            String filterVar = this.context.localVariable("filter");
            builder.add(MongoCodeBlocks.queryBlockBuilder(this.context, this.queryMethod).usingQueryVariableName("filter").filter(new QueryInteraction(this.filter, false, false, false)).buildJustTheQuery());
            builder.addStatement("$1L = $1L.filter($2L.getQueryObject())", new Object[]{vectorSearchVar, filterVar});
            builder.add("\n", new Object[0]);
            return new ExpressionSnippet(builder.build());
        }

        private ExpressionSnippet getNumCandidatesExpression(VectorSearchOperation.SearchType searchType, ExpressionSnippet limit) {
            String numCandidates = this.vectorSearchAnnotation.numCandidates();
            if (StringUtils.hasText((String)numCandidates)) {
                if (MongoCodeBlocks.containsPlaceholder(numCandidates) || MongoCodeBlocks.containsExpression(numCandidates)) {
                    return new ExpressionSnippet(MongoCodeBlocks.evaluateNumberPotentially(numCandidates, Integer.class, this.context), true);
                }
                return new ExpressionSnippet(CodeBlock.of((String)"$L", (Object[])new Object[]{numCandidates}));
            }
            if (searchType == VectorSearchOperation.SearchType.ANN || searchType == VectorSearchOperation.SearchType.DEFAULT) {
                CodeBlock.Builder builder = CodeBlock.builder();
                if (StringUtils.hasText((String)this.context.getLimitParameterName())) {
                    builder.add("$L.max() * 20", new Object[]{this.context.getLimitParameterName()});
                } else if (this.filter.isLimited()) {
                    builder.add("$L", new Object[]{this.filter.getLimit() * 20});
                } else {
                    builder.add("$L * 20", new Object[]{limit.code()});
                }
                return new ExpressionSnippet(builder.build());
            }
            return ExpressionSnippet.empty();
        }

        private ExpressionSnippet getLimitExpression() {
            if (StringUtils.hasText((String)this.context.getLimitParameterName())) {
                return new ExpressionSnippet(CodeBlock.of((String)"$L", (Object[])new Object[]{this.context.getLimitParameterName()}));
            }
            if (this.filter.isLimited()) {
                return new ExpressionSnippet(CodeBlock.of((String)"$L", (Object[])new Object[]{this.filter.getLimit()}));
            }
            String limit = this.vectorSearchAnnotation.limit();
            if (StringUtils.hasText((String)limit)) {
                if (MongoCodeBlocks.containsPlaceholder(limit) || MongoCodeBlocks.containsExpression(limit)) {
                    return new ExpressionSnippet(MongoCodeBlocks.evaluateNumberPotentially(limit, Integer.class, this.context), true);
                }
                return new ExpressionSnippet(CodeBlock.of((String)"$L", (Object[])new Object[]{limit}));
            }
            return new ExpressionSnippet(CodeBlock.of((String)"$T.unlimited()", (Object[])new Object[]{Limit.class}));
        }
    }
}

