/*
 * Decompiled with CFR 0.152.
 */
package io.trino.plugin.functions.python;

import com.dylibso.chicory.log.Logger;
import com.dylibso.chicory.runtime.ExportFunction;
import com.dylibso.chicory.runtime.HostFunction;
import com.dylibso.chicory.runtime.ImportFunction;
import com.dylibso.chicory.runtime.ImportValues;
import com.dylibso.chicory.runtime.Instance;
import com.dylibso.chicory.runtime.Memory;
import com.dylibso.chicory.wasi.WasiOptions;
import com.dylibso.chicory.wasi.WasiPreview1;
import com.dylibso.chicory.wasm.ChicoryException;
import com.dylibso.chicory.wasm.WasmModule;
import com.dylibso.chicory.wasm.types.ValueType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.io.Closer;
import com.google.common.jimfs.Configuration;
import com.google.common.jimfs.Jimfs;
import io.airlift.slice.BasicSliceInput;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceInput;
import io.airlift.slice.Slices;
import io.airlift.units.DataSize;
import io.trino.plugin.functions.python.JdkLogger;
import io.trino.plugin.functions.python.LoggingOutputStream;
import io.trino.plugin.functions.python.TrinoTypes;
import io.trino.spi.ErrorCodeSupplier;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.type.Type;
import io.trino.wasm.python.PythonModule;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.io.OutputStream;
import java.io.UncheckedIOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.FileSystem;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Stream;

