/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.chroma.vectorstore;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.http.client.support.BasicAuthenticationInterceptor;
import org.springframework.lang.Nullable;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.client.HttpClientErrorException;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.client.HttpStatusCodeException;
import org.springframework.web.client.RestClient;

public class ChromaApi {
    private static final Pattern VALUE_ERROR_PATTERN = Pattern.compile("ValueError\\('([^']*)'\\)");
    private static final Pattern MESSAGE_ERROR_PATTERN = Pattern.compile("\"message\":\"(.*?)\"");
    private final ObjectMapper objectMapper;
    private RestClient restClient;
    @Nullable
    private String keyToken;

    public ChromaApi(String baseUrl) {
        this(baseUrl, RestClient.builder().requestFactory((ClientHttpRequestFactory)new SimpleClientHttpRequestFactory()), new ObjectMapper());
    }

    public ChromaApi(String baseUrl, RestClient.Builder restClientBuilder) {
        this(baseUrl, restClientBuilder, new ObjectMapper());
    }

    public ChromaApi(String baseUrl, RestClient.Builder restClientBuilder, ObjectMapper objectMapper) {
        this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(h -> h.setContentType(MediaType.APPLICATION_JSON)).build();
        this.objectMapper = objectMapper;
    }

    public ChromaApi withKeyToken(String keyToken) {
        this.keyToken = keyToken;
        return this;
    }

    public ChromaApi withBasicAuthCredentials(String username, String password) {
        this.restClient = this.restClient.mutate().requestInterceptor((ClientHttpRequestInterceptor)new BasicAuthenticationInterceptor(username, password)).build();
        return this;
    }

    public List<Embedding> toEmbeddingResponseList(@Nullable QueryResponse queryResponse) {
        ArrayList<Embedding> result = new ArrayList<Embedding>();
        if (queryResponse != null && !CollectionUtils.isEmpty(queryResponse.ids())) {
            for (int i = 0; i < queryResponse.ids().get(0).size(); ++i) {
                result.add(new Embedding(queryResponse.ids().get(0).get(i), queryResponse.embeddings().get(0).get(i), queryResponse.documents().get(0).get(i), queryResponse.metadata().get(0).get(i), queryResponse.distances().get(0).get(i)));
            }
        }
        return result;
    }

    @Nullable
    public Collection createCollection(CreateCollectionRequest createCollectionRequest) {
        return (Collection)((RestClient.RequestBodySpec)((RestClient.RequestBodySpec)this.restClient.post().uri("/api/v1/collections", new Object[0])).headers(this::httpHeaders)).body((Object)createCollectionRequest).retrieve().toEntity(Collection.class).getBody();
    }

    public void deleteCollection(String collectionName) {
        this.restClient.delete().uri("/api/v1/collections/{collection_name}", new Object[]{collectionName}).headers(this::httpHeaders).retrieve().toBodilessEntity();
    }

    @Nullable
    public Collection getCollection(String collectionName) {
        try {
            return (Collection)this.restClient.get().uri("/api/v1/collections/{collection_name}", new Object[]{collectionName}).headers(this::httpHeaders).retrieve().toEntity(Collection.class).getBody();
        }
        catch (HttpClientErrorException | HttpServerErrorException e) {
            String msg = this.getErrorMessage((HttpStatusCodeException)e);
            if (String.format("Collection %s does not exist.", collectionName).equals(msg)) {
                return null;
            }
            throw new RuntimeException(msg, e);
        }
    }

    @Nullable
    public List<Collection> listCollections() {
        return (List)this.restClient.get().uri("/api/v1/collections", new Object[0]).headers(this::httpHeaders).retrieve().toEntity(CollectionList.class).getBody();
    }

    public void upsertEmbeddings(@Nullable String collectionId, AddEmbeddingsRequest embedding) {
        ((RestClient.RequestBodySpec)((RestClient.RequestBodySpec)this.restClient.post().uri("/api/v1/collections/{collection_id}/upsert", new Object[]{collectionId})).headers(this::httpHeaders)).body((Object)embedding).retrieve().toBodilessEntity();
    }

    public int deleteEmbeddings(@Nullable String collectionId, DeleteEmbeddingsRequest deleteRequest) {
        return ((RestClient.RequestBodySpec)((RestClient.RequestBodySpec)this.restClient.post().uri("/api/v1/collections/{collection_id}/delete", new Object[]{collectionId})).headers(this::httpHeaders)).body((Object)deleteRequest).retrieve().toEntity(String.class).getStatusCode().value();
    }

    @Nullable
    public Long countEmbeddings(String collectionId) {
        return (Long)this.restClient.get().uri("/api/v1/collections/{collection_id}/count", new Object[]{collectionId}).headers(this::httpHeaders).retrieve().toEntity(Long.class).getBody();
    }

    @Nullable
    public QueryResponse queryCollection(@Nullable String collectionId, QueryRequest queryRequest) {
        return (QueryResponse)((RestClient.RequestBodySpec)((RestClient.RequestBodySpec)this.restClient.post().uri("/api/v1/collections/{collection_id}/query", new Object[]{collectionId})).headers(this::httpHeaders)).body((Object)queryRequest).retrieve().toEntity(QueryResponse.class).getBody();
    }

