/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.model.bedrock;

import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.internal.RetryUtils;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.model.bedrock.BedrockCohereEmbeddingResponse;
import dev.langchain4j.model.bedrock.internal.Json;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.auth.credentials.DefaultCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClientBuilder;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelRequest;
import software.amazon.awssdk.services.bedrockruntime.model.InvokeModelResponse;

public class BedrockCohereEmbeddingModel
implements EmbeddingModel {
    private final BedrockRuntimeClient client;
    private final String model;
    private final String inputType;
    private final String truncate;
    private final int maxRetries;

    public BedrockCohereEmbeddingModel(Builder builder) {
        this.client = (BedrockRuntimeClient)Utils.getOrDefault((Object)builder.client, () -> this.initClient(builder));
        this.model = ValidationUtils.ensureNotBlank((String)builder.model, (String)"model");
        this.inputType = ValidationUtils.ensureNotBlank((String)builder.inputType, (String)"inputType");
        this.truncate = builder.truncate;
        this.maxRetries = (Integer)Utils.getOrDefault((Object)builder.maxRetries, (Object)3);
    }

    private BedrockRuntimeClient initClient(Builder builder) {
        return (BedrockRuntimeClient)((BedrockRuntimeClientBuilder)((BedrockRuntimeClientBuilder)BedrockRuntimeClient.builder().region((Region)Utils.getOrDefault((Object)builder.region, (Object)Region.US_EAST_1))).credentialsProvider((AwsCredentialsProvider)Utils.getOrDefault((Object)builder.credentialsProvider, () -> DefaultCredentialsProvider.builder().build()))).build();
    }

    public Response<List<Embedding>> embedAll(List<TextSegment> textSegments) {
        Map<String, Object> requestParameters = this.toRequestParameters(textSegments);
        String requestJson = Json.toJson(requestParameters);
        InvokeModelResponse invokeModelResponse = (InvokeModelResponse)RetryUtils.withRetryMappingExceptions(() -> this.invoke(requestJson), (int)this.maxRetries);
        String responseJson = invokeModelResponse.body().asUtf8String();
        BedrockCohereEmbeddingResponse embeddingResponse = Json.fromJson(responseJson, BedrockCohereEmbeddingResponse.class);
        List embeddings = Arrays.stream(embeddingResponse.getEmbeddings().getFloatEmbeddings()).map(Embedding::from).collect(Collectors.toList());
        return Response.from(embeddings);
    }

    private Map<String, Object> toRequestParameters(List<TextSegment> textSegments) {
        HashMap<String, Object> parameters = new HashMap<String, Object>();
        parameters.put("texts", textSegments.stream().map(TextSegment::text).collect(Collectors.toList()));
        parameters.put("input_type", this.inputType);
        parameters.put("truncate", this.truncate);
        parameters.put("embedding_types", List.of("float"));
        return parameters;
    }

    private InvokeModelResponse invoke(String body) {
        InvokeModelRequest invokeModelRequest = (InvokeModelRequest)InvokeModelRequest.builder().modelId(this.model).body(SdkBytes.fromString((String)body, (Charset)Charset.defaultCharset())).build();
        return this.client.invokeModel(invokeModelRequest);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static class Builder {
        private String model;
        private String inputType;
        private String truncate;
        private BedrockRuntimeClient client;
        private Region region;
        private AwsCredentialsProvider credentialsProvider;
        private Integer maxRetries;

        public Builder model(Model model) {
            return this.model(model.getValue());
        }

        public Builder model(String model) {
            this.model = model;
            return this;
        }

        public Builder inputType(InputType inputType) {
            return this.inputType(inputType.getValue());
        }

        public Builder inputType(String inputType) {
            this.inputType = inputType;
            return this;
        }

        public Builder truncate(Truncate truncate) {
            return this.truncate(truncate.getValue());
        }

        public Builder truncate(String truncate) {
            this.truncate = truncate;
            return this;
        }

        public Builder client(BedrockRuntimeClient client) {
            this.client = client;
            return this;
        }

        public Builder region(Region region) {
            this.region = region;
            return this;
        }

        public Builder credentialsProvider(AwsCredentialsProvider credentialsProvider) {
            this.credentialsProvider = credentialsProvider;
            return this;
        }

        public Builder maxRetries(Integer maxRetries) {
            this.maxRetries = maxRetries;
            return this;
        }

        public BedrockCohereEmbeddingModel build() {
            return new BedrockCohereEmbeddingModel(this);
        }
    }

    public static enum Truncate {
        NONE("NONE"),
        START("START"),
        END("END");

        private final String value;

        private Truncate(String value) {
            this.value = value;
        }

        public String getValue() {
            return this.value;
        }
    }

    public static enum InputType {
        SEARCH_DOCUMENT("search_document"),
        SEARCH_QUERY("search_query"),
        CLASSIFICATION("classification"),
        CLUSTERING("clustering");

        private final String value;

        private InputType(String value) {
            this.value = value;
        }

        public String getValue() {
            return this.value;
        }
    }

    public static enum Model {
        COHERE_EMBED_ENGLISH_V3("cohere.embed-english-v3"),
        COHERE_EMBED_MULTILINGUAL_V3("cohere.embed-multilingual-v3");

        private final String value;

        private Model(String value) {
            this.value = value;
        }

        public String getValue() {
            return this.value;
        }
    }
}

