/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vectorstore.coherence;

import com.oracle.coherence.ai.DistanceAlgorithm;
import com.oracle.coherence.ai.DocumentChunk;
import com.oracle.coherence.ai.Float32Vector;
import com.oracle.coherence.ai.QueryResult;
import com.oracle.coherence.ai.Vector;
import com.oracle.coherence.ai.distance.CosineDistance;
import com.oracle.coherence.ai.distance.InnerProductDistance;
import com.oracle.coherence.ai.distance.L2SquaredDistance;
import com.oracle.coherence.ai.hnsw.HnswIndex;
import com.oracle.coherence.ai.index.BinaryQuantIndex;
import com.oracle.coherence.ai.search.SimilaritySearch;
import com.oracle.coherence.ai.util.Vectors;
import com.tangosol.net.NamedMap;
import com.tangosol.net.Session;
import com.tangosol.util.Filter;
import com.tangosol.util.InvocableMap;
import com.tangosol.util.ValueExtractor;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.vectorstore.AbstractVectorStoreBuilder;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.coherence.CoherenceFilterExpressionConverter;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.observation.AbstractObservationVectorStore;
import org.springframework.ai.vectorstore.observation.VectorStoreObservationContext;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

public class CoherenceVectorStore
extends AbstractObservationVectorStore
implements InitializingBean {
    public static final String DEFAULT_MAP_NAME = "spring-ai-documents";
    public static final DistanceType DEFAULT_DISTANCE_TYPE = DistanceType.COSINE;
    public static final CoherenceFilterExpressionConverter FILTER_EXPRESSION_CONVERTER = new CoherenceFilterExpressionConverter();
    private final int dimensions;
    private final Session session;
    private NamedMap<DocumentChunk.Id, DocumentChunk> documentChunks;
    private String mapName;
    private DistanceType distanceType;
    private boolean forcedNormalization;
    private IndexType indexType;

    protected CoherenceVectorStore(Builder builder) {
        super((AbstractVectorStoreBuilder)builder);
        Assert.notNull((Object)builder.session, (String)"Session must not be null");
        this.session = builder.session;
        this.dimensions = builder.getEmbeddingModel().dimensions();
        this.mapName = builder.mapName;
        this.distanceType = builder.distanceType;
        this.forcedNormalization = builder.forcedNormalization;
        this.indexType = builder.indexType;
    }

    public static Builder builder(Session session, EmbeddingModel embeddingModel) {
        return new Builder(session, embeddingModel);
    }

    public void doAdd(List<Document> documents) {
        HashMap<DocumentChunk.Id, DocumentChunk> chunks = new HashMap<DocumentChunk.Id, DocumentChunk>((int)Math.ceil((float)documents.size() / 0.75f));
        for (Document doc : documents) {
            DocumentChunk.Id id = this.toChunkId(doc.getId());
            DocumentChunk chunk = new DocumentChunk(doc.getText(), doc.getMetadata(), (Vector)this.toFloat32Vector(this.embeddingModel.embed(doc)));
            chunks.put(id, chunk);
        }
        this.documentChunks.putAll(chunks);
    }

    public void doDelete(List<String> idList) {
        List<DocumentChunk.Id> chunkIds = idList.stream().map(this::toChunkId).toList();
        this.documentChunks.invokeAll(chunkIds, (InvocableMap.EntryProcessor & Serializable)entry -> {
            if (entry.isPresent()) {
                entry.remove(false);
                return true;
            }
            return false;
        });
    }

    public List<Document> doSimilaritySearch(SearchRequest request) {
        Float32Vector vector = this.toFloat32Vector(this.embeddingModel.embed(request.getQuery()));
        Filter.Expression expression = request.getFilterExpression();
        Filter<?> filter = expression == null ? null : FILTER_EXPRESSION_CONVERTER.convert((Filter.Operand)expression);
        SimilaritySearch search = new SimilaritySearch(DocumentChunk::vector, (Vector)vector, request.getTopK()).algorithm(this.getDistanceAlgorithm()).filter(filter);
        List results = (List)this.documentChunks.aggregate((InvocableMap.EntryAggregator)search);
        ArrayList<Document> documents = new ArrayList<Document>(results.size());
        for (QueryResult r : results) {
            if (this.distanceType == DistanceType.COSINE && !(1.0 - r.getDistance() >= request.getSimilarityThreshold())) continue;
            DocumentChunk.Id id = (DocumentChunk.Id)r.getKey();
            DocumentChunk chunk = (DocumentChunk)r.getValue();
            HashMap<String, Double> mergedMetadata = new HashMap<String, Double>(chunk.metadata());
            mergedMetadata.put(DocumentMetadata.DISTANCE.value(), r.getDistance());
            documents.add(Document.builder().id(id.docId()).text(chunk.text()).metadata(mergedMetadata).score(Double.valueOf(1.0 - r.getDistance())).build());
        }
        return documents;
    }

    private DistanceAlgorithm<float[]> getDistanceAlgorithm() {
        return switch (this.distanceType.ordinal()) {
            default -> throw new IncompatibleClassChangeError();
            case 0 -> new CosineDistance();
            case 1 -> new InnerProductDistance();
            case 2 -> new L2SquaredDistance();
        };
    }

    public void afterPropertiesSet() throws Exception {
        this.documentChunks = this.session.getMap(this.mapName, new NamedMap.Option[0]);
        switch (this.indexType.ordinal()) {
            case 2: {
                this.documentChunks.addIndex((ValueExtractor)new HnswIndex(DocumentChunk::vector, this.distanceType.name(), this.dimensions));
                break;
            }
            case 1: {
                this.documentChunks.addIndex((ValueExtractor)new BinaryQuantIndex(DocumentChunk::vector));
            }
        }
    }

    private DocumentChunk.Id toChunkId(String id) {
        return new DocumentChunk.Id(id, 0);
    }

    private Float32Vector toFloat32Vector(float[] floats) {
        return new Float32Vector(this.forcedNormalization ? Vectors.normalize((float[])floats) : floats);
    }

    String getMapName() {
        return this.mapName;
    }

    public VectorStoreObservationContext.Builder createObservationContextBuilder(String operationName) {
        return VectorStoreObservationContext.builder((String)VectorStoreProvider.NEO4J.value(), (String)operationName).collectionName(this.mapName).dimensions(Integer.valueOf(this.embeddingModel.dimensions()));
    }

    public <T> Optional<T> getNativeClient() {
        Session client = this.session;
        return Optional.of(client);
    }

    public static class Builder
    extends AbstractVectorStoreBuilder<Builder> {
        private final Session session;
        private String mapName = "spring-ai-documents";
        private DistanceType distanceType = DEFAULT_DISTANCE_TYPE;
        private boolean forcedNormalization = false;
        private IndexType indexType = IndexType.NONE;

        private Builder(Session session, EmbeddingModel embeddingModel) {
            super(embeddingModel);
            Assert.notNull((Object)session, (String)"Session must not be null");
            this.session = session;
        }

        public Builder mapName(String mapName) {
            if (StringUtils.hasText((String)mapName)) {
                this.mapName = mapName;
            }
            return this;
        }

        public Builder distanceType(DistanceType distanceType) {
            Assert.notNull((Object)((Object)distanceType), (String)"DistanceType must not be null");
            this.distanceType = distanceType;
            return this;
        }

        public Builder forcedNormalization(boolean forcedNormalization) {
            this.forcedNormalization = forcedNormalization;
            return this;
        }

        public Builder indexType(IndexType indexType) {
            Assert.notNull((Object)((Object)indexType), (String)"IndexType must not be null");
            this.indexType = indexType;
            return this;
        }

        public CoherenceVectorStore build() {
            return new CoherenceVectorStore(this);
        }
    }

    public static enum DistanceType {
        COSINE,
        IP,
        L2;

    }

    public static enum IndexType {
        NONE,
        BINARY,
        HNSW;

    }
}

