/*
 * Decompiled with CFR 0.152.
 */
package org.neo4j.genai.vector.providers;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.UncheckedIOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.text.StringSubstitutor;
import org.eclipse.collections.api.factory.Maps;
import org.eclipse.collections.api.map.MutableMap;
import org.neo4j.genai.util.HttpService;
import org.neo4j.genai.util.JsonUtils;
import org.neo4j.genai.util.MalformedGenAIResponseException;
import org.neo4j.genai.vector.VectorEncoding;

public final class VertexAI
implements VectorEncoding.Provider<Parameters> {
    public static final String NAME = "VertexAI";
    private static final String ENDPOINT_TEMPLATE = "https://${region}-aiplatform.googleapis.com/v1/projects/${projectId}/locations/${region}/publishers/google/models/${model}:predict";
    static final String DEFAULT_REGION = "us-central1";
    static final Set<String> SUPPORTED_REGIONS = Set.of("us-west1", "us-west2", "us-west3", "us-west4", "us-central1", "us-east1", "us-east4", "us-south1", "northamerica-northeast1", "northamerica-northeast2", "southamerica-east1", "southamerica-west1", "europe-west2", "europe-west1", "europe-west4", "europe-west6", "europe-west3", "europe-north1", "europe-central2", "europe-west8", "europe-west9", "europe-southwest1", "asia-south1", "asia-southeast1", "asia-southeast2", "asia-east2", "asia-east1", "asia-northeast1", "asia-northeast2", "australia-southeast1", "australia-southeast2", "asia-northeast3", "me-west1");
    private static final String STRINGIFIED_SUPPORTED_REGIONS = SUPPORTED_REGIONS.stream().map(s -> "'" + s + "'").collect(Collectors.joining(", ", "[", "]"));
    static final String DEFAULT_MODEL = "textembedding-gecko@001";
    static final Set<String> SUPPORTED_MODELS = Set.of("textembedding-gecko@001", "textembedding-gecko@002", "textembedding-gecko@003", "textembedding-gecko-multilingual@001");
    private static final String STRINGIFIED_SUPPORTED_MODELS = SUPPORTED_MODELS.stream().map(s -> "'" + s + "'").collect(Collectors.joining(", ", "[", "]"));

    @Override
    public Class<Parameters> parameterDeclarations() {
        return Parameters.class;
    }

    @Override
    public String name() {
        return NAME;
    }

    @Override
    public VectorEncoding.Provider.Encoder configure(Parameters configuration) {
        if (!SUPPORTED_MODELS.contains(configuration.model)) {
            throw new IllegalArgumentException("Provided model '%s' is not supported. Supported models: %s".formatted(configuration.model, STRINGIFIED_SUPPORTED_MODELS));
        }
        if (!SUPPORTED_REGIONS.contains(configuration.region)) {
            throw new IllegalArgumentException("Provided region '%s' is not supported. Supported regions: %s".formatted(configuration.region, STRINGIFIED_SUPPORTED_REGIONS));
        }
        URI endpoint = URI.create(StringSubstitutor.replace((Object)ENDPOINT_TEMPLATE, Map.of("region", configuration.region, "projectId", configuration.projectId, "model", configuration.model)));
        return new Encoder(endpoint, configuration);
    }

    public static class Parameters {
        public String token;
        public String projectId;
        public String model = "textembedding-gecko@001";
        public String region = "us-central1";
        public Optional<String> taskType;
        public Optional<String> title;
    }

    record Encoder(URI endpoint, Parameters configuration) implements VectorEncoding.Provider.Encoder
    {
        @Override
        public float[] encode(HttpService httpService, String data) {
            return this.encode(httpService, List.of(data), ArrayUtils.EMPTY_INT_ARRAY).findFirst().orElseThrow().vector();
        }

        @Override
        public Stream<VectorEncoding.BatchRow> encode(HttpService httpService, List<String> resources, int[] nullIndexes) {
            return httpService.request(this.endpoint, builder -> builder.headers("Authorization", "Bearer " + this.configuration.token, "Content-Type", "application/json; charset=" + String.valueOf(StandardCharsets.UTF_8), "Accept", "application/json").POST(HttpService.pipe(outputStream -> this.writeRequestPayload((OutputStream)outputStream, resources))).build(), inputStream -> JsonUtils.parseResponse(VertexAI.NAME, "predictions", new String[]{"embeddings", "values"}, resources, inputStream, nullIndexes));
        }

        static Stream<VectorEncoding.BatchRow> parseResponse(List<String> resources, InputStream inputStream, int[] nullIndexes) throws MalformedGenAIResponseException {
            String[] properties = new String[]{"embeddings", "values"};
            return JsonUtils.parseResponse(VertexAI.NAME, "predictions", properties, resources, inputStream, nullIndexes);
        }

        private void writeRequestPayload(OutputStream out, List<String> resources) {
            try {
                JsonUtils.getObjectMapper().writeValue(out, Map.of("instances", resources.stream().map(resource -> {
                    MutableMap instance = Maps.mutable.of((Object)"content", resource);
                    this.configuration.taskType.ifPresent(x -> instance.put((Object)"task_type", x));
                    this.configuration.title.ifPresent(x -> instance.put((Object)"title", x));
                    return instance;
                }).toList()));
            }
            catch (IOException e) {
                throw new UncheckedIOException(e);
            }
        }
    }
}

