/*
 * Decompiled with CFR 0.152.
 */
package io.trino.testing;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Streams;
import com.google.common.graph.Traverser;
import io.airlift.units.Duration;
import io.trino.FeaturesConfig;
import io.trino.Session;
import io.trino.client.StageStats;
import io.trino.client.StatementStats;
import io.trino.execution.FailureInjector;
import io.trino.operator.OperatorStats;
import io.trino.operator.RetryPolicy;
import io.trino.server.DynamicFilterService;
import io.trino.spi.ErrorType;
import io.trino.spi.QueryId;
import io.trino.spi.predicate.Domain;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.testing.AbstractTestQueryFramework;
import io.trino.testing.MaterializedResult;
import io.trino.testing.QueryAssertions;
import io.trino.testing.QueryRunner;
import io.trino.testing.ResultWithQueryId;
import io.trino.testing.sql.TestTable;
import io.trino.tpch.TpchTable;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.UUID;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import org.assertj.core.api.AbstractBooleanAssert;
import org.assertj.core.api.AbstractThrowableAssert;
import org.assertj.core.api.Assertions;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

public abstract class BaseFailureRecoveryTest
extends AbstractTestQueryFramework {
    private static final String PARTITIONED_LINEITEM = "partitioned_lineitem";
    protected static final int INVOCATION_COUNT = 1;
    private static final Duration MAX_ERROR_DURATION = new Duration(5.0, TimeUnit.SECONDS);
    private static final Duration REQUEST_TIMEOUT = new Duration(5.0, TimeUnit.SECONDS);
    private final RetryPolicy retryPolicy;

    protected BaseFailureRecoveryTest(RetryPolicy retryPolicy) {
        this.retryPolicy = Objects.requireNonNull(retryPolicy, "retryPolicy is null");
    }

    @Override
    protected final QueryRunner createQueryRunner() throws Exception {
        return this.createQueryRunner((List<TpchTable<?>>)ImmutableList.of((Object)TpchTable.NATION, (Object)TpchTable.ORDERS, (Object)TpchTable.CUSTOMER, (Object)TpchTable.SUPPLIER), (Map<String, String>)ImmutableMap.builder().put((Object)"query.remote-task.max-error-duration", (Object)MAX_ERROR_DURATION.toString()).put((Object)"exchange.max-error-duration", (Object)MAX_ERROR_DURATION.toString()).put((Object)"retry-policy", (Object)this.retryPolicy.toString()).put((Object)"retry-initial-delay", (Object)"0s").put((Object)"retry-attempts", (Object)"1").put((Object)"failure-injection.request-timeout", (Object)new Duration((double)(REQUEST_TIMEOUT.toMillis() * 2L), TimeUnit.MILLISECONDS).toString()).put((Object)"exchange.http-client.idle-timeout", (Object)REQUEST_TIMEOUT.toString()).put((Object)"query.initial-hash-partitions", (Object)"5").put((Object)"exchange.deduplication-buffer-size", (Object)"1kB").buildOrThrow(), (Map<String, String>)ImmutableMap.builder().put((Object)"scheduler.http-client.idle-timeout", (Object)REQUEST_TIMEOUT.toString()).buildOrThrow());
    }

    protected abstract QueryRunner createQueryRunner(List<TpchTable<?>> var1, Map<String, String> var2, Map<String, String> var3) throws Exception;

    @BeforeClass
    public void initTables() throws Exception {
        this.createPartitionedLineitemTable(PARTITIONED_LINEITEM, (List<String>)ImmutableList.of((Object)"orderkey", (Object)"partkey", (Object)"suppkey"), "suppkey");
    }

    protected abstract void createPartitionedLineitemTable(String var1, List<String> var2, String var3);

    protected abstract boolean areWriteRetriesSupported();

    @Test(invocationCount=1)
    public void testSimpleSelect() {
        this.testSelect("SELECT * FROM nation");
    }

    @Test(invocationCount=1)
    public void testAggregation() {
        this.testSelect("SELECT orderStatus, count(*) FROM orders GROUP BY orderStatus");
    }

    @Test(invocationCount=1)
    public void testJoinDynamicFilteringDisabled() {
        String selectQuery = "SELECT * FROM partitioned_lineitem JOIN supplier ON partitioned_lineitem.suppkey = supplier.suppkey AND supplier.name = 'Supplier#000000001'";
        this.testSelect(selectQuery, Optional.of(this.enableDynamicFiltering(false)));
    }

    @Test(invocationCount=1)
    public void testJoinDynamicFilteringEnabled() {
        String selectQuery = "SELECT * FROM partitioned_lineitem JOIN supplier ON partitioned_lineitem.suppkey = supplier.suppkey AND supplier.name = 'Supplier#000000001'";
        this.testSelect(selectQuery, Optional.of(this.enableDynamicFiltering(true)), queryId -> {
            DynamicFilterService.DynamicFiltersStats dynamicFiltersStats = this.getDynamicFilteringStats((QueryId)queryId);
            Assertions.assertThat((int)dynamicFiltersStats.getLazyDynamicFilters()).isEqualTo(1);
            DynamicFilterService.DynamicFilterDomainStats domainStats = (DynamicFilterService.DynamicFilterDomainStats)Iterables.getOnlyElement((Iterable)dynamicFiltersStats.getDynamicFilterDomainStats());
            Assertions.assertThat((String)domainStats.getSimplifiedDomain()).isEqualTo(Domain.singleValue((Type)BigintType.BIGINT, (Object)1L).toString(this.getSession().toConnectorSession()));
            OperatorStats probeStats = this.searchScanFilterAndProjectOperatorStats((QueryId)queryId, this.getQualifiedTableName(PARTITIONED_LINEITEM));
            Assertions.assertThat((long)probeStats.getInputPositions()).isIn(new Object[]{615L, 1230L});
        });
    }

    protected void testSelect(String query) {
        this.testSelect(query, Optional.empty());
    }

    protected void testSelect(String query, Optional<Session> session) {
        this.testSelect(query, session, queryId -> {});
    }

    protected void testSelect(String query, Optional<Session> session, Consumer<QueryId> queryAssertion) {
        this.assertThatQuery(query).withSession(session).experiencing(FailureInjector.InjectedFailureType.TASK_MANAGEMENT_REQUEST_FAILURE).at(BaseFailureRecoveryTest.leafStage()).failsWithoutRetries(failure -> failure.hasMessageFindingMatch("Error 500 Internal Server Error|Error closing remote buffer, expected 204 got 500")).finishesSuccessfully(queryAssertion);
        this.assertThatQuery(query).withSession(session).experiencing(FailureInjector.InjectedFailureType.TASK_GET_RESULTS_REQUEST_FAILURE).at(BaseFailureRecoveryTest.boundaryDistributedStage()).failsWithoutRetries(failure -> failure.hasMessageFindingMatch("Error 500 Internal Server Error|Error closing remote buffer, expected 204 got 500")).finishesSuccessfully(queryAssertion);
        this.assertThatQuery(query).withSession(session).experiencing(FailureInjector.InjectedFailureType.TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)).at(BaseFailureRecoveryTest.leafStage()).failsWithoutRetries(failure -> failure.hasMessageContaining("This error is injected by the failure injection service")).finishesSuccessfully(queryAssertion);
        this.assertThatQuery(query).withSession(session).experiencing(FailureInjector.InjectedFailureType.TASK_FAILURE, Optional.of(ErrorType.EXTERNAL)).at(BaseFailureRecoveryTest.intermediateDistributedStage()).failsWithoutRetries(failure -> failure.hasMessageContaining("This error is injected by the failure injection service")).finishesSuccessfully(queryAssertion);
        this.assertThatQuery(query).experiencing(FailureInjector.InjectedFailureType.TASK_MANAGEMENT_REQUEST_TIMEOUT).at(BaseFailureRecoveryTest.intermediateDistributedStage()).failsWithoutRetries(failure -> failure.hasMessageContaining("Encountered too many errors talking to a worker node")).finishesSuccessfully();
        this.assertThatQuery(query).experiencing(FailureInjector.InjectedFailureType.TASK_GET_RESULTS_REQUEST_TIMEOUT).at(BaseFailureRecoveryTest.boundaryDistributedStage()).failsWithoutRetries(failure -> failure.hasMessageFindingMatch("Encountered too many errors talking to a worker node|Error closing remote buffer.*3 failures")).finishesSuccessfully();
    }

    @Test(invocationCount=1)
    public void testUserFailure() {
        Assertions.assertThatThrownBy(() -> this.getQueryRunner().execute("SELECT * FROM nation WHERE regionKey / nationKey - 1 = 0")).hasMessageContaining("Division by zero");
        this.assertThatQuery("SELECT * FROM nation").experiencing(FailureInjector.InjectedFailureType.TASK_FAILURE, Optional.of(ErrorType.USER_ERROR)).at(BaseFailureRecoveryTest.leafStage()).failsAlways(failure -> failure.hasMessageContaining("This error is injected by the failure injection service"));
    }

    @Test(invocationCount=1)
    public void testCreateTable() {
        this.testTableModification(Optional.empty(), "CREATE TABLE <table> AS SELECT * FROM orders", Optional.of("DROP TABLE <table>"));
    }

    @Test(invocationCount=1)
    public void testInsert() {
        this.testTableModification(Optional.of("CREATE TABLE <table> AS SELECT * FROM orders WITH NO DATA"), "INSERT INTO <table> SELECT * FROM orders", Optional.of("DROP TABLE <table>"));
    }

    @Test(invocationCount=1)
    public void testDelete() {
        this.testTableModification(Optional.of("CREATE TABLE <table> AS SELECT * FROM orders"), "DELETE FROM orders WHERE orderkey = 1", Optional.of("DROP TABLE <table>"));
    }

    @Test(invocationCount=1)
    public void testDeleteWithSubquery() {
        this.testTableModification(Optional.of("CREATE TABLE <table> AS SELECT * FROM orders"), "DELETE FROM orders WHERE custkey IN (SELECT custkey FROM customer WHERE nationkey = 1)", Optional.of("DROP TABLE <table>"));
    }

    @Test(invocationCount=1)
    public void testUpdate() {
        this.testTableModification(Optional.of("CREATE TABLE <table> AS SELECT * FROM orders"), "UPDATE orders SET shippriority = 101 WHERE custkey = 1", Optional.of("DROP TABLE <table>"));
    }

    @Test(invocationCount=1)
    public void testUpdateWithSubquery() {
        this.testTableModification(Optional.of("CREATE TABLE <table> AS SELECT * FROM orders"), "UPDATE orders SET shippriority = 101 WHERE custkey = (SELECT min(custkey) FROM customer)", Optional.of("DROP TABLE <table>"));
    }

    @Test(invocationCount=1)
    public void testAnalyzeStatistics() {
        this.testNonSelect(Optional.empty(), Optional.of("CREATE TABLE <table> AS SELECT * FROM orders"), "ANALYZE <table>", Optional.of("DROP TABLE <table>"), false);
    }

    @Test(invocationCount=1)
    public void testRefreshMaterializedView() {
        this.testTableModification(Optional.of("CREATE MATERIALIZED VIEW <table> AS SELECT * FROM orders"), "REFRESH MATERIALIZED VIEW <table>", Optional.of("DROP MATERIALIZED VIEW <table>"));
    }

    @Test(invocationCount=1)
    public void testExplainAnalyze() {
        this.testSelect("EXPLAIN ANALYZE SELECT orderStatus, count(*) FROM orders GROUP BY orderStatus");
        this.testTableModification(Optional.of("CREATE TABLE <table> AS SELECT * FROM orders WITH NO DATA"), "EXPLAIN ANALYZE INSERT INTO <table> SELECT * FROM orders", Optional.of("DROP TABLE <table>"));
    }

    @Test(invocationCount=1)
    public void testRequestTimeouts() {
        this.assertThatQuery("SELECT * FROM nation").experiencing(FailureInjector.InjectedFailureType.TASK_MANAGEMENT_REQUEST_TIMEOUT).at(BaseFailureRecoveryTest.leafStage()).failsWithoutRetries(failure -> failure.hasMessageContaining("Encountered too many errors talking to a worker node")).finishesSuccessfully();
        this.assertThatQuery("SELECT * FROM nation").experiencing(FailureInjector.InjectedFailureType.TASK_MANAGEMENT_REQUEST_TIMEOUT).at(BaseFailureRecoveryTest.boundaryDistributedStage()).failsWithoutRetries(failure -> failure.hasMessageContaining("Encountered too many errors talking to a worker node")).finishesSuccessfully();
        if (this.areWriteRetriesSupported()) {
            this.assertThatQuery("INSERT INTO <table> SELECT * FROM orders").withSetupQuery(Optional.of("CREATE TABLE <table> AS SELECT * FROM orders WITH NO DATA")).withCleanupQuery(Optional.of("DROP TABLE <table>")).experiencing(FailureInjector.InjectedFailureType.TASK_GET_RESULTS_REQUEST_TIMEOUT).at(BaseFailureRecoveryTest.leafStage()).failsWithoutRetries(failure -> failure.hasMessageContaining("Encountered too many errors talking to a worker node")).finishesSuccessfullyWithoutTaskFailures();
        }
    }

    protected void testTableModification(Optional<String> setupQuery, String query, Optional<String> cleanupQuery) {
        this.testTableModification(Optional.empty(), setupQuery, query, cleanupQuery);
    }

    protected void testTableModification(Optional<Session> session, Optional<String> setupQuery, String query, Optional<String> cleanupQuery) {
        this.testNonSelect(session, setupQuery, query, cleanupQuery, true);
    }

    protected void testNonSelect(Optional<Session> session, Optional<String> setupQuery, String query, Optional<String> cleanupQuery, boolean writesData) {
        if (writesData && !this.areWriteRetriesSupported()) {
            this.assertThatQuery(query).withSession(session).withSetupQuery(setupQuery).withCleanupQuery(cleanupQuery).failsDespiteRetries(failure -> failure.hasMessageMatching("This connector does not support query retries"));
            return;
        }
        this.assertThatQuery(query).withSession(session).withSetupQuery(setupQuery).withCleanupQuery(cleanupQuery).experiencing(FailureInjector.InjectedFailureType.TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)).at(BaseFailureRecoveryTest.boundaryCoordinatorStage()).failsAlways(failure -> failure.hasMessageContaining("This error is injected by the failure injection service"));
        this.assertThatQuery(query).withSession(session).withSetupQuery(setupQuery).withCleanupQuery(cleanupQuery).experiencing(FailureInjector.InjectedFailureType.TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)).at(BaseFailureRecoveryTest.rootStage()).failsAlways(failure -> failure.hasMessageContaining("This error is injected by the failure injection service"));
        this.assertThatQuery(query).withSession(session).withSetupQuery(setupQuery).withCleanupQuery(cleanupQuery).experiencing(FailureInjector.InjectedFailureType.TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)).at(BaseFailureRecoveryTest.leafStage()).failsWithoutRetries(failure -> failure.hasMessageContaining("This error is injected by the failure injection service")).finishesSuccessfully();
        this.assertThatQuery(query).withSession(session).withSetupQuery(setupQuery).withCleanupQuery(cleanupQuery).experiencing(FailureInjector.InjectedFailureType.TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)).at(BaseFailureRecoveryTest.boundaryDistributedStage()).failsWithoutRetries(failure -> failure.hasMessageContaining("This error is injected by the failure injection service")).finishesSuccessfully();
        this.assertThatQuery(query).withSession(session).withSetupQuery(setupQuery).withCleanupQuery(cleanupQuery).experiencing(FailureInjector.InjectedFailureType.TASK_FAILURE, Optional.of(ErrorType.INTERNAL_ERROR)).at(BaseFailureRecoveryTest.intermediateDistributedStage()).failsWithoutRetries(failure -> failure.hasMessageContaining("This error is injected by the failure injection service")).finishesSuccessfully();
        this.assertThatQuery(query).withSession(session).withSetupQuery(setupQuery).withCleanupQuery(cleanupQuery).experiencing(FailureInjector.InjectedFailureType.TASK_MANAGEMENT_REQUEST_FAILURE).at(BaseFailureRecoveryTest.boundaryDistributedStage()).failsWithoutRetries(failure -> failure.hasMessageFindingMatch("Error 500 Internal Server Error|Error closing remote buffer, expected 204 got 500")).finishesSuccessfully();
        this.assertThatQuery(query).withSession(session).withSetupQuery(setupQuery).withCleanupQuery(cleanupQuery).experiencing(FailureInjector.InjectedFailureType.TASK_GET_RESULTS_REQUEST_FAILURE).at(BaseFailureRecoveryTest.boundaryDistributedStage()).failsWithoutRetries(failure -> failure.hasMessageFindingMatch("Error 500 Internal Server Error|Error closing remote buffer, expected 204 got 500")).finishesSuccessfully();
        this.assertThatQuery(query).withSetupQuery(setupQuery).withCleanupQuery(cleanupQuery).experiencing(FailureInjector.InjectedFailureType.TASK_MANAGEMENT_REQUEST_TIMEOUT).at(BaseFailureRecoveryTest.boundaryDistributedStage()).failsWithoutRetries(failure -> failure.hasMessageContaining("Encountered too many errors talking to a worker node")).finishesSuccessfully();
        this.assertThatQuery(query).withSetupQuery(setupQuery).withCleanupQuery(cleanupQuery).experiencing(FailureInjector.InjectedFailureType.TASK_GET_RESULTS_REQUEST_TIMEOUT).at(BaseFailureRecoveryTest.boundaryDistributedStage()).failsWithoutRetries(failure -> failure.hasMessageContaining("Encountered too many errors talking to a worker node")).finishesSuccessfully();
    }

    private FailureRecoveryAssert assertThatQuery(String query) {
        return new FailureRecoveryAssert(query);
    }

    protected static Function<MaterializedResult, Integer> rootStage() {
        return result -> Integer.parseInt(BaseFailureRecoveryTest.getRootStage(result).getStageId());
    }

    protected static Function<MaterializedResult, Integer> boundaryCoordinatorStage() {
        return result -> BaseFailureRecoveryTest.findStageId(result, stage -> stage.isCoordinatorOnly() && stage.getSubStages().stream().noneMatch(StageStats::isCoordinatorOnly));
    }

    protected static Function<MaterializedResult, Integer> boundaryDistributedStage() {
        return result -> {
            StageStats rootStage = BaseFailureRecoveryTest.getRootStage(result);
            if (!rootStage.isCoordinatorOnly()) {
                return Integer.parseInt(rootStage.getStageId());
            }
            StageStats boundaryCoordinatorStage = BaseFailureRecoveryTest.findStage(result, stage -> stage.isCoordinatorOnly() && stage.getSubStages().stream().noneMatch(StageStats::isCoordinatorOnly));
            StageStats boundaryDistributedStage = (StageStats)boundaryCoordinatorStage.getSubStages().get(ThreadLocalRandom.current().nextInt(boundaryCoordinatorStage.getSubStages().size()));
            return Integer.parseInt(boundaryDistributedStage.getStageId());
        };
    }

    protected static Function<MaterializedResult, Integer> intermediateDistributedStage() {
        return result -> BaseFailureRecoveryTest.findStageId(result, stage -> !stage.isCoordinatorOnly() && !stage.getSubStages().isEmpty());
    }

    protected static Function<MaterializedResult, Integer> leafStage() {
        return result -> BaseFailureRecoveryTest.findStageId(result, stage -> stage.getSubStages().isEmpty());
    }

    private static int findStageId(MaterializedResult result, Predicate<StageStats> predicate) {
        return Integer.parseInt(BaseFailureRecoveryTest.findStage(result, predicate).getStageId());
    }

    private static StageStats findStage(MaterializedResult result, Predicate<StageStats> predicate) {
        List stages = (List)Streams.stream((Iterable)Traverser.forTree(StageStats::getSubStages).breadthFirst((Object)BaseFailureRecoveryTest.getRootStage(result))).filter(predicate).collect(ImmutableList.toImmutableList());
        if (stages.isEmpty()) {
            throw new IllegalArgumentException("stage not found");
        }
        return (StageStats)stages.get(ThreadLocalRandom.current().nextInt(stages.size()));
    }

    private static StageStats getStageStats(MaterializedResult result, int stageId) {
        return Streams.stream((Iterable)Traverser.forTree(StageStats::getSubStages).breadthFirst((Object)BaseFailureRecoveryTest.getRootStage(result))).filter(stageStats -> Integer.parseInt(stageStats.getStageId()) == stageId).findFirst().orElseThrow(() -> new IllegalArgumentException("stage stats not found: " + stageId));
    }

    private static StageStats getRootStage(MaterializedResult result) {
        StatementStats statementStats = (StatementStats)result.getStatementStats().orElseThrow(() -> new IllegalArgumentException("statement stats is not present"));
        return Objects.requireNonNull(statementStats.getRootStage(), "root stage is null");
    }

    private Session enableDynamicFiltering(boolean enabled) {
        Session defaultSession = this.getQueryRunner().getDefaultSession();
        return Session.builder((Session)defaultSession).setSystemProperty("enable_dynamic_filtering", Boolean.toString(enabled)).setSystemProperty("join_reordering_strategy", FeaturesConfig.JoinReorderingStrategy.NONE.name()).setSystemProperty("join_distribution_type", FeaturesConfig.JoinDistributionType.PARTITIONED.name()).setCatalogSessionProperty((String)defaultSession.getCatalog().orElseThrow(), "dynamic_filtering_wait_timeout", "1h").build();
    }

    protected class FailureRecoveryAssert {
        private final String query;
        private Session session;
        private Optional<Function<MaterializedResult, Integer>> stageSelector;
        private Optional<FailureInjector.InjectedFailureType> failureType;
        private Optional<ErrorType> errorType;
        private Optional<String> setup;
        private Optional<String> cleanup;

        public FailureRecoveryAssert(String query) {
            this.session = BaseFailureRecoveryTest.this.getQueryRunner().getDefaultSession();
            this.failureType = Optional.empty();
            this.errorType = Optional.empty();
            this.setup = Optional.empty();
            this.cleanup = Optional.empty();
            this.query = Objects.requireNonNull(query, "query is null");
        }

        public FailureRecoveryAssert withSession(Optional<Session> session) {
            Objects.requireNonNull(session, "session is null");
            session.ifPresent(value -> {
                this.session = value;
            });
            return this;
        }

        public FailureRecoveryAssert withSetupQuery(Optional<String> query) {
            this.setup = Objects.requireNonNull(query, "query is null");
            return this;
        }

        public FailureRecoveryAssert withCleanupQuery(Optional<String> query) {
            this.cleanup = Objects.requireNonNull(query, "query is null");
            return this;
        }

        public FailureRecoveryAssert experiencing(FailureInjector.InjectedFailureType failureType) {
            return this.experiencing(failureType, Optional.empty());
        }

        public FailureRecoveryAssert experiencing(FailureInjector.InjectedFailureType failureType, Optional<ErrorType> errorType) {
            this.failureType = Optional.of(Objects.requireNonNull(failureType, "failureType is null"));
            this.errorType = Objects.requireNonNull(errorType, "errorType is null");
            if (failureType == FailureInjector.InjectedFailureType.TASK_FAILURE) {
                Preconditions.checkArgument((boolean)errorType.isPresent(), (Object)"error type must be present when injection type is task failure");
            } else {
                Preconditions.checkArgument((boolean)errorType.isEmpty(), (Object)"error type must not be present when injection type is not task failure");
            }
            return this;
        }

        public FailureRecoveryAssert at(Function<MaterializedResult, Integer> stageSelector) {
            this.stageSelector = Optional.of(Objects.requireNonNull(stageSelector, "stageSelector is null"));
            return this;
        }

        private ExecutionResult executeExpected() {
            return this.execute(this.noRetries(this.session), this.query, Optional.empty());
        }

        private ExecutionResult executeActual(OptionalInt failureStageId) {
            return this.executeActual(this.session, failureStageId);
        }

        private ExecutionResult executeActualNoRetries(OptionalInt failureStageId) {
            return this.executeActual(this.noRetries(this.session), failureStageId);
        }

        private ExecutionResult executeActual(Session session, OptionalInt failureStageId) {
            String token = UUID.randomUUID().toString();
            if (this.failureType.isPresent()) {
                BaseFailureRecoveryTest.this.getQueryRunner().injectTaskFailure(token, failureStageId.orElseThrow(() -> new IllegalArgumentException("failure stageId not provided")), 0, 0, this.failureType.get(), this.errorType);
                return this.execute(session, this.query, Optional.of(token));
            }
            return this.execute(session, this.query, Optional.of(token));
        }

        private ExecutionResult execute(Session session, String query, Optional<String> traceToken) {
            Optional<MaterializedResult> updatedTableStatistics;
            Optional<MaterializedResult> updatedTableContent;
            RuntimeException failure;
            ResultWithQueryId<MaterializedResult> resultWithQueryId;
            block8: {
                String tableName = "table_" + TestTable.randomTableSuffix();
                this.setup.ifPresent(sql -> BaseFailureRecoveryTest.this.getQueryRunner().execute(this.noRetries(session), this.resolveTableName((String)sql, tableName)));
                resultWithQueryId = null;
                failure = null;
                try {
                    resultWithQueryId = BaseFailureRecoveryTest.this.getDistributedQueryRunner().executeWithQueryId(this.withTraceToken(session, traceToken), this.resolveTableName(query, tableName));
                }
                catch (RuntimeException e) {
                    failure = e;
                }
                MaterializedResult result = resultWithQueryId == null ? null : (MaterializedResult)resultWithQueryId.getResult();
                updatedTableContent = Optional.empty();
                if (result != null && result.getUpdateCount().isPresent()) {
                    updatedTableContent = Optional.of(BaseFailureRecoveryTest.this.getQueryRunner().execute(this.noRetries(session), "SELECT * FROM " + tableName));
                }
                updatedTableStatistics = Optional.empty();
                if (result != null && result.getUpdateType().isPresent() && ((String)result.getUpdateType().get()).equals("ANALYZE")) {
                    updatedTableStatistics = Optional.of(BaseFailureRecoveryTest.this.getQueryRunner().execute(this.noRetries(session), "SHOW STATS FOR " + tableName));
                }
                try {
                    this.cleanup.ifPresent(sql -> BaseFailureRecoveryTest.this.getQueryRunner().execute(this.noRetries(session), this.resolveTableName((String)sql, tableName)));
                }
                catch (RuntimeException e) {
                    if (failure == null) {
                        failure = e;
                    }
                    if (failure == e) break block8;
                    failure.addSuppressed(e);
                }
            }
            if (failure != null) {
                throw failure;
            }
            return new ExecutionResult(resultWithQueryId, updatedTableContent, updatedTableStatistics);
        }

        public void finishesSuccessfully() {
            this.finishesSuccessfully(queryId -> {});
        }

        public void finishesSuccessfullyWithoutTaskFailures() {
            this.finishesSuccessfully(queryId -> {}, false);
        }

        private void finishesSuccessfully(Consumer<QueryId> queryAssertion) {
            this.finishesSuccessfully(queryAssertion, true);
        }

        public void finishesSuccessfully(Consumer<QueryId> queryAssertion, boolean expectTaskFailures) {
            this.verifyFailureTypeAndStageSelector();
            ExecutionResult expected = this.executeExpected();
            MaterializedResult expectedQueryResult = expected.getQueryResult();
            OptionalInt failureStageId = this.getFailureStageId(() -> expectedQueryResult);
            ExecutionResult actual = this.executeActual(failureStageId);
            Assert.assertEquals((int)BaseFailureRecoveryTest.getStageStats(actual.getQueryResult(), failureStageId.getAsInt()).getFailedTasks(), (int)(expectTaskFailures ? 1 : 0));
            MaterializedResult actualQueryResult = actual.getQueryResult();
            boolean isAnalyze = expectedQueryResult.getUpdateType().isPresent() && ((String)expectedQueryResult.getUpdateType().get()).equals("ANALYZE");
            boolean isUpdate = expectedQueryResult.getUpdateCount().isPresent();
            boolean isExplain = this.query.trim().toUpperCase(Locale.ENGLISH).startsWith("EXPLAIN");
            if (isAnalyze) {
                Assert.assertEquals((Object)actualQueryResult.getUpdateCount(), (Object)expectedQueryResult.getUpdateCount());
                Assertions.assertThat(expected.getUpdatedTableStatistics()).isPresent();
                Assertions.assertThat(actual.getUpdatedTableStatistics()).isPresent();
                MaterializedResult expectedUpdatedTableStatistics = expected.getUpdatedTableStatistics().get();
                MaterializedResult actualUpdatedTableStatistics = actual.getUpdatedTableStatistics().get();
                QueryAssertions.assertEqualsIgnoreOrder(actualUpdatedTableStatistics, expectedUpdatedTableStatistics, "For query: \n " + this.query);
            } else if (isUpdate) {
                Assert.assertEquals((Object)actualQueryResult.getUpdateCount(), (Object)expectedQueryResult.getUpdateCount());
                Assertions.assertThat(expected.getUpdatedTableContent()).isPresent();
                Assertions.assertThat(actual.getUpdatedTableContent()).isPresent();
                MaterializedResult expectedUpdatedTableContent = expected.getUpdatedTableContent().get();
                MaterializedResult actualUpdatedTableContent = actual.getUpdatedTableContent().get();
                QueryAssertions.assertEqualsIgnoreOrder(actualUpdatedTableContent, expectedUpdatedTableContent, "For query: \n " + this.query);
            } else if (isExplain) {
                Assert.assertEquals((int)actualQueryResult.getRowCount(), (int)expectedQueryResult.getRowCount());
            } else {
                QueryAssertions.assertEqualsIgnoreOrder(actualQueryResult, expectedQueryResult, "For query: \n " + this.query);
            }
            queryAssertion.accept(actual.getQueryId());
        }

        public FailureRecoveryAssert failsAlways(Consumer<AbstractThrowableAssert> failureAssertion) {
            this.failsWithoutRetries(failureAssertion);
            this.failsDespiteRetries(failureAssertion);
            return this;
        }

        public FailureRecoveryAssert failsWithoutRetries(Consumer<AbstractThrowableAssert> failureAssertion) {
            this.verifyFailureTypeAndStageSelector();
            OptionalInt failureStageId = this.getFailureStageId(() -> this.executeExpected().getQueryResult());
            failureAssertion.accept(Assertions.assertThatThrownBy(() -> this.executeActualNoRetries(failureStageId)));
            return this;
        }

        public FailureRecoveryAssert failsDespiteRetries(Consumer<AbstractThrowableAssert> failureAssertion) {
            this.verifyFailureTypeAndStageSelector();
            OptionalInt failureStageId = this.getFailureStageId(() -> this.executeExpected().getQueryResult());
            failureAssertion.accept(Assertions.assertThatThrownBy(() -> this.executeActual(failureStageId)));
            return this;
        }

        private void verifyFailureTypeAndStageSelector() {
            ((AbstractBooleanAssert)Assertions.assertThat((this.failureType.isPresent() == this.stageSelector.isPresent() ? 1 : 0) != 0).withFailMessage("Either both or none of failureType and stageSelector must be set", new Object[0])).isTrue();
        }

        private OptionalInt getFailureStageId(Supplier<MaterializedResult> expectedQueryResult) {
            if (this.stageSelector.isEmpty()) {
                return OptionalInt.empty();
            }
            return OptionalInt.of(this.stageSelector.get().apply(expectedQueryResult.get()));
        }

        private String resolveTableName(String query, String tableName) {
            return query.replaceAll("<table>", tableName);
        }

        private Session noRetries(Session session) {
            return Session.builder((Session)session).setSystemProperty("retry_policy", "NONE").build();
        }

        private Session withTraceToken(Session session, Optional<String> traceToken) {
            return Session.builder((Session)session).setTraceToken(traceToken).build();
        }
    }

    private static class ExecutionResult {
        private final MaterializedResult queryResult;
        private final QueryId queryId;
        private final Optional<MaterializedResult> updatedTableContent;
        private final Optional<MaterializedResult> updatedTableStatistics;

        private ExecutionResult(ResultWithQueryId<MaterializedResult> resultWithQueryId, Optional<MaterializedResult> updatedTableContent, Optional<MaterializedResult> updatedTableStatistics) {
            Objects.requireNonNull(resultWithQueryId, "resultWithQueryId is null");
            this.queryResult = resultWithQueryId.getResult();
            this.queryId = resultWithQueryId.getQueryId();
            this.updatedTableContent = Objects.requireNonNull(updatedTableContent, "updatedTableContent is null");
            this.updatedTableStatistics = Objects.requireNonNull(updatedTableStatistics, "updatedTableStatistics is null");
        }

        public MaterializedResult getQueryResult() {
            return this.queryResult;
        }

        public QueryId getQueryId() {
            return this.queryId;
        }

        public Optional<MaterializedResult> getUpdatedTableContent() {
            return this.updatedTableContent;
        }

        public Optional<MaterializedResult> getUpdatedTableStatistics() {
            return this.updatedTableStatistics;
        }
    }
}

