/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.io.snowflake.services;

import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import org.apache.beam.sdk.io.snowflake.data.SnowflakeTableSchema;
import org.apache.beam.sdk.io.snowflake.enums.CreateDisposition;
import org.apache.beam.sdk.io.snowflake.enums.WriteDisposition;
import org.apache.beam.sdk.io.snowflake.services.SnowflakeBatchServiceConfig;
import org.apache.beam.sdk.io.snowflake.services.SnowflakeService;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SnowflakeBatchServiceImpl
implements SnowflakeService<SnowflakeBatchServiceConfig> {
    private static final Logger LOG = LoggerFactory.getLogger(SnowflakeBatchServiceImpl.class);
    private static final String SNOWFLAKE_GCS_PREFIX = "gcs://";
    private static final String GCS_PREFIX = "gs://";

    @Override
    public void write(SnowflakeBatchServiceConfig config) throws Exception {
        this.copyToTable(config);
    }

    @Override
    public String read(SnowflakeBatchServiceConfig config) throws Exception {
        return this.copyIntoStage(config);
    }

    private String copyIntoStage(SnowflakeBatchServiceConfig config) throws SQLException {
        SerializableFunction<Void, DataSource> dataSourceProviderFn = config.getDataSourceProviderFn();
        String database = config.getDatabase();
        String schema = config.getSchema();
        String table = config.getTable();
        String query = config.getQuery();
        String storageIntegrationName = config.getStorageIntegrationName();
        String stagingBucketDir = config.getStagingBucketDir();
        String source = query != null ? String.format("(%s)", query) : this.getTablePath(database, schema, table);
        String copyQuery = String.format("COPY INTO '%s' FROM %s STORAGE_INTEGRATION=%s FILE_FORMAT=(TYPE=CSV COMPRESSION=GZIP FIELD_OPTIONALLY_ENCLOSED_BY='%s');", this.getProperBucketDir(stagingBucketDir), source, storageIntegrationName, this.getASCIICharRepresentation(config.getQuotationMark()));
        SnowflakeBatchServiceImpl.runStatement(copyQuery, this.getConnection(dataSourceProviderFn), null);
        return stagingBucketDir.concat("*");
    }

    private String getASCIICharRepresentation(String input) {
        return String.format("0x%x", new BigInteger(1, input.getBytes(StandardCharsets.UTF_8)));
    }

    private void copyToTable(SnowflakeBatchServiceConfig config) throws SQLException {
        SerializableFunction<Void, DataSource> dataSourceProviderFn = config.getDataSourceProviderFn();
        List<String> filesList = config.getFilesList();
        String database = config.getDatabase();
        String schema = config.getSchema();
        String table = config.getTable();
        String query = config.getQuery();
        SnowflakeTableSchema tableSchema = config.getTableSchema();
        CreateDisposition createDisposition = config.getCreateDisposition();
        WriteDisposition writeDisposition = config.getWriteDisposition();
        String storageIntegrationName = config.getStorageIntegrationName();
        String stagingBucketDir = config.getStagingBucketDir();
        String source = query != null ? String.format("(%s)", query) : String.format("'%s'", stagingBucketDir);
        filesList = filesList.stream().map(e -> String.format("'%s'", e)).collect(Collectors.toList());
        String files = String.join((CharSequence)", ", filesList);
        files = files.replaceAll(stagingBucketDir, "");
        DataSource dataSource = (DataSource)dataSourceProviderFn.apply(null);
        this.prepareTableAccordingCreateDisposition(dataSource, table, tableSchema, createDisposition);
        this.prepareTableAccordingWriteDisposition(dataSource, table, writeDisposition);
        query = !storageIntegrationName.isEmpty() ? String.format("COPY INTO %s FROM %s FILES=(%s) FILE_FORMAT=(TYPE=CSV FIELD_OPTIONALLY_ENCLOSED_BY='%s' COMPRESSION=GZIP) STORAGE_INTEGRATION=%s;", this.getTablePath(database, schema, table), this.getProperBucketDir(source), files, this.getASCIICharRepresentation(config.getQuotationMark()), storageIntegrationName) : String.format("COPY INTO %s FROM %s FILES=(%s) FILE_FORMAT=(TYPE=CSV FIELD_OPTIONALLY_ENCLOSED_BY='%s' COMPRESSION=GZIP);", table, source, files, this.getASCIICharRepresentation(config.getQuotationMark()));
        SnowflakeBatchServiceImpl.runStatement(query, dataSource.getConnection(), null);
    }

    private void truncateTable(DataSource dataSource, String tablePath) throws SQLException {
        String query = String.format("TRUNCATE %s;", tablePath);
        SnowflakeBatchServiceImpl.runConnectionWithStatement(dataSource, query, null);
    }

    private static void checkIfTableIsEmpty(DataSource dataSource, String tablePath) throws SQLException {
        String selectQuery = String.format("SELECT count(*) FROM %s LIMIT 1;", tablePath);
        SnowflakeBatchServiceImpl.runConnectionWithStatement(dataSource, selectQuery, resultSet -> {
            assert (resultSet != null);
            SnowflakeBatchServiceImpl.checkIfTableIsEmpty((ResultSet)resultSet);
        });
    }

    private static void checkIfTableIsEmpty(ResultSet resultSet) {
        int columnId = 1;
        try {
            if (!resultSet.next() || !SnowflakeBatchServiceImpl.checkIfTableIsEmpty(resultSet, columnId)) {
                throw new RuntimeException("Table is not empty. Aborting COPY with disposition EMPTY");
            }
        }
        catch (SQLException e) {
            throw new RuntimeException("Unable run pipeline with EMPTY disposition.", e);
        }
    }

    private static boolean checkIfTableIsEmpty(ResultSet resultSet, int columnId) throws SQLException {
        int rowCount = resultSet.getInt(columnId);
        return rowCount < 1;
    }

    private void prepareTableAccordingCreateDisposition(DataSource dataSource, String table, SnowflakeTableSchema tableSchema, CreateDisposition createDisposition) throws SQLException {
        switch (createDisposition) {
            case CREATE_NEVER: {
                break;
            }
            case CREATE_IF_NEEDED: {
                this.createTableIfNotExists(dataSource, table, tableSchema);
            }
        }
    }

    private void prepareTableAccordingWriteDisposition(DataSource dataSource, String table, WriteDisposition writeDisposition) throws SQLException {
        switch (writeDisposition) {
            case TRUNCATE: {
                this.truncateTable(dataSource, table);
                break;
            }
            case EMPTY: {
                SnowflakeBatchServiceImpl.checkIfTableIsEmpty(dataSource, table);
                break;
            }
        }
    }

    private void createTableIfNotExists(DataSource dataSource, String table, SnowflakeTableSchema tableSchema) throws SQLException {
        String query = String.format("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = '%s');", table.toUpperCase());
        SnowflakeBatchServiceImpl.runConnectionWithStatement(dataSource, query, resultSet -> {
            assert (resultSet != null);
            if (!SnowflakeBatchServiceImpl.checkResultIfTableExists((ResultSet)resultSet)) {
                try {
                    this.createTable(dataSource, table, tableSchema);
                }
                catch (SQLException e) {
                    throw new RuntimeException("Unable to create table.", e);
                }
            }
        });
    }

    private static boolean checkResultIfTableExists(ResultSet resultSet) {
        try {
            if (resultSet.next()) {
                return SnowflakeBatchServiceImpl.checkIfResultIsTrue(resultSet);
            }
            throw new RuntimeException("Unable run pipeline with CREATE IF NEEDED - no response.");
        }
        catch (SQLException e) {
            throw new RuntimeException("Unable run pipeline with CREATE IF NEEDED disposition.", e);
        }
    }

    private void createTable(DataSource dataSource, String table, SnowflakeTableSchema tableSchema) throws SQLException {
        Preconditions.checkArgument((tableSchema != null ? 1 : 0) != 0, (Object)"The CREATE_IF_NEEDED disposition requires schema if table doesn't exists");
        String query = String.format("CREATE TABLE %s (%s);", table, tableSchema.sql());
        SnowflakeBatchServiceImpl.runConnectionWithStatement(dataSource, query, null);
    }

    private static boolean checkIfResultIsTrue(ResultSet resultSet) throws SQLException {
        int columnId = 1;
        return resultSet.getBoolean(columnId);
    }

    private static void runConnectionWithStatement(DataSource dataSource, String query, Consumer resultSetMethod) throws SQLException {
        Connection connection = dataSource.getConnection();
        SnowflakeBatchServiceImpl.runStatement(query, connection, resultSetMethod);
        connection.close();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void runStatement(String query, Connection connection, Consumer resultSetMethod) throws SQLException {
        PreparedStatement statement = connection.prepareStatement(query);
        try {
            if (resultSetMethod != null) {
                ResultSet resultSet = statement.executeQuery();
                resultSetMethod.accept(resultSet);
            } else {
                statement.execute();
            }
        }
        finally {
            statement.close();
            connection.close();
        }
    }

    private Connection getConnection(SerializableFunction<Void, DataSource> dataSourceProviderFn) throws SQLException {
        DataSource dataSource = (DataSource)dataSourceProviderFn.apply(null);
        return dataSource.getConnection();
    }

    private String getProperBucketDir(String bucketDir) {
        if (bucketDir.contains(GCS_PREFIX)) {
            return bucketDir.replace(GCS_PREFIX, SNOWFLAKE_GCS_PREFIX);
        }
        return bucketDir;
    }

    private String getTablePath(String database, String schema, String table) {
        return String.format("%s.%s.%s", database, schema, table);
    }
}