    @Nullable
    public GetEmbeddingResponse getEmbeddings(String collectionId, GetEmbeddingsRequest getEmbeddingsRequest) {
        return (GetEmbeddingResponse)((RestClient.RequestBodySpec)((RestClient.RequestBodySpec)this.restClient.post().uri("/api/v1/collections/{collection_id}/get", new Object[]{collectionId})).headers(this::httpHeaders)).body((Object)getEmbeddingsRequest).retrieve().toEntity(GetEmbeddingResponse.class).getBody();
    }

    public Map<String, Object> where(String text) {
        try {
            return (Map)this.objectMapper.readValue(text, Map.class);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException(e);
        }
    }

    private void httpHeaders(HttpHeaders headers) {
        if (StringUtils.hasText((String)this.keyToken)) {
            headers.setBearerAuth(this.keyToken);
        }
    }

    private String getErrorMessage(HttpStatusCodeException e) {
        String errorMessage = e.getMessage();
        if (!StringUtils.hasText((String)errorMessage)) {
            return "";
        }
        Matcher valueErrorMatcher = VALUE_ERROR_PATTERN.matcher(errorMessage);
        if (e instanceof HttpServerErrorException && valueErrorMatcher.find()) {
            return valueErrorMatcher.group(1);
        }
        Matcher messageErrorMatcher = MESSAGE_ERROR_PATTERN.matcher(errorMessage);
        if (messageErrorMatcher.find()) {
            return messageErrorMatcher.group(1);
        }
        return "";
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record QueryResponse(@JsonProperty(value="ids") List<List<String>> ids, @JsonProperty(value="embeddings") List<List<float[]>> embeddings, @JsonProperty(value="documents") List<List<String>> documents, @JsonProperty(value="metadatas") List<List<Map<String, Object>>> metadata, @JsonProperty(value="distances") List<List<Double>> distances) {
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record Embedding(@JsonProperty(value="id") String id, @JsonProperty(value="embedding") float[] embedding, @JsonProperty(value="document") String document, @Nullable @JsonProperty(value="metadata") Map<String, Object> metadata, @JsonProperty(value="distances") Double distances) {
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record Collection(@JsonProperty(value="id") String id, @JsonProperty(value="name") String name, @JsonProperty(value="metadata") Map<String, Object> metadata) {
    }

    private static class CollectionList
    extends ArrayList<Collection> {
        private CollectionList() {
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record GetEmbeddingResponse(@JsonProperty(value="ids") List<String> ids, @JsonProperty(value="embeddings") List<float[]> embeddings, @JsonProperty(value="documents") List<String> documents, @JsonProperty(value="metadatas") List<Map<String, String>> metadata) {
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record QueryRequest(@JsonProperty(value="query_embeddings") List<float[]> queryEmbeddings, @JsonProperty(value="n_results") Integer nResults, @Nullable @JsonProperty(value="where") Map<String, Object> where, @JsonProperty(value="include") List<Include> include) {
        public QueryRequest(float[] queryEmbedding, Integer nResults) {
            this(List.of(queryEmbedding), nResults, null, Include.all);
        }

        public QueryRequest(float[] queryEmbedding, Integer nResults, @Nullable Map<String, Object> where) {
            this(List.of(queryEmbedding), nResults, CollectionUtils.isEmpty(where) ? null : where, Include.all);
        }

        public static enum Include {
            METADATAS,
            DOCUMENTS,
            DISTANCES,
            EMBEDDINGS;

            public static final List<Include> all;

            static {
                all = List.of(METADATAS, DOCUMENTS, DISTANCES, EMBEDDINGS);
            }
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record GetEmbeddingsRequest(@JsonProperty(value="ids") List<String> ids, @Nullable @JsonProperty(value="where") Map<String, Object> where, @JsonProperty(value="limit") Integer limit, @JsonProperty(value="offset") Integer offset, @JsonProperty(value="include") List<QueryRequest.Include> include) {
        public GetEmbeddingsRequest(List<String> ids) {
            this(ids, null, 10, 0, QueryRequest.Include.all);
        }

        public GetEmbeddingsRequest(List<String> ids, Map<String, Object> where) {
            this(ids, CollectionUtils.isEmpty(where) ? null : where, 10, 0, QueryRequest.Include.all);
        }

        public GetEmbeddingsRequest(List<String> ids, Map<String, Object> where, Integer limit, Integer offset) {
            this(ids, CollectionUtils.isEmpty(where) ? null : where, limit, offset, QueryRequest.Include.all);
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record DeleteEmbeddingsRequest(@Nullable @JsonProperty(value="ids") List<String> ids, @Nullable @JsonProperty(value="where") Map<String, Object> where) {
        public DeleteEmbeddingsRequest(List<String> ids) {
            this(ids, null);
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record AddEmbeddingsRequest(@JsonProperty(value="ids") List<String> ids, @JsonProperty(value="embeddings") List<float[]> embeddings, @JsonProperty(value="metadatas") List<Map<String, Object>> metadata, @JsonProperty(value="documents") List<String> documents) {
        public AddEmbeddingsRequest(String id, float[] embedding, Map<String, Object> metadata, String document) {
            this(List.of(id), List.of(embedding), List.of(metadata), List.of(document));
        }
    }

    @JsonInclude(value=JsonInclude.Include.NON_NULL)
    public record CreateCollectionRequest(@JsonProperty(value="name") String name, @JsonProperty(value="metadata") Map<String, Object> metadata) {
        public CreateCollectionRequest(String name) {
            this(name, new HashMap<String, Object>(Map.of("hnsw:space", "cosine")));
        }
    }
}

