/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vectorstore.mongodb.atlas;

import com.mongodb.MongoCommandException;
import com.mongodb.client.result.DeleteResult;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.bson.Document;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.mongodb.atlas.MongoDBAtlasFilterExpressionConverter;
import org.springframework.ai.vectorstore.mongodb.atlas.VectorSearchAggregation;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.data.mongodb.UncategorizedMongoDbException;
import org.springframework.data.mongodb.core.MongoTemplate;
import org.springframework.data.mongodb.core.aggregation.Aggregation;
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
import org.springframework.data.mongodb.core.query.BasicQuery;
import org.springframework.data.mongodb.core.query.Criteria;
import org.springframework.data.mongodb.core.query.CriteriaDefinition;
import org.springframework.data.mongodb.core.query.Query;
import org.springframework.util.Assert;

public class MongoDBAtlasVectorStore
extends AbstractObservationVectorStore
implements InitializingBean {
    private static final Logger logger = LoggerFactory.getLogger(MongoDBAtlasVectorStore.class);
    public static final String ID_FIELD_NAME = "_id";
    public static final String METADATA_FIELD_NAME = "metadata";
    public static final String CONTENT_FIELD_NAME = "content";
    public static final String SCORE_FIELD_NAME = "score";
    public static final String DEFAULT_VECTOR_COLLECTION_NAME = "vector_store";
    private static final String DEFAULT_VECTOR_INDEX_NAME = "vector_index";
    private static final String DEFAULT_PATH_NAME = "embedding";
    private static final int DEFAULT_NUM_CANDIDATES = 200;
    private static final int INDEX_ALREADY_EXISTS_ERROR_CODE = 68;
    private static final String INDEX_ALREADY_EXISTS_ERROR_CODE_NAME = "IndexAlreadyExists";
    private final MongoTemplate mongoTemplate;
    private final String collectionName;
    private final String vectorIndexName;
    private final String pathName;
    private final List<String> metadataFieldsToFilter;
    private final int numCandidates;
    private final MongoDBAtlasFilterExpressionConverter filterExpressionConverter;
    private final boolean initializeSchema;

    protected MongoDBAtlasVectorStore(Builder builder) {
        super((AbstractVectorStoreBuilder)builder);
        Assert.notNull((Object)builder.mongoTemplate, (String)"MongoTemplate must not be null");
        this.mongoTemplate = builder.mongoTemplate;
        this.collectionName = builder.collectionName;
        this.vectorIndexName = builder.vectorIndexName;
        this.pathName = builder.pathName;
        this.numCandidates = builder.numCandidates;
        this.metadataFieldsToFilter = builder.metadataFieldsToFilter;
        this.filterExpressionConverter = builder.filterExpressionConverter;
        this.initializeSchema = builder.initializeSchema;
    }

    public void afterPropertiesSet() throws Exception {
        if (!this.initializeSchema) {
            return;
        }
        if (!this.mongoTemplate.collectionExists(this.collectionName)) {
            this.mongoTemplate.createCollection(this.collectionName);
        }
        this.createSearchIndex();
    }

    private void createSearchIndex() {
        try {
            this.mongoTemplate.executeCommand(this.createSearchIndexDefinition());
        }
        catch (UncategorizedMongoDbException e) {
            MongoCommandException commandException;
            Throwable cause = e.getCause();
            if (cause instanceof MongoCommandException && (68 == (commandException = (MongoCommandException)cause).getCode() || INDEX_ALREADY_EXISTS_ERROR_CODE_NAME.equals(commandException.getErrorCodeName()))) {
                return;
            }
            throw e;
        }
    }

    private Document createSearchIndexDefinition() {
        ArrayList<Document> vectorFields = new ArrayList<Document>();
        vectorFields.add(new Document().append("type", (Object)"vector").append("path", (Object)this.pathName).append("numDimensions", (Object)this.embeddingModel.dimensions()).append("similarity", (Object)"cosine"));
        vectorFields.addAll(this.metadataFieldsToFilter.stream().map(fieldName -> new Document().append("type", (Object)"filter").append("path", (Object)("metadata." + fieldName))).toList());
        return new Document().append("createSearchIndexes", (Object)this.collectionName).append("indexes", List.of(new Document().append("name", (Object)this.vectorIndexName).append("type", (Object)"vectorSearch").append("definition", (Object)new Document("fields", vectorFields))));
    }

    private org.springframework.ai.document.Document mapMongoDocument(Document mongoDocument, float[] queryEmbedding) {
        String id = mongoDocument.getString((Object)ID_FIELD_NAME);
        String content = mongoDocument.getString((Object)CONTENT_FIELD_NAME);
        double score = mongoDocument.getDouble((Object)SCORE_FIELD_NAME);
        Map metadata = (Map)mongoDocument.get((Object)METADATA_FIELD_NAME, Document.class);
        metadata.put(DocumentMetadata.DISTANCE.value(), 1.0 - score);
        return org.springframework.ai.document.Document.builder().id(id).text(content).metadata(metadata).score(Double.valueOf(score)).build();
    }

    public void doAdd(List<org.springframework.ai.document.Document> documents) {
        List embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
        for (org.springframework.ai.document.Document document : documents) {
            MongoDBDocument mdbDocument = new MongoDBDocument(document.getId(), document.getText(), document.getMetadata(), (float[])embeddings.get(documents.indexOf(document)));
            this.mongoTemplate.save((Object)mdbDocument, this.collectionName);
        }
    }

    public void doDelete(List<String> idList) {
        Query query = new Query((CriteriaDefinition)Criteria.where((String)ID_FIELD_NAME).in(idList));
        this.mongoTemplate.remove(query, this.collectionName);
    }

    protected void doDelete(Filter.Expression filterExpression) {
        Assert.notNull((Object)filterExpression, (String)"Filter expression must not be null");
        try {
            String nativeFilterExpression = this.filterExpressionConverter.convertExpression(filterExpression);
            BasicQuery query = new BasicQuery(nativeFilterExpression);
            DeleteResult deleteResult = this.mongoTemplate.remove((Query)query, this.collectionName);
            logger.debug("Deleted {} documents matching filter expression", (Object)deleteResult.getDeletedCount());
        }
        catch (Exception e) {
            throw new IllegalStateException("Failed to delete documents by filter", e);
        }
    }

    public List<org.springframework.ai.document.Document> similaritySearch(String query) {
        return this.similaritySearch(SearchRequest.builder().query(query).build());
    }

    public List<org.springframework.ai.document.Document> doSimilaritySearch(SearchRequest request) {
        String nativeFilterExpressions = request.getFilterExpression() != null ? this.filterExpressionConverter.convertExpression(request.getFilterExpression()) : "";
        float[] queryEmbedding = this.embeddingModel.embed(request.getQuery());
        VectorSearchAggregation vectorSearch = new VectorSearchAggregation(EmbeddingUtils.toList((float[])queryEmbedding), this.pathName, this.numCandidates, this.vectorIndexName, request.getTopK(), nativeFilterExpressions);
        Aggregation aggregation = Aggregation.newAggregation((AggregationOperation[])new AggregationOperation[]{vectorSearch, Aggregation.addFields().addField(SCORE_FIELD_NAME).withValueOfExpression("{\"$meta\":\"vectorSearchScore\"}", new Object[0]).build(), Aggregation.match((Criteria)new Criteria(SCORE_FIELD_NAME).gte((Object)request.getSimilarityThreshold()))});
        return this.mongoTemplate.aggregate(aggregation, this.collectionName, Document.class).getMappedResults().stream().map(d -> this.mapMongoDocument((Document)d, queryEmbedding)).toList();
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {
        return VectorStoreObservationContext.builder((String)VectorStoreProvider.MONGODB.value(), (String)operationName).collectionName(this.collectionName).dimensions(Integer.valueOf(this.embeddingModel.dimensions())).fieldName(this.pathName);
    }

    public <T> Optional<T> getNativeClient() {
        MongoTemplate client = this.mongoTemplate;
        return Optional.of(client);
    }

    public static Builder builder(MongoTemplate mongoTemplate, EmbeddingModel embeddingModel) {
        return new Builder(mongoTemplate, embeddingModel);
    }

    public static class Builder
    extends AbstractVectorStoreBuilder<Builder> {
        private final MongoTemplate mongoTemplate;
        private String collectionName = "vector_store";
        private String vectorIndexName = "vector_index";
        private String pathName = "embedding";
        private int numCandidates = 200;
        private List<String> metadataFieldsToFilter = Collections.emptyList();
        private boolean initializeSchema = false;
        private MongoDBAtlasFilterExpressionConverter filterExpressionConverter = new MongoDBAtlasFilterExpressionConverter();

        private Builder(MongoTemplate mongoTemplate, EmbeddingModel embeddingModel) {
            super(embeddingModel);
            Assert.notNull((Object)mongoTemplate, (String)"MongoTemplate must not be null");
            this.mongoTemplate = mongoTemplate;
        }

        public Builder collectionName(String collectionName) {
            Assert.hasText((String)collectionName, (String)"Collection Name must not be null or empty");
            this.collectionName = collectionName;
            return this;
        }

        public Builder vectorIndexName(String vectorIndexName) {
            Assert.hasText((String)vectorIndexName, (String)"Vector Index Name must not be null or empty");
            this.vectorIndexName = vectorIndexName;
            return this;
        }

        public Builder pathName(String pathName) {
            Assert.hasText((String)pathName, (String)"Path Name must not be null or empty");
            this.pathName = pathName;
            return this;
        }

        public Builder numCandidates(int numCandidates) {
            this.numCandidates = numCandidates;
            return this;
        }

        public Builder metadataFieldsToFilter(List<String> metadataFieldsToFilter) {
            Assert.notEmpty(metadataFieldsToFilter, (String)"Fields list must not be empty");
            this.metadataFieldsToFilter = metadataFieldsToFilter;
            return this;
        }

        public Builder initializeSchema(boolean initializeSchema) {
            this.initializeSchema = initializeSchema;
            return this;
        }

        public Builder filterExpressionConverter(MongoDBAtlasFilterExpressionConverter converter) {
            Assert.notNull((Object)((Object)converter), (String)"filterExpressionConverter must not be null");
            this.filterExpressionConverter = converter;
            return this;
        }

        public MongoDBAtlasVectorStore build() {
            return new MongoDBAtlasVectorStore(this);
        }
    }

    public record MongoDBDocument(String id, String content, Map<String, Object> metadata, float[] embedding) {
    }
}

