/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.rag.content.retriever;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.store.embedding.EmbeddingMatch;
import dev.langchain4j.store.embedding.EmbeddingStore;
import java.util.List;
import java.util.stream.Collectors;

public class EmbeddingStoreContentRetriever
implements ContentRetriever {
    public static final int DEFAULT_MAX_RESULTS = 3;
    public static final double DEFAULT_MIN_SCORE = 0.0;
    private final EmbeddingStore<TextSegment> embeddingStore;
    private final EmbeddingModel embeddingModel;
    private final int maxResults;
    private final double minScore;

    public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore, EmbeddingModel embeddingModel) {
        this(embeddingStore, embeddingModel, 3, 0.0);
    }

    public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore, EmbeddingModel embeddingModel, int maxResults) {
        this(embeddingStore, embeddingModel, maxResults, 0.0);
    }

    public EmbeddingStoreContentRetriever(EmbeddingStore<TextSegment> embeddingStore, EmbeddingModel embeddingModel, Integer maxResults, Double minScore) {
        this.embeddingStore = ValidationUtils.ensureNotNull(embeddingStore, "embeddingStore");
        this.embeddingModel = ValidationUtils.ensureNotNull(embeddingModel, "embeddingModel");
        this.maxResults = ValidationUtils.ensureGreaterThanZero(Utils.getOrDefault(maxResults, 3), "maxResults");
        this.minScore = ValidationUtils.ensureBetween(Utils.getOrDefault(minScore, 0.0), 0.0, 1.0, "minScore");
    }

    @Override
    public List<Content> retrieve(Query query) {
        Embedding embeddedText = this.embeddingModel.embed(query.text()).content();
        List<EmbeddingMatch<TextSegment>> relevant = this.embeddingStore.findRelevant(embeddedText, this.maxResults, this.minScore);
        return relevant.stream().map(EmbeddingMatch::embedded).map(Content::from).collect(Collectors.toList());
    }

    public static EmbeddingStoreContentRetrieverBuilder builder() {
        return new EmbeddingStoreContentRetrieverBuilder();
    }

    public static class EmbeddingStoreContentRetrieverBuilder {
        private EmbeddingStore<TextSegment> embeddingStore;
        private EmbeddingModel embeddingModel;
        private Integer maxResults;
        private Double minScore;

        EmbeddingStoreContentRetrieverBuilder() {
        }

        public EmbeddingStoreContentRetrieverBuilder embeddingStore(EmbeddingStore<TextSegment> embeddingStore) {
            this.embeddingStore = embeddingStore;
            return this;
        }

        public EmbeddingStoreContentRetrieverBuilder embeddingModel(EmbeddingModel embeddingModel) {
            this.embeddingModel = embeddingModel;
            return this;
        }

        public EmbeddingStoreContentRetrieverBuilder maxResults(Integer maxResults) {
            this.maxResults = maxResults;
            return this;
        }

        public EmbeddingStoreContentRetrieverBuilder minScore(Double minScore) {
            this.minScore = minScore;
            return this;
        }

        public EmbeddingStoreContentRetriever build() {
            return new EmbeddingStoreContentRetriever(this.embeddingStore, this.embeddingModel, this.maxResults, this.minScore);
        }

        public String toString() {
            return "EmbeddingStoreContentRetriever.EmbeddingStoreContentRetrieverBuilder(embeddingStore=" + this.embeddingStore + ", embeddingModel=" + this.embeddingModel + ", maxResults=" + this.maxResults + ", minScore=" + this.minScore + ")";
        }
    }
}

