/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.vectorstore;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.ObjectWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.nio.charset.StandardCharsets;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.embedding.EmbeddingClient;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.core.io.Resource;

public class SimpleVectorStore
implements VectorStore {
    private static final Logger logger = LoggerFactory.getLogger(SimpleVectorStore.class);
    protected Map<String, Document> store = new ConcurrentHashMap<String, Document>();
    protected EmbeddingClient embeddingClient;

    public SimpleVectorStore(EmbeddingClient embeddingClient) {
        Objects.requireNonNull(embeddingClient, "EmbeddingClient must not be null");
        this.embeddingClient = embeddingClient;
    }

    @Override
    public void add(List<Document> documents) {
        for (Document document : documents) {
            logger.info("Calling EmbeddingClient for document id = {}", (Object)document.getId());
            List<Double> embedding = this.embeddingClient.embed(document);
            document.setEmbedding(embedding);
            this.store.put(document.getId(), document);
        }
    }

    @Override
    public Optional<Boolean> delete(List<String> idList) {
        for (String id : idList) {
            this.store.remove(id);
        }
        return Optional.of(true);
    }

    @Override
    public List<Document> similaritySearch(SearchRequest request) {
        if (request.getFilterExpression() != null) {
            throw new UnsupportedOperationException("The [" + String.valueOf(this.getClass()) + "] doesn't support metadata filtering!");
        }
        List<Double> userQueryEmbedding = this.getUserQueryEmbedding(request.getQuery());
        return this.store.values().stream().map(entry -> new Similarity(entry.getId(), EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding()))).filter(s -> s.score >= request.getSimilarityThreshold()).sorted(Comparator.comparingDouble(s -> s.score).reversed()).limit(request.getTopK()).map(s -> this.store.get(s.key)).toList();
    }

    public void save(File file) {
        String json = this.getVectorDbAsJson();
        try {
            if (!file.exists()) {
                logger.info("Creating new vector store file: {}", (Object)file);
                file.createNewFile();
            } else {
                logger.info("Overwriting existing vector store file: {}", (Object)file);
            }
            try (FileOutputStream stream = new FileOutputStream(file);
                 OutputStreamWriter writer = new OutputStreamWriter((OutputStream)stream, StandardCharsets.UTF_8);){
                writer.write(json);
                ((Writer)writer).flush();
            }
        }
        catch (IOException ex) {
            logger.error("IOException occurred while saving vector store file.", (Throwable)ex);
            throw new RuntimeException(ex);
        }
        catch (SecurityException ex) {
            logger.error("SecurityException occurred while saving vector store file.", (Throwable)ex);
            throw new RuntimeException(ex);
        }
        catch (NullPointerException ex) {
            logger.error("NullPointerException occurred while saving vector store file.", (Throwable)ex);
            throw new RuntimeException(ex);
        }
    }

    public void load(File file) {
        TypeReference<HashMap<String, Document>> typeRef = new TypeReference<HashMap<String, Document>>(){};
        ObjectMapper objectMapper = new ObjectMapper();
        try {
            Map deserializedMap;
            this.store = deserializedMap = (Map)objectMapper.readValue(file, (TypeReference)typeRef);
        }
        catch (IOException ex) {
            throw new RuntimeException(ex);
        }
    }

    public void load(Resource resource) {
        TypeReference<HashMap<String, Document>> typeRef = new TypeReference<HashMap<String, Document>>(){};
        ObjectMapper objectMapper = new ObjectMapper();
        try {
            Map deserializedMap;
            this.store = deserializedMap = (Map)objectMapper.readValue(resource.getInputStream(), (TypeReference)typeRef);
        }
        catch (IOException ex) {
            throw new RuntimeException(ex);
        }
    }

    private String getVectorDbAsJson() {
        String json;
        ObjectMapper objectMapper = new ObjectMapper();
        ObjectWriter objectWriter = objectMapper.writerWithDefaultPrettyPrinter();
        try {
            json = objectWriter.writeValueAsString(this.store);
        }
        catch (JsonProcessingException e) {
            throw new RuntimeException("Error serializing documentMap to JSON.", e);
        }
        return json;
    }

    private List<Double> getUserQueryEmbedding(String query) {
        return this.embeddingClient.embed(query);
    }

    public static class Similarity {
        private String key;
        private double score;

        public Similarity(String key, double score) {
            this.key = key;
            this.score = score;
        }
    }

    public class EmbeddingMath {
        private EmbeddingMath() {
            throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
        }

        public static double cosineSimilarity(List<Double> vectorX, List<Double> vectorY) {
            if (vectorX == null || vectorY == null) {
                throw new RuntimeException("Vectors must not be null");
            }
            if (vectorX.size() != vectorY.size()) {
                throw new IllegalArgumentException("Vectors lengths must be equal");
            }
            double dotProduct = EmbeddingMath.dotProduct(vectorX, vectorY);
            double normX = EmbeddingMath.norm(vectorX);
            double normY = EmbeddingMath.norm(vectorY);
            if (normX == 0.0 || normY == 0.0) {
                throw new IllegalArgumentException("Vectors cannot have zero norm");
            }
            return dotProduct / (Math.sqrt(normX) * Math.sqrt(normY));
        }

        public static double dotProduct(List<Double> vectorX, List<Double> vectorY) {
            if (vectorX.size() != vectorY.size()) {
                throw new IllegalArgumentException("Vectors lengths must be equal");
            }
            double result = 0.0;
            for (int i = 0; i < vectorX.size(); ++i) {
                result += vectorX.get(i) * vectorY.get(i);
            }
            return result;
        }

        public static double norm(List<Double> vector) {
            return EmbeddingMath.dotProduct(vector, vector);
        }
    }
}

