/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vectorstore;

import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch.core.BulkRequest;
import co.elastic.clients.elasticsearch.core.BulkResponse;
import co.elastic.clients.elasticsearch.core.SearchResponse;
import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem;
import co.elastic.clients.elasticsearch.core.bulk.DeleteOperation;
import co.elastic.clients.elasticsearch.core.bulk.IndexOperation;
import co.elastic.clients.elasticsearch.core.search.Hit;
import co.elastic.clients.json.JsonpMapper;
import co.elastic.clients.json.jackson.JacksonJsonpMapper;
import co.elastic.clients.transport.ElasticsearchTransport;
import co.elastic.clients.transport.Version;
import co.elastic.clients.transport.rest_client.RestClientTransport;
import co.elastic.clients.util.ObjectBuilder;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.micrometer.observation.ObservationRegistry;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import org.elasticsearch.client.RestClient;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.vectorstore.ElasticsearchAiSearchFilterExpressionConverter;
import org.springframework.ai.vectorstore.ElasticsearchVectorStoreOptions;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.SimilarityFunction;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;

public class ElasticsearchVectorStore
extends AbstractObservationVectorStore
implements InitializingBean {
    private static final Logger logger = LoggerFactory.getLogger(ElasticsearchVectorStore.class);
    private static Map<SimilarityFunction, VectorStoreSimilarityMetric> SIMILARITY_TYPE_MAPPING = Map.of(SimilarityFunction.cosine, VectorStoreSimilarityMetric.COSINE, SimilarityFunction.l2_norm, VectorStoreSimilarityMetric.EUCLIDEAN, SimilarityFunction.dot_product, VectorStoreSimilarityMetric.DOT);
    private final EmbeddingModel embeddingModel;
    private final ElasticsearchClient elasticsearchClient;
    private final ElasticsearchVectorStoreOptions options;
    private final FilterExpressionConverter filterExpressionConverter;
    private final boolean initializeSchema;
    private final BatchingStrategy batchingStrategy;

    public ElasticsearchVectorStore(RestClient restClient, EmbeddingModel embeddingModel, boolean initializeSchema) {
        this(new ElasticsearchVectorStoreOptions(), restClient, embeddingModel, initializeSchema);
    }

    public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestClient restClient, EmbeddingModel embeddingModel, boolean initializeSchema) {
        this(options, restClient, embeddingModel, initializeSchema, ObservationRegistry.NOOP, null, (BatchingStrategy)new TokenCountBatchingStrategy());
    }

    public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestClient restClient, EmbeddingModel embeddingModel, boolean initializeSchema, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) {
        super(observationRegistry, customObservationConvention);
        this.initializeSchema = initializeSchema;
        Objects.requireNonNull(embeddingModel, "RestClient must not be null");
        Objects.requireNonNull(embeddingModel, "EmbeddingModel must not be null");
        String version = Version.VERSION == null ? "Unknown" : Version.VERSION.toString();
        this.elasticsearchClient = (ElasticsearchClient)new ElasticsearchClient((ElasticsearchTransport)new RestClientTransport(restClient, (JsonpMapper)new JacksonJsonpMapper(new ObjectMapper().configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)))).withTransportOptions(t -> t.addHeader("user-agent", "spring-ai elastic-java/" + version));
        this.embeddingModel = embeddingModel;
        this.options = options;
        this.filterExpressionConverter = new ElasticsearchAiSearchFilterExpressionConverter();
        this.batchingStrategy = batchingStrategy;
    }

    public void doAdd(List<Document> documents) {
        if (!this.indexExists()) {
            throw new IllegalArgumentException("Index not found");
        }
        BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder();
        this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
        for (Document document : documents) {
            bulkRequestBuilder.operations(op -> op.index(idx -> ((IndexOperation.Builder)((IndexOperation.Builder)idx.index(this.options.getIndexName())).id(document.getId())).document((Object)document)));
        }
        BulkResponse bulkRequest = this.bulkRequest(bulkRequestBuilder.build());
        if (bulkRequest.errors()) {
            List bulkResponseItems = bulkRequest.items();
            for (BulkResponseItem bulkResponseItem : bulkResponseItems) {
                if (bulkResponseItem.error() == null) continue;
                throw new IllegalStateException(bulkResponseItem.error().reason());
            }
        }
    }

    public Optional<Boolean> doDelete(List<String> idList) {
        BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder();
        if (!this.indexExists()) {
            throw new IllegalArgumentException("Index not found");
        }
        for (String id : idList) {
            bulkRequestBuilder.operations(op -> op.delete(idx -> (ObjectBuilder)((DeleteOperation.Builder)idx.index(this.options.getIndexName())).id(id)));
        }
        return Optional.of(this.bulkRequest(bulkRequestBuilder.build()).errors());
    }

    private BulkResponse bulkRequest(BulkRequest bulkRequest) {
        try {
            return this.elasticsearchClient.bulk(bulkRequest);
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public List<Document> doSimilaritySearch(SearchRequest searchRequest) {
        Assert.notNull((Object)searchRequest, (String)"The search request must not be null.");
        try {
            float threshold = (float)searchRequest.getSimilarityThreshold();
            if (this.options.getSimilarity().equals((Object)SimilarityFunction.l2_norm)) {
                threshold = 1.0f - threshold;
            }
            float finalThreshold = threshold;
            float[] vectors = this.embeddingModel.embed(searchRequest.getQuery());
            SearchResponse res = this.elasticsearchClient.search(sr -> sr.index(this.options.getIndexName(), new String[0]).knn(knn -> knn.queryVector(EmbeddingUtils.toList((float[])vectors)).similarity(Float.valueOf(finalThreshold)).k(Long.valueOf(searchRequest.getTopK())).field("embedding").numCandidates(Long.valueOf((long)(1.5 * (double)searchRequest.getTopK()))).filter(fl -> fl.queryString(qs -> qs.query(this.getElasticsearchQueryString(searchRequest.getFilterExpression()))))), Document.class);
            return res.hits().hits().stream().map(this::toDocument).collect(Collectors.toList());
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private String getElasticsearchQueryString(Filter.Expression filterExpression) {
        return Objects.isNull(filterExpression) ? "*" : this.filterExpressionConverter.convertExpression(filterExpression);
    }

    private Document toDocument(Hit<Document> hit) {
        Document document = (Document)hit.source();
        document.getMetadata().put("distance", Float.valueOf(this.calculateDistance(Float.valueOf(hit.score().floatValue()))));
        return document;
    }

    private float calculateDistance(Float score) {
        switch (this.options.getSimilarity()) {
            case l2_norm: {
                return (float)(1.0 - Math.sqrt(1.0f / score.floatValue() - 1.0f));
            }
        }
        return 2.0f * score.floatValue() - 1.0f;
    }

    public boolean indexExists() {
        try {
            return this.elasticsearchClient.indices().exists(ex -> ex.index(this.options.getIndexName(), new String[0])).value();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private void createIndexMapping() {
        try {
            this.elasticsearchClient.indices().create(cr -> cr.index(this.options.getIndexName()).mappings(map -> map.properties("embedding", p -> p.denseVector(dv -> dv.similarity(this.options.getSimilarity().toString()).dims(Integer.valueOf(this.options.getDimensions()))))));
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void afterPropertiesSet() {
        if (!this.initializeSchema) {
            return;
        }
        if (!this.indexExists()) {
            this.createIndexMapping();
        }
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {
        return VectorStoreObservationContext.builder((String)VectorStoreProvider.ELASTICSEARCH.value(), (String)operationName).withCollectionName(this.options.getIndexName()).withDimensions(Integer.valueOf(this.embeddingModel.dimensions())).withSimilarityMetric(this.getSimilarityMetric());
    }

    private String getSimilarityMetric() {
        if (!SIMILARITY_TYPE_MAPPING.containsKey((Object)this.options.getSimilarity())) {
            return this.options.getSimilarity().name();
        }
        return SIMILARITY_TYPE_MAPPING.get((Object)this.options.getSimilarity()).value();
    }
}

