/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.model.openai;

import com.openai.client.OpenAIClientAsync;
import com.openai.core.http.Headers;
import com.openai.errors.OpenAIServiceException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import org.apache.flink.configuration.DescribedEnum;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.configuration.description.InlineElement;
import org.apache.flink.configuration.description.TextElement;
import org.apache.flink.model.openai.ContextOverflowAction;
import org.apache.flink.model.openai.OpenAIOptions;
import org.apache.flink.model.openai.OpenAIUtils;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.catalog.Column;
import org.apache.flink.table.catalog.ResolvedSchema;
import org.apache.flink.table.data.GenericArrayData;
import org.apache.flink.table.data.GenericMapData;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.binary.BinaryStringData;
import org.apache.flink.table.factories.ModelProviderFactory;
import org.apache.flink.table.functions.AsyncPredictFunction;
import org.apache.flink.table.functions.FunctionContext;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.VarCharType;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class AbstractOpenAIModelFunction
extends AsyncPredictFunction {
    private static final Logger LOG = LoggerFactory.getLogger(AbstractOpenAIModelFunction.class);
    protected transient OpenAIClientAsync client;
    private final ErrorHandlingStrategy errorHandlingStrategy;
    private final int numRetry;
    private final RetryFallbackStrategy retryFallbackStrategy;
    private final String baseUrl;
    private final String apiKey;
    private final String model;
    @Nullable
    private final Integer maxContextSize;
    private final ContextOverflowAction contextOverflowAction;
    protected final List<String> outputColumnNames;

    public AbstractOpenAIModelFunction(ModelProviderFactory.Context factoryContext, ReadableConfig config) {
        String endpoint = (String)config.get(OpenAIOptions.ENDPOINT);
        this.baseUrl = endpoint.replaceAll(String.format("/%s/*$", this.getEndpointSuffix()), "");
        this.apiKey = (String)config.get(OpenAIOptions.API_KEY);
        this.errorHandlingStrategy = (ErrorHandlingStrategy)((Object)config.get(OpenAIOptions.ERROR_HANDLING_STRATEGY));
        this.numRetry = this.errorHandlingStrategy == ErrorHandlingStrategy.RETRY ? (Integer)config.get(OpenAIOptions.RETRY_NUM) : 0;
        this.model = (String)config.get(OpenAIOptions.MODEL);
        this.maxContextSize = (Integer)config.get(OpenAIOptions.MAX_CONTEXT_SIZE);
        this.contextOverflowAction = (ContextOverflowAction)((Object)config.get(OpenAIOptions.CONTEXT_OVERFLOW_ACTION));
        this.retryFallbackStrategy = (RetryFallbackStrategy)((Object)config.get(OpenAIOptions.RETRY_FALLBACK_STRATEGY));
        this.validateSingleColumnSchema(factoryContext.getCatalogModel().getResolvedInputSchema(), (LogicalType)new VarCharType(Integer.MAX_VALUE), "input");
        this.outputColumnNames = factoryContext.getCatalogModel().getResolvedOutputSchema().getColumnNames();
    }

    public void open(FunctionContext context) throws Exception {
        super.open(context);
        LOG.debug("Creating an OpenAI client.");
        this.client = OpenAIUtils.createAsyncClient(this.baseUrl, this.apiKey, this.numRetry);
        this.contextOverflowAction.initializeEncodingForContextLimit(this.model, this.maxContextSize);
    }

    public CompletableFuture<Collection<RowData>> asyncPredict(RowData rowData) {
        if (rowData.isNullAt(0)) {
            LOG.warn("Input is null, skipping prediction.");
            return CompletableFuture.completedFuture(Collections.emptyList());
        }
        String input2 = this.contextOverflowAction.processTokensWithLimit(this.model, rowData.getString(0).toString(), this.maxContextSize);
        if (input2 == null) {
            return CompletableFuture.completedFuture(Collections.emptyList());
        }
        return this.asyncPredictInternal(input2);
    }

    public void close() throws Exception {
        super.close();
        if (this.client != null) {
            LOG.debug("Releasing the OpenAI client.");
            OpenAIUtils.releaseAsyncClient(this.baseUrl, this.apiKey);
            this.client = null;
        }
    }

    protected abstract String getEndpointSuffix();

    protected abstract CompletableFuture<Collection<RowData>> asyncPredictInternal(String var1);

    protected void validateSingleColumnSchema(ResolvedSchema schema, LogicalType expectedType, String inputOrOutput) {
        List columns = schema.getColumns();
        List physicalColumnNames = columns.stream().filter(Column::isPhysical).map(Column::getName).collect(Collectors.toList());
        if (physicalColumnNames.size() != 1) {
            throw new IllegalArgumentException(String.format("Model should have exactly one %s physical column, but actually has %s physical columns: %s", inputOrOutput, physicalColumnNames.size(), physicalColumnNames));
        }
        Column column = (Column)schema.getColumn((String)physicalColumnNames.get(0)).get();
        if (!expectedType.equals((Object)column.getDataType().getLogicalType())) {
            throw new IllegalArgumentException(String.format("%s column %s should be %s, but is a %s.", inputOrOutput, column.getName(), expectedType, column.getDataType().getLogicalType()));
        }
        List metadataColumns = columns.stream().filter(x -> x instanceof Column.MetadataColumn).collect(Collectors.toList());
        if (!metadataColumns.isEmpty()) {
            Preconditions.checkArgument((boolean)"output".equals(inputOrOutput), (Object)"Only output schema supports metadata column");
            for (Column metadataColumn : metadataColumns) {
                ErrorMessageMetadata errorMessageMetadata = ErrorMessageMetadata.get(metadataColumn.getName());
                Preconditions.checkNotNull((Object)((Object)errorMessageMetadata), (String)String.format("Unexpected metadata column %s. Supported metadata columns:\n%s", metadataColumn.getName(), ErrorMessageMetadata.getAllKeysAndDescriptions()));
                Preconditions.checkArgument((boolean)errorMessageMetadata.dataType.equals((Object)metadataColumn.getDataType()), (Object)String.format("Expected metadata column %s to be of type %s, but is of type %s", metadataColumn.getName(), errorMessageMetadata.dataType, metadataColumn.getDataType()));
            }
        }
    }

    protected Collection<RowData> handleErrorsAndRespond(Throwable t) {
        ErrorHandlingStrategy finalErrorHandlingStrategy;
        ErrorHandlingStrategy errorHandlingStrategy = finalErrorHandlingStrategy = this.errorHandlingStrategy == ErrorHandlingStrategy.RETRY ? this.retryFallbackStrategy.strategy : this.errorHandlingStrategy;
        if (finalErrorHandlingStrategy == ErrorHandlingStrategy.FAILOVER) {
            throw new RuntimeException(t);
        }
        if (finalErrorHandlingStrategy == ErrorHandlingStrategy.IGNORE) {
            LOG.warn("The input row data failed to acquire a valid response. Ignoring the input.", t);
            GenericRowData rowData = new GenericRowData(this.outputColumnNames.size());
            boolean isMetadataSet = false;
            for (int i = 0; i < this.outputColumnNames.size(); ++i) {
                String columnName = this.outputColumnNames.get(i);
                ErrorMessageMetadata errorMessageMetadata = ErrorMessageMetadata.get(columnName);
                if (errorMessageMetadata == null) continue;
                rowData.setField(i, errorMessageMetadata.converter.apply(t));
                isMetadataSet = true;
            }
            return isMetadataSet ? Collections.singletonList(rowData) : Collections.emptyList();
        }
        throw new UnsupportedOperationException("Unsupported error handling strategy: " + String.valueOf((Object)finalErrorHandlingStrategy));
    }

    protected static enum ErrorMessageMetadata {
        ERROR_STRING("error-string", DataTypes.STRING(), x -> BinaryStringData.fromString((String)x.getMessage()), "A message associated with the error"),
        HTTP_STATUS_CODE("http-status-code", DataTypes.INT(), e -> ExceptionUtils.findThrowable((Throwable)e, OpenAIServiceException.class).map(OpenAIServiceException::statusCode).orElse(null), "The HTTP status code"),
        HTTP_HEADERS_MAP("http-headers-map", DataTypes.MAP((DataType)DataTypes.STRING(), (DataType)DataTypes.ARRAY((DataType)DataTypes.STRING())), e -> ExceptionUtils.findThrowable((Throwable)e, OpenAIServiceException.class).map(e1 -> {
            HashMap<BinaryStringData, GenericArrayData> map2 = new HashMap<BinaryStringData, GenericArrayData>();
            Headers headers = e1.headers();
            for (String name : headers.names()) {
                map2.put(BinaryStringData.fromString((String)name), new GenericArrayData(headers.values(name).stream().map(BinaryStringData::fromString).toArray()));
            }
            return new GenericMapData(map2);
        }).orElse(null), "The headers returned with the response");

        final String key;
        final DataType dataType;
        final Function<Throwable, Object> converter;
        final String description;

        private ErrorMessageMetadata(String key, DataType dataType, Function<Throwable, Object> converter, String description2) {
            this.key = key;
            this.dataType = dataType;
            this.converter = converter;
            this.description = description2;
        }

        @Nullable
        static ErrorMessageMetadata get(String key) {
            for (ErrorMessageMetadata value : ErrorMessageMetadata.values()) {
                if (!value.key.equals(key)) continue;
                return value;
            }
            return null;
        }

        static String getAllKeysAndDescriptions() {
            return Arrays.stream(ErrorMessageMetadata.values()).map(value -> value.key + ":\t" + value.description).collect(Collectors.joining("\n"));
        }
    }

    public static enum RetryFallbackStrategy implements DescribedEnum
    {
        FAILOVER(ErrorHandlingStrategy.FAILOVER),
        IGNORE(ErrorHandlingStrategy.IGNORE);

        private final ErrorHandlingStrategy strategy;

        private RetryFallbackStrategy(ErrorHandlingStrategy strategy) {
            this.strategy = strategy;
        }

        public InlineElement getDescription() {
            return TextElement.text((String)this.strategy.description);
        }
    }

    public static enum ErrorHandlingStrategy implements DescribedEnum
    {
        RETRY("Retry sending the request."),
        FAILOVER("Throw exceptions and fail the Flink job."),
        IGNORE("Ignore the input that caused the error and continue. The error itself would be recorded in log.");

        private final String description;

        private ErrorHandlingStrategy(String description2) {
            this.description = description2;
        }

        public InlineElement getDescription() {
            return TextElement.text((String)this.description);
        }
    }
}

