/*
 * Decompiled with CFR 0.152.
 */
package org.springframework.ai.embedding;

import com.knuddels.jtokkit.api.EncodingType;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.springframework.ai.document.ContentFormatter;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.tokenizer.JTokkitTokenCountEstimator;
import org.springframework.ai.tokenizer.TokenCountEstimator;
import org.springframework.util.Assert;

public class TokenCountBatchingStrategy
implements BatchingStrategy {
    private static final int MAX_INPUT_TOKEN_COUNT = 8191;
    private static final double DEFAULT_TOKEN_COUNT_RESERVE_PERCENTAGE = 0.1;
    private final TokenCountEstimator tokenCountEstimator;
    private final int maxInputTokenCount;
    private final ContentFormatter contentFormatter;
    private final MetadataMode metadataMode;

    public TokenCountBatchingStrategy() {
        this(EncodingType.CL100K_BASE, 8191, 0.1);
    }

    public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double reservePercentage) {
        this(encodingType, maxInputTokenCount, reservePercentage, Document.DEFAULT_CONTENT_FORMATTER, MetadataMode.NONE);
    }

    public TokenCountBatchingStrategy(EncodingType encodingType, int maxInputTokenCount, double reservePercentage, ContentFormatter contentFormatter, MetadataMode metadataMode) {
        Assert.notNull((Object)encodingType, (String)"EncodingType must not be null");
        Assert.isTrue((maxInputTokenCount > 0 ? 1 : 0) != 0, (String)"MaxInputTokenCount must be greater than 0");
        Assert.isTrue((reservePercentage >= 0.0 && reservePercentage < 1.0 ? 1 : 0) != 0, (String)"ReservePercentage must be in range [0, 1)");
        Assert.notNull((Object)contentFormatter, (String)"ContentFormatter must not be null");
        Assert.notNull((Object)metadataMode, (String)"MetadataMode must not be null");
        this.tokenCountEstimator = new JTokkitTokenCountEstimator(encodingType);
        this.maxInputTokenCount = (int)Math.round((double)maxInputTokenCount * (1.0 - reservePercentage));
        this.contentFormatter = contentFormatter;
        this.metadataMode = metadataMode;
    }

    public TokenCountBatchingStrategy(TokenCountEstimator tokenCountEstimator, int maxInputTokenCount, double reservePercentage, ContentFormatter contentFormatter, MetadataMode metadataMode) {
        Assert.notNull((Object)tokenCountEstimator, (String)"TokenCountEstimator must not be null");
        Assert.isTrue((maxInputTokenCount > 0 ? 1 : 0) != 0, (String)"MaxInputTokenCount must be greater than 0");
        Assert.isTrue((reservePercentage >= 0.0 && reservePercentage < 1.0 ? 1 : 0) != 0, (String)"ReservePercentage must be in range [0, 1)");
        Assert.notNull((Object)contentFormatter, (String)"ContentFormatter must not be null");
        Assert.notNull((Object)metadataMode, (String)"MetadataMode must not be null");
        this.tokenCountEstimator = tokenCountEstimator;
        this.maxInputTokenCount = (int)Math.round((double)maxInputTokenCount * (1.0 - reservePercentage));
        this.contentFormatter = contentFormatter;
        this.metadataMode = metadataMode;
    }

    @Override
    public List<List<Document>> batch(List<Document> documents) {
        ArrayList<List<Document>> batches = new ArrayList<List<Document>>();
        int currentSize = 0;
        ArrayList<Document> currentBatch = new ArrayList<Document>();
        LinkedHashMap<Document, Integer> documentTokens = new LinkedHashMap<Document, Integer>();
        for (Document document : documents) {
            int tokenCount = this.tokenCountEstimator.estimate(document.getFormattedContent(this.contentFormatter, this.metadataMode));
            if (tokenCount > this.maxInputTokenCount) {
                throw new IllegalArgumentException("Tokens in a single document exceeds the maximum number of allowed input tokens");
            }
            documentTokens.put(document, tokenCount);
        }
        for (Map.Entry entry : documentTokens.entrySet()) {
            Document document = (Document)entry.getKey();
            if ((currentSize += ((Integer)entry.getValue()).intValue()) > this.maxInputTokenCount) {
                batches.add(currentBatch);
                currentBatch = new ArrayList();
                currentSize = 0;
            }
            currentBatch.add(document);
        }
        if (!currentBatch.isEmpty()) {
            batches.add(currentBatch);
        }
        return batches;
    }
}

