/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vectorstore;

import io.micrometer.observation.ObservationRegistry;
import java.text.MessageFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
import org.springframework.ai.vectorstore.RedisFilterExpressionConverter;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.FilterExpressionConverter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationConvention;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import redis.clients.jedis.JedisPooled;
import redis.clients.jedis.Pipeline;
import redis.clients.jedis.json.Path2;
import redis.clients.jedis.search.Document;
import redis.clients.jedis.search.FTCreateParams;
import redis.clients.jedis.search.IndexDataType;
import redis.clients.jedis.search.Query;
import redis.clients.jedis.search.RediSearchUtil;
import redis.clients.jedis.search.Schema;
import redis.clients.jedis.search.SearchResult;
import redis.clients.jedis.search.schemafields.NumericField;
import redis.clients.jedis.search.schemafields.SchemaField;
import redis.clients.jedis.search.schemafields.TagField;
import redis.clients.jedis.search.schemafields.TextField;
import redis.clients.jedis.search.schemafields.VectorField;

public class RedisVectorStore
extends AbstractObservationVectorStore
implements InitializingBean {
    public static final String DEFAULT_INDEX_NAME = "spring-ai-index";
    public static final String DEFAULT_CONTENT_FIELD_NAME = "content";
    public static final String DEFAULT_EMBEDDING_FIELD_NAME = "embedding";
    public static final String DEFAULT_PREFIX = "embedding:";
    public static final Algorithm DEFAULT_VECTOR_ALGORITHM = Algorithm.HSNW;
    public static final String DISTANCE_FIELD_NAME = "vector_score";
    private static final String QUERY_FORMAT = "%s=>[KNN %s @%s $%s AS %s]";
    private static final Path2 JSON_SET_PATH = Path2.of((String)"$");
    private static final String JSON_PATH_PREFIX = "$.";
    private static final Logger logger = LoggerFactory.getLogger(RedisVectorStore.class);
    private static final Predicate<Object> RESPONSE_OK = Predicate.isEqual("OK");
    private static final Predicate<Object> RESPONSE_DEL_OK = Predicate.isEqual(1L);
    private static final String VECTOR_TYPE_FLOAT32 = "FLOAT32";
    private static final String EMBEDDING_PARAM_NAME = "BLOB";
    private static final String DEFAULT_DISTANCE_METRIC = "COSINE";
    private final boolean initializeSchema;
    private final JedisPooled jedis;
    private final EmbeddingModel embeddingModel;
    private final RedisVectorStoreConfig config;
    private final BatchingStrategy batchingStrategy;
    private FilterExpressionConverter filterExpressionConverter;

    public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis, boolean initializeSchema) {
        this(config, embeddingModel, jedis, initializeSchema, ObservationRegistry.NOOP, null, (BatchingStrategy)new TokenCountBatchingStrategy());
    }

    public RedisVectorStore(RedisVectorStoreConfig config, EmbeddingModel embeddingModel, JedisPooled jedis, boolean initializeSchema, ObservationRegistry observationRegistry, VectorStoreObservationConvention customObservationConvention, BatchingStrategy batchingStrategy) {
        super(observationRegistry, customObservationConvention);
        Assert.notNull((Object)config, (String)"Config must not be null");
        Assert.notNull((Object)embeddingModel, (String)"Embedding model must not be null");
        this.initializeSchema = initializeSchema;
        this.jedis = jedis;
        this.embeddingModel = embeddingModel;
        this.config = config;
        this.filterExpressionConverter = new RedisFilterExpressionConverter(this.config.metadataFields);
        this.batchingStrategy = batchingStrategy;
    }

    public JedisPooled getJedis() {
        return this.jedis;
    }

    public void doAdd(List<org.springframework.ai.document.Document> documents) {
        try (Pipeline pipeline = this.jedis.pipelined();){
            this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
            for (org.springframework.ai.document.Document document : documents) {
                document.setEmbedding(document.getEmbedding());
                HashMap<String, Object> fields = new HashMap<String, Object>();
                fields.put(this.config.embeddingFieldName, document.getEmbedding());
                fields.put(this.config.contentFieldName, document.getContent());
                fields.putAll(document.getMetadata());
                pipeline.jsonSetWithEscape(this.key(document.getId()), JSON_SET_PATH, fields);
            }
            List responses = pipeline.syncAndReturnAll();
            Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_OK)).findAny();
            if (errResponse.isPresent()) {
                String message = MessageFormat.format("Could not add document: {0}", errResponse.get());
                if (logger.isErrorEnabled()) {
                    logger.error(message);
                }
                throw new RuntimeException(message);
            }
        }
    }

    private String key(String id) {
        return this.config.prefix + id;
    }

    public Optional<Boolean> doDelete(List<String> idList) {
        try (Pipeline pipeline = this.jedis.pipelined();){
            for (String id : idList) {
                pipeline.jsonDel(this.key(id));
            }
            List responses = pipeline.syncAndReturnAll();
            Optional<Object> errResponse = responses.stream().filter(Predicate.not(RESPONSE_DEL_OK)).findAny();
            if (errResponse.isPresent()) {
                if (logger.isErrorEnabled()) {
                    logger.error("Could not delete document: {}", errResponse.get());
                }
                Optional<Boolean> optional = Optional.of(false);
                return optional;
            }
            Optional<Boolean> optional = Optional.of(true);
            return optional;
        }
    }

    public List<org.springframework.ai.document.Document> doSimilaritySearch(SearchRequest request) {
        Assert.isTrue((request.getTopK() > 0 ? 1 : 0) != 0, (String)"The number of documents to be returned must be greater than zero");
        Assert.isTrue((request.getSimilarityThreshold() >= 0.0 && request.getSimilarityThreshold() <= 1.0 ? 1 : 0) != 0, (String)"The similarity score is bounded between 0 and 1; least to most similar respectively.");
        String filter = this.nativeExpressionFilter(request);
        String queryString = String.format(QUERY_FORMAT, filter, request.getTopK(), this.config.embeddingFieldName, EMBEDDING_PARAM_NAME, DISTANCE_FIELD_NAME);
        ArrayList<String> returnFields = new ArrayList<String>();
        this.config.metadataFields.stream().map(MetadataField::name).forEach(returnFields::add);
        returnFields.add(this.config.embeddingFieldName);
        returnFields.add(this.config.contentFieldName);
        returnFields.add(DISTANCE_FIELD_NAME);
        float[] embedding = this.embeddingModel.embed(request.getQuery());
        Query query = new Query(queryString).addParam(EMBEDDING_PARAM_NAME, (Object)RediSearchUtil.toByteArray((float[])embedding)).returnFields(returnFields.toArray(new String[0])).setSortBy(DISTANCE_FIELD_NAME, true).limit(Integer.valueOf(0), Integer.valueOf(request.getTopK())).dialect(2);
        SearchResult result = this.jedis.ftSearch(this.config.indexName, query);
        return result.getDocuments().stream().filter(d -> (double)this.similarityScore((Document)d) >= request.getSimilarityThreshold()).map(this::toDocument).toList();
    }

    private org.springframework.ai.document.Document toDocument(Document doc) {
        String id = doc.getId().substring(this.config.prefix.length());
        String content = doc.hasProperty(this.config.contentFieldName) ? doc.getString(this.config.contentFieldName) : null;
        Map metadata = this.config.metadataFields.stream().map(MetadataField::name).filter(arg_0 -> ((Document)doc).hasProperty(arg_0)).collect(Collectors.toMap(Function.identity(), arg_0 -> ((Document)doc).getString(arg_0)));
        metadata.put(DISTANCE_FIELD_NAME, Float.valueOf(1.0f - this.similarityScore(doc)));
        return new org.springframework.ai.document.Document(id, content, metadata);
    }

    private float similarityScore(Document doc) {
        return (2.0f - Float.parseFloat(doc.getString(DISTANCE_FIELD_NAME))) / 2.0f;
    }

    private String nativeExpressionFilter(SearchRequest request) {
        if (request.getFilterExpression() == null) {
            return "*";
        }
        return "(" + this.filterExpressionConverter.convertExpression(request.getFilterExpression()) + ")";
    }

    public void afterPropertiesSet() {
        if (!this.initializeSchema) {
            return;
        }
        if (this.jedis.ftList().contains(this.config.indexName)) {
            return;
        }
        String response = this.jedis.ftCreate(this.config.indexName, FTCreateParams.createParams().on(IndexDataType.JSON).addPrefix(this.config.prefix), this.schemaFields());
        if (!RESPONSE_OK.test(response)) {
            String message = MessageFormat.format("Could not create index: {0}", response);
            throw new RuntimeException(message);
        }
    }

    private Iterable<SchemaField> schemaFields() {
        HashMap<String, Object> vectorAttrs = new HashMap<String, Object>();
        vectorAttrs.put("DIM", this.embeddingModel.dimensions());
        vectorAttrs.put("DISTANCE_METRIC", DEFAULT_DISTANCE_METRIC);
        vectorAttrs.put("TYPE", VECTOR_TYPE_FLOAT32);
        ArrayList<SchemaField> fields = new ArrayList<SchemaField>();
        fields.add((SchemaField)TextField.of((String)this.jsonPath(this.config.contentFieldName)).as(this.config.contentFieldName).weight(1.0));
        fields.add((SchemaField)VectorField.builder().fieldName(this.jsonPath(this.config.embeddingFieldName)).algorithm(this.vectorAlgorithm()).attributes(vectorAttrs).as(this.config.embeddingFieldName).build());
        if (!CollectionUtils.isEmpty(this.config.metadataFields)) {
            for (MetadataField field : this.config.metadataFields) {
                fields.add(this.schemaField(field));
            }
        }
        return fields;
    }

    private SchemaField schemaField(MetadataField field) {
        String fieldName = this.jsonPath(field.name);
        return switch (field.fieldType) {
            case Schema.FieldType.NUMERIC -> NumericField.of((String)fieldName).as(field.name);
            case Schema.FieldType.TAG -> TagField.of((String)fieldName).as(field.name);
            case Schema.FieldType.TEXT -> TextField.of((String)fieldName).as(field.name);
            default -> throw new IllegalArgumentException(MessageFormat.format("Field {0} has unsupported type {1}", field.name, field.fieldType));
        };
    }

    private VectorField.VectorAlgorithm vectorAlgorithm() {
        if (this.config.vectorAlgorithm == Algorithm.HSNW) {
            return VectorField.VectorAlgorithm.HNSW;
        }
        return VectorField.VectorAlgorithm.FLAT;
    }

    private String jsonPath(String field) {
        return JSON_PATH_PREFIX + field;
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {
        return VectorStoreObservationContext.builder((String)VectorStoreProvider.REDIS.value(), (String)operationName).withCollectionName(this.config.indexName).withDimensions(Integer.valueOf(this.embeddingModel.dimensions())).withFieldName(this.config.embeddingFieldName).withSimilarityMetric(VectorStoreSimilarityMetric.COSINE.value());
    }

    public static final class RedisVectorStoreConfig {
        private final String indexName;
        private final String prefix;
        private final String contentFieldName;
        private final String embeddingFieldName;
        private final Algorithm vectorAlgorithm;
        private final List<MetadataField> metadataFields;

        private RedisVectorStoreConfig() {
            this(RedisVectorStoreConfig.builder());
        }

        private RedisVectorStoreConfig(Builder builder) {
            this.indexName = builder.indexName;
            this.prefix = builder.prefix;
            this.contentFieldName = builder.contentFieldName;
            this.embeddingFieldName = builder.embeddingFieldName;
            this.vectorAlgorithm = builder.vectorAlgorithm;
            this.metadataFields = builder.metadataFields;
        }

        public static Builder builder() {
            return new Builder();
        }

        public static RedisVectorStoreConfig defaultConfig() {
            return RedisVectorStoreConfig.builder().build();
        }

        public static final class Builder {
            private String indexName = "spring-ai-index";
            private String prefix = "embedding:";
            private String contentFieldName = "content";
            private String embeddingFieldName = "embedding";
            private Algorithm vectorAlgorithm = DEFAULT_VECTOR_ALGORITHM;
            private List<MetadataField> metadataFields = new ArrayList<MetadataField>();

            private Builder() {
            }

            public Builder withIndexName(String name) {
                this.indexName = name;
                return this;
            }

            public Builder withPrefix(String prefix) {
                this.prefix = prefix;
                return this;
            }

            public Builder withContentFieldName(String name) {
                this.contentFieldName = name;
                return this;
            }

            public Builder withEmbeddingFieldName(String name) {
                this.embeddingFieldName = name;
                return this;
            }

            public Builder withVectorAlgorithm(Algorithm algorithm) {
                this.vectorAlgorithm = algorithm;
                return this;
            }

            public Builder withMetadataFields(MetadataField ... fields) {
                return this.withMetadataFields(Arrays.asList(fields));
            }

            public Builder withMetadataFields(List<MetadataField> fields) {
                this.metadataFields = fields;
                return this;
            }

            public RedisVectorStoreConfig build() {
                return new RedisVectorStoreConfig(this);
            }
        }
    }

    public record MetadataField(String name, Schema.FieldType fieldType) {
        public static MetadataField text(String name) {
            return new MetadataField(name, Schema.FieldType.TEXT);
        }

        public static MetadataField numeric(String name) {
            return new MetadataField(name, Schema.FieldType.NUMERIC);
        }

        public static MetadataField tag(String name) {
            return new MetadataField(name, Schema.FieldType.TAG);
        }
    }

    public static enum Algorithm {
        FLAT,
        HSNW;

    }
}