final class PythonEngine
implements Closeable {
    private static final io.airlift.log.Logger log = io.airlift.log.Logger.get(PythonEngine.class);
    private static final Logger logger = JdkLogger.get(PythonEngine.class);
    private static final Configuration FS_CONFIG = Configuration.unix().toBuilder().setAttributeViews("unix", new String[0]).setMaxSize(DataSize.of((long)8L, (DataSize.Unit)DataSize.Unit.MEGABYTE).toBytes()).build();
    private static final Map<Integer, ErrorCodeSupplier> ERROR_CODES = (Map)Stream.of(StandardErrorCode.values()).collect(ImmutableMap.toImmutableMap(error -> error.toErrorCode().getCode(), Function.identity()));
    private static final WasmModule PYTHON_MODULE = PythonModule.load();
    private final Closer closer = Closer.create();
    private final LimitedOutputStream stderr = new LimitedOutputStream();
    private final ExportFunction allocate;
    private final ExportFunction deallocate;
    private final ExportFunction setup;
    private final ExportFunction execute;
    private final Memory memory;
    private Type returnType;
    private List<Type> argumentTypes;
    private TrinoException error;

    public PythonEngine(String guestCode) {
        FileSystem fileSystem = (FileSystem)this.closer.register((Closeable)Jimfs.newFileSystem((Configuration)FS_CONFIG));
        Path guestRoot = fileSystem.getPath("/guest", new String[0]);
        try {
            Files.createDirectories(guestRoot, new FileAttribute[0]);
            Files.writeString(guestRoot.resolve("guest.py"), (CharSequence)guestCode, new OpenOption[0]);
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
        OutputStream stdout = (OutputStream)this.closer.register((Closeable)new LoggingOutputStream(log));
        WasiOptions wasiOptions = WasiOptions.builder().withStdout(stdout).withStderr((OutputStream)this.stderr).withDirectory(guestRoot.toString(), guestRoot).build();
        WasiPreview1 wasi = (WasiPreview1)this.closer.register((Closeable)new WasiPreview1(logger, wasiOptions));
        ImportValues importValues = ImportValues.builder().addFunction((ImportFunction[])wasi.toHostFunctions()).addFunction(new ImportFunction[]{this.returnErrorHostFunction()}).build();
        Instance instance = Instance.builder((WasmModule)PYTHON_MODULE).withMachineFactory(PythonModule::create).withImportValues(importValues).build();
        this.allocate = instance.export("allocate");
        this.deallocate = instance.export("deallocate");
        this.setup = instance.export("setup");
        this.execute = instance.export("execute");
        this.memory = instance.memory();
    }

    public void setup(Type returnType, List<Type> argumentTypes, String handlerName) {
        try {
            this.doSetup(returnType, argumentTypes, handlerName);
        }
        catch (ChicoryException e) {
            throw this.fatalError("Python error", e);
        }
    }

    private void doSetup(Type returnType, List<Type> argumentTypes, String handlerName) {
        byte[] nameBytes = handlerName.getBytes(StandardCharsets.UTF_8);
        int nameAddress = this.allocate(nameBytes.length + 1);
        this.memory.write(nameAddress, nameBytes);
        this.memory.writeByte(nameAddress + nameBytes.length, (byte)0);
        Slice argumentTypeSlice = TrinoTypes.toRowTypeDescriptor(argumentTypes);
        int argTypeAddress = this.allocate(argumentTypeSlice.length());
        this.writeSliceTo(argumentTypeSlice, argTypeAddress);
        Slice returnTypeSlice = TrinoTypes.toTypeDescriptor(returnType);
        int returnTypeAddress = this.allocate(returnTypeSlice.length());
        this.writeSliceTo(returnTypeSlice, returnTypeAddress);
        this.setup.apply(new long[]{nameAddress, argTypeAddress, returnTypeAddress});
        this.deallocate(nameAddress);
        this.returnType = Objects.requireNonNull(returnType, "returnType is null");
        this.argumentTypes = ImmutableList.copyOf((Collection)Objects.requireNonNull(argumentTypes, "argumentTypes is null"));
    }

    private void writeSliceTo(Slice slice, int address) {
        this.memory.write(address, slice.byteArray(), slice.byteArrayOffset(), slice.length());
    }

    private int allocate(int size) {
        return Math.toIntExact(this.allocate.apply(new long[]{size})[0]);
    }

    private void deallocate(int address) {
        this.deallocate.apply(new long[]{address});
    }

    private int execute(int address) {
        return Math.toIntExact(this.execute.apply(new long[]{address})[0]);
    }

    public Object execute(Object[] arguments) {
        int resultAddress;
        Slice slice = TrinoTypes.javaToBinary(this.argumentTypes, arguments);
        int argAddress = this.allocate(slice.length());
        this.writeSliceTo(slice, argAddress);
        this.error = null;
        try {
            resultAddress = this.execute(argAddress);
        }
        catch (ChicoryException e) {
            throw this.fatalError("Failed to invoke Python function", e);
        }
        this.deallocate(argAddress);
        if (this.error != null) {
            throw new TrinoException(() -> ((TrinoException)this.error).getErrorCode(), this.error.getMessage(), this.error.getCause());
        }
        if (resultAddress == 0) {
            throw new TrinoException((ErrorCodeSupplier)StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR, "Python function did not return a result");
        }
        int resultSize = this.memory.readInt(resultAddress);
        byte[] bytes = this.memory.readBytes(resultAddress + 4, resultSize);
        this.deallocate(resultAddress);
        BasicSliceInput input = new BasicSliceInput(Slices.wrappedBuffer((byte[])bytes));
        return TrinoTypes.binaryToJava(this.returnType, (SliceInput)input);
    }

    public TrinoException fatalError(String message, ChicoryException e) {
        String error = this.stderr.toString(StandardCharsets.UTF_8).strip();
        if (!error.isEmpty()) {
            message = (String)message + ":";
            message = (String)message + (error.contains("\n") ? "\n" : " ");
            message = (String)message + error;
        }
        return new TrinoException((ErrorCodeSupplier)StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR, (String)message, (Throwable)e);
    }

    @Override
    public void close() {
        try {
            this.closer.close();
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    private long[] returnError(Instance instance, long ... args) {
        ErrorCodeSupplier errorCode;
        int code = Math.toIntExact(args[0]);
        int messageAddress = Math.toIntExact(args[1]);
        int messageSize = Math.toIntExact(args[2]);
        int tracebackAddress = Math.toIntExact(args[3]);
        int tracebackSize = Math.toIntExact(args[4]);
        Memory memory = instance.memory();
        String message = memory.readString(messageAddress, messageSize);
        RuntimeException traceback = null;
        if (tracebackAddress != 0) {
            String value = memory.readString(tracebackAddress, tracebackSize);
            traceback = new RuntimeException("Python traceback:\n" + value.stripTrailing());
        }
        if ((errorCode = ERROR_CODES.get(code)) == null) {
            errorCode = StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR;
            message = "Unknown error code (%s): %s".formatted(code, message);
        }
        this.error = new TrinoException(errorCode, message, (Throwable)traceback);
        return null;
    }

    private HostFunction returnErrorHostFunction() {
        return new HostFunction("trino", "return_error", List.of(ValueType.I32, ValueType.I32, ValueType.I32, ValueType.I32, ValueType.I32), List.of(), this::returnError);
    }

    private static class LimitedOutputStream
    extends ByteArrayOutputStream {
        private static final int LIMIT = 4096;

        private LimitedOutputStream() {
        }

        @Override
        public void write(byte[] b, int off, int len) {
            if (this.count < 4096) {
                super.write(b, off, Math.min(len, 4096 - this.count));
            }
        }
    }
}

