/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.io.snowflake.services;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import org.apache.beam.sdk.io.snowflake.enums.CloudProvider;
import org.apache.beam.sdk.io.snowflake.enums.WriteDisposition;
import org.apache.beam.sdk.io.snowflake.services.SnowflakeService;
import org.apache.beam.sdk.io.snowflake.services.SnowflakeServiceConfig;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SnowflakeServiceImpl
implements SnowflakeService<SnowflakeServiceConfig> {
    private static final Logger LOG = LoggerFactory.getLogger(SnowflakeServiceImpl.class);
    private static final String SNOWFLAKE_GCS_PREFIX = "gcs://";

    @Override
    public void write(SnowflakeServiceConfig config) throws Exception {
        this.copyToTable(config);
    }

    @Override
    public String read(SnowflakeServiceConfig config) throws Exception {
        return this.copyIntoStage(config);
    }

    public String copyIntoStage(SnowflakeServiceConfig config) throws SQLException {
        SerializableFunction<Void, DataSource> dataSourceProviderFn = config.getDataSourceProviderFn();
        String table = config.getTable();
        String query = config.getQuery();
        String storageIntegrationName = config.getstorageIntegrationName();
        String stagingBucketDir = config.getStagingBucketDir();
        String source = query != null ? String.format("(%s)", query) : table;
        String copyQuery = String.format("COPY INTO '%s' FROM %s STORAGE_INTEGRATION=%s FILE_FORMAT=(TYPE=CSV COMPRESSION=GZIP FIELD_OPTIONALLY_ENCLOSED_BY='%s');", this.getProperBucketDir(stagingBucketDir), source, storageIntegrationName, "''");
        SnowflakeServiceImpl.runStatement(copyQuery, this.getConnection(dataSourceProviderFn), null);
        return stagingBucketDir.concat("*");
    }

    public void copyToTable(SnowflakeServiceConfig config) throws SQLException {
        SerializableFunction<Void, DataSource> dataSourceProviderFn = config.getDataSourceProviderFn();
        List<String> filesList = config.getFilesList();
        String table = config.getTable();
        String query = config.getQuery();
        WriteDisposition writeDisposition = config.getWriteDisposition();
        String storageIntegrationName = config.getstorageIntegrationName();
        String stagingBucketDir = config.getStagingBucketDir();
        String source = query != null ? String.format("(%s)", query) : String.format("'%s'", stagingBucketDir);
        filesList = filesList.stream().map(e -> String.format("'%s'", e)).collect(Collectors.toList());
        String files = String.join((CharSequence)", ", filesList);
        files = files.replaceAll(stagingBucketDir, "");
        DataSource dataSource = (DataSource)dataSourceProviderFn.apply(null);
        this.prepareTableAccordingWriteDisposition(dataSource, table, writeDisposition);
        query = !storageIntegrationName.isEmpty() ? String.format("COPY INTO %s FROM %s FILES=(%s) FILE_FORMAT=(TYPE=CSV FIELD_OPTIONALLY_ENCLOSED_BY='%s' COMPRESSION=GZIP) STORAGE_INTEGRATION=%s;", table, this.getProperBucketDir(source), files, "''", storageIntegrationName) : String.format("COPY INTO %s FROM %s FILES=(%s) FILE_FORMAT=(TYPE=CSV FIELD_OPTIONALLY_ENCLOSED_BY='%s' COMPRESSION=GZIP);", table, source, files, "''");
        SnowflakeServiceImpl.runStatement(query, dataSource.getConnection(), null);
    }

    private void truncateTable(DataSource dataSource, String table) throws SQLException {
        String query = String.format("TRUNCATE %s;", table);
        SnowflakeServiceImpl.runConnectionWithStatement(dataSource, query, null);
    }

    private static void checkIfTableIsEmpty(DataSource dataSource, String table) throws SQLException {
        String selectQuery = String.format("SELECT count(*) FROM %s LIMIT 1;", table);
        SnowflakeServiceImpl.runConnectionWithStatement(dataSource, selectQuery, resultSet -> {
            assert (resultSet != null);
            SnowflakeServiceImpl.checkIfTableIsEmpty((ResultSet)resultSet);
        });
    }

    private static void checkIfTableIsEmpty(ResultSet resultSet) {
        int columnId = 1;
        try {
            if (!resultSet.next() || !SnowflakeServiceImpl.checkIfTableIsEmpty(resultSet, columnId)) {
                throw new RuntimeException("Table is not empty. Aborting COPY with disposition EMPTY");
            }
        }
        catch (SQLException e) {
            throw new RuntimeException("Unable run pipeline with EMPTY disposition.", e);
        }
    }

    private static boolean checkIfTableIsEmpty(ResultSet resultSet, int columnId) throws SQLException {
        int rowCount = resultSet.getInt(columnId);
        return rowCount < 1;
    }

    private void prepareTableAccordingWriteDisposition(DataSource dataSource, String table, WriteDisposition writeDisposition) throws SQLException {
        switch (writeDisposition) {
            case TRUNCATE: {
                this.truncateTable(dataSource, table);
                break;
            }
            case EMPTY: {
                SnowflakeServiceImpl.checkIfTableIsEmpty(dataSource, table);
                break;
            }
        }
    }

    private static void runConnectionWithStatement(DataSource dataSource, String query, Consumer resultSetMethod) throws SQLException {
        Connection connection = dataSource.getConnection();
        SnowflakeServiceImpl.runStatement(query, connection, resultSetMethod);
        connection.close();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private static void runStatement(String query, Connection connection, Consumer resultSetMethod) throws SQLException {
        PreparedStatement statement = connection.prepareStatement(query);
        try {
            if (resultSetMethod != null) {
                ResultSet resultSet = statement.executeQuery();
                resultSetMethod.accept(resultSet);
            } else {
                statement.execute();
            }
        }
        finally {
            statement.close();
            connection.close();
        }
    }

    private Connection getConnection(SerializableFunction<Void, DataSource> dataSourceProviderFn) throws SQLException {
        DataSource dataSource = (DataSource)dataSourceProviderFn.apply(null);
        return dataSource.getConnection();
    }

    private String getProperBucketDir(String bucketDir) {
        if (bucketDir.contains(CloudProvider.GCS.getPrefix())) {
            return bucketDir.replace(CloudProvider.GCS.getPrefix(), SNOWFLAKE_GCS_PREFIX);
        }
        return bucketDir;
    }
}

