/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.store.embedding.azure.search;

import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.credential.TokenCredential;
import com.azure.core.util.Context;
import com.azure.search.documents.SearchClient;
import com.azure.search.documents.SearchClientBuilder;
import com.azure.search.documents.SearchDocument;
import com.azure.search.documents.indexes.SearchIndexClient;
import com.azure.search.documents.indexes.SearchIndexClientBuilder;
import com.azure.search.documents.indexes.models.HnswAlgorithmConfiguration;
import com.azure.search.documents.indexes.models.HnswParameters;
import com.azure.search.documents.indexes.models.SearchField;
import com.azure.search.documents.indexes.models.SearchFieldDataType;
import com.azure.search.documents.indexes.models.SearchIndex;
import com.azure.search.documents.indexes.models.SemanticConfiguration;
import com.azure.search.documents.indexes.models.SemanticField;
import com.azure.search.documents.indexes.models.SemanticPrioritizedFields;
import com.azure.search.documents.indexes.models.SemanticSearch;
import com.azure.search.documents.indexes.models.VectorSearch;
import com.azure.search.documents.indexes.models.VectorSearchAlgorithmMetric;
import com.azure.search.documents.indexes.models.VectorSearchProfile;
import com.azure.search.documents.models.IndexingResult;
import com.azure.search.documents.models.SearchOptions;
import com.azure.search.documents.models.SearchResult;
import com.azure.search.documents.models.VectorQuery;
import com.azure.search.documents.models.VectorSearchOptions;
import com.azure.search.documents.models.VectorizedQuery;
import com.azure.search.documents.util.SearchPagedIterable;
import dev.langchain4j.data.document.Metadata;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import dev.langchain4j.store.embedding.RelevanceScore;
import dev.langchain4j.store.embedding.azure.search.AzureAiSearchRuntimeException;
import dev.langchain4j.store.embedding.azure.search.Document;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractAzureAiSearchEmbeddingStore
implements EmbeddingStore<TextSegment> {
    private static final Logger log = LoggerFactory.getLogger(AbstractAzureAiSearchEmbeddingStore.class);
    public static final String INDEX_NAME = "vectorsearch";
    static final String DEFAULT_FIELD_ID = "id";
    protected static final String DEFAULT_FIELD_CONTENT = "content";
    protected final String DEFAULT_FIELD_CONTENT_VECTOR = "content_vector";
    protected static final String DEFAULT_FIELD_METADATA = "metadata";
    protected static final String DEFAULT_FIELD_METADATA_SOURCE = "source";
    protected static final String DEFAULT_FIELD_METADATA_ATTRS = "attributes";
    protected static final String SEMANTIC_SEARCH_CONFIG_NAME = "semantic-search-config";
    protected static final String VECTOR_ALGORITHM_NAME = "vector-search-algorithm";
    protected static final String VECTOR_SEARCH_PROFILE_NAME = "vector-search-profile";
    private boolean createOrUpdateIndex;
    private SearchIndexClient searchIndexClient;
    protected SearchClient searchClient;

    protected void initialize(String endpoint, AzureKeyCredential keyCredential, TokenCredential tokenCredential, boolean createOrUpdateIndex, int dimensions, SearchIndex index) {
        this.createOrUpdateIndex = createOrUpdateIndex;
        if (keyCredential != null) {
            if (createOrUpdateIndex) {
                this.searchIndexClient = new SearchIndexClientBuilder().endpoint(endpoint).credential(keyCredential).buildClient();
            }
            this.searchClient = new SearchClientBuilder().endpoint(endpoint).credential(keyCredential).indexName(INDEX_NAME).buildClient();
        } else {
            if (createOrUpdateIndex) {
                this.searchIndexClient = new SearchIndexClientBuilder().endpoint(endpoint).credential(tokenCredential).buildClient();
            }
            this.searchClient = new SearchClientBuilder().endpoint(endpoint).credential(tokenCredential).indexName(INDEX_NAME).buildClient();
        }
        if (createOrUpdateIndex) {
            if (index == null) {
                this.createOrUpdateIndex(dimensions);
            } else {
                this.createOrUpdateIndex(index);
            }
        }
    }

    public void createOrUpdateIndex(int dimensions) {
        if (!this.createOrUpdateIndex) {
            throw new IllegalArgumentException("createOrUpdateIndex is false, so the index cannot be created or updated");
        }
        if (dimensions == 0) {
            log.info("Dimensions is 0, so the index will only be created for full text search");
        }
        ArrayList<SearchField> fields = new ArrayList<SearchField>();
        fields.add(new SearchField(DEFAULT_FIELD_ID, SearchFieldDataType.STRING).setKey(Boolean.valueOf(true)).setFilterable(Boolean.valueOf(true)));
        fields.add(new SearchField(DEFAULT_FIELD_CONTENT, SearchFieldDataType.STRING).setSearchable(Boolean.valueOf(true)).setFilterable(Boolean.valueOf(true)));
        if (dimensions > 0) {
            fields.add(new SearchField("content_vector", SearchFieldDataType.collection((SearchFieldDataType)SearchFieldDataType.SINGLE)).setSearchable(Boolean.valueOf(true)).setVectorSearchDimensions(Integer.valueOf(dimensions)).setVectorSearchProfileName(VECTOR_SEARCH_PROFILE_NAME));
        }
        fields.add(new SearchField(DEFAULT_FIELD_METADATA, SearchFieldDataType.COMPLEX).setFields(Arrays.asList(new SearchField(DEFAULT_FIELD_METADATA_SOURCE, SearchFieldDataType.STRING).setFilterable(Boolean.valueOf(true)), new SearchField(DEFAULT_FIELD_METADATA_ATTRS, SearchFieldDataType.collection((SearchFieldDataType)SearchFieldDataType.COMPLEX)).setFields(Arrays.asList(new SearchField("key", SearchFieldDataType.STRING).setFilterable(Boolean.valueOf(true)), new SearchField("value", SearchFieldDataType.STRING).setFilterable(Boolean.valueOf(true)))))));
        SearchIndex index = null;
        if (dimensions > 0) {
            VectorSearch vectorSearch = new VectorSearch().setAlgorithms(Collections.singletonList(new HnswAlgorithmConfiguration(VECTOR_ALGORITHM_NAME).setParameters(new HnswParameters().setMetric(VectorSearchAlgorithmMetric.COSINE).setM(Integer.valueOf(4)).setEfSearch(Integer.valueOf(500)).setEfConstruction(Integer.valueOf(400))))).setProfiles(Collections.singletonList(new VectorSearchProfile(VECTOR_SEARCH_PROFILE_NAME, VECTOR_ALGORITHM_NAME)));
            SemanticSearch semanticSearch = new SemanticSearch().setDefaultConfigurationName(SEMANTIC_SEARCH_CONFIG_NAME).setConfigurations(Collections.singletonList(new SemanticConfiguration(SEMANTIC_SEARCH_CONFIG_NAME, new SemanticPrioritizedFields().setContentFields(new SemanticField[]{new SemanticField(DEFAULT_FIELD_CONTENT)}).setKeywordsFields(new SemanticField[]{new SemanticField(DEFAULT_FIELD_CONTENT)}))));
            index = new SearchIndex(INDEX_NAME).setFields(fields).setVectorSearch(vectorSearch).setSemanticSearch(semanticSearch);
        } else {
            index = new SearchIndex(INDEX_NAME).setFields(fields);
        }
        this.searchIndexClient.createOrUpdateIndex(index);
    }

    void createOrUpdateIndex(SearchIndex index) {
        if (!this.createOrUpdateIndex) {
            throw new IllegalArgumentException("createOrUpdateIndex is false, so the index cannot be created or updated");
        }
        this.searchIndexClient.createOrUpdateIndex(index);
    }

    public void deleteIndex() {
        if (!this.createOrUpdateIndex) {
            throw new IllegalArgumentException("createOrUpdateIndex is false, so the index cannot be deleted");
        }
        this.searchIndexClient.deleteIndex(INDEX_NAME);
    }

    public String add(Embedding embedding) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, null);
        return id;
    }

    public void add(String id, Embedding embedding) {
        this.addInternal(id, embedding, null);
    }

    public String add(Embedding embedding, TextSegment textSegment) {
        String id = Utils.randomUUID();
        this.addInternal(id, embedding, textSegment);
        return id;
    }

    public List<String> addAll(List<Embedding> embeddings) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, null);
        return ids;
    }

    public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
        List<String> ids = embeddings.stream().map(ignored -> Utils.randomUUID()).collect(Collectors.toList());
        this.addAllInternal(ids, embeddings, embedded);
        return ids;
    }

    public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
        List vector = referenceEmbedding.vectorAsList();
        VectorizedQuery vectorizedQuery = new VectorizedQuery(vector).setFields(new String[]{"content_vector"}).setKNearestNeighborsCount(Integer.valueOf(maxResults));
        SearchPagedIterable searchResults = this.searchClient.search(null, new SearchOptions().setVectorSearchOptions(new VectorSearchOptions().setQueries(new VectorQuery[]{vectorizedQuery})), Context.NONE);
        ArrayList<EmbeddingMatch<TextSegment>> result = new ArrayList<EmbeddingMatch<TextSegment>>();
        for (SearchResult searchResult : searchResults) {
            EmbeddingMatch embeddingMatch;
            Double score = AbstractAzureAiSearchEmbeddingStore.fromAzureScoreToRelevanceScore(searchResult.getScore());
            if (score < minScore) continue;
            SearchDocument searchDocument = (SearchDocument)searchResult.getDocument(SearchDocument.class);
            String embeddingId = (String)searchDocument.get((Object)DEFAULT_FIELD_ID);
            List embeddingList = (List)searchDocument.get((Object)"content_vector");
            float[] embeddingArray = this.doublesListToFloatArray(embeddingList);
            Embedding embedding = Embedding.from((float[])embeddingArray);
            String embeddedContent = (String)searchDocument.get((Object)DEFAULT_FIELD_CONTENT);
            if (Utils.isNotNullOrBlank((String)embeddedContent)) {
                LinkedHashMap metadata = (LinkedHashMap)searchDocument.get((Object)DEFAULT_FIELD_METADATA);
                List attributes = (List)metadata.get(DEFAULT_FIELD_METADATA_ATTRS);
                HashMap<String, String> attributesMap = new HashMap<String, String>();
                for (Object attribute : attributes) {
                    LinkedHashMap innerAttribute = (LinkedHashMap)attribute;
                    String key = (String)innerAttribute.get("key");
                    String value = (String)innerAttribute.get("value");
                    attributesMap.put(key, value);
                }
                Metadata langChainMetadata = Metadata.from(attributesMap);
                TextSegment embedded = TextSegment.textSegment((String)embeddedContent, (Metadata)langChainMetadata);
                embeddingMatch = new EmbeddingMatch(score, embeddingId, embedding, (Object)embedded);
            } else {
                embeddingMatch = new EmbeddingMatch(score, embeddingId, embedding, null);
            }
            result.add((EmbeddingMatch<TextSegment>)embeddingMatch);
        }
        return result;
    }

    private void addInternal(String id, Embedding embedding, TextSegment embedded) {
        this.addAllInternal(Collections.singletonList(id), Collections.singletonList(embedding), embedded == null ? null : Collections.singletonList(embedded));
    }

    private void addAllInternal(List<String> ids, List<Embedding> embeddings, List<TextSegment> embedded) {
        if (Utils.isNullOrEmpty(ids) || Utils.isNullOrEmpty(embeddings)) {
            log.info("Empty embeddings - no ops");
            return;
        }
        ValidationUtils.ensureTrue((ids.size() == embeddings.size() ? 1 : 0) != 0, (String)"ids size is not equal to embeddings size");
        ValidationUtils.ensureTrue((embedded == null || embeddings.size() == embedded.size() ? 1 : 0) != 0, (String)"embeddings size is not equal to embedded size");
        ArrayList<Document> documents = new ArrayList<Document>();
        for (int i = 0; i < ids.size(); ++i) {
            Document document = new Document();
            document.setId(ids.get(i));
            document.setContentVector(embeddings.get(i).vectorAsList());
            if (embedded != null) {
                document.setContent(embedded.get(i).text());
                Document.Metadata metadata = new Document.Metadata();
                ArrayList<Document.Metadata.Attribute> attributes = new ArrayList<Document.Metadata.Attribute>();
                for (Map.Entry entry : embedded.get(i).metadata().asMap().entrySet()) {
                    Document.Metadata.Attribute attribute = new Document.Metadata.Attribute();
                    attribute.setKey((String)entry.getKey());
                    attribute.setValue((String)entry.getValue());
                    attributes.add(attribute);
                }
                metadata.setAttributes(attributes);
                document.setMetadata(metadata);
            }
            documents.add(document);
        }
        List indexingResults = this.searchClient.uploadDocuments(documents).getResults();
        for (IndexingResult indexingResult : indexingResults) {
            if (!indexingResult.isSucceeded()) {
                throw new AzureAiSearchRuntimeException("Failed to add embedding: " + indexingResult.getErrorMessage());
            }
            log.debug("Added embedding: {}", (Object)indexingResult.getKey());
        }
    }

    float[] doublesListToFloatArray(List<Double> doubles) {
        float[] array = new float[doubles.size()];
        for (int i = 0; i < doubles.size(); ++i) {
            array[i] = doubles.get(i).floatValue();
        }
        return array;
    }

    protected static double fromAzureScoreToRelevanceScore(double score) {
        double cosineDistance = (1.0 - score) / score;
        double cosineSimilarity = -cosineDistance + 1.0;
        return RelevanceScore.fromCosineSimilarity((double)cosineSimilarity);
    }
}

