/*
 * Decompiled with CFR 0.152.
 */
package io.trino.testing.statistics;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.execution.querystats.PlanOptimizersStatsCollector;
import io.trino.execution.warnings.WarningCollector;
import io.trino.security.AccessControl;
import io.trino.sql.planner.Plan;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.plan.OutputNode;
import io.trino.testing.MaterializedRow;
import io.trino.testing.QueryRunner;
import io.trino.testing.statistics.Metric;
import io.trino.testing.statistics.MetricComparison;
import io.trino.testing.statistics.StatsContext;
import io.trino.transaction.TransactionBuilder;
import io.trino.transaction.TransactionManager;
import java.util.List;
import java.util.Map;
import java.util.OptionalDouble;
import java.util.stream.Collectors;

final class MetricComparator {
    private MetricComparator() {
    }

    static List<MetricComparison> getMetricComparisons(String query, QueryRunner runner, List<Metric> metrics) {
        List<OptionalDouble> estimatedValues = MetricComparator.getEstimatedValues(metrics, query, runner);
        List<OptionalDouble> actualValues = MetricComparator.getActualValues(metrics, query, runner);
        ImmutableList.Builder metricComparisons = ImmutableList.builder();
        for (int i = 0; i < metrics.size(); ++i) {
            metricComparisons.add((Object)new MetricComparison(metrics.get(i), estimatedValues.get(i), actualValues.get(i)));
        }
        return metricComparisons.build();
    }

    private static List<OptionalDouble> getEstimatedValues(List<Metric> metrics, String query, QueryRunner runner) {
        return (List)TransactionBuilder.transaction((TransactionManager)runner.getTransactionManager(), (AccessControl)runner.getAccessControl()).singleStatement().execute(runner.getDefaultSession(), session -> MetricComparator.getEstimatedValuesInternal(metrics, query, runner, session));
    }

    private static List<OptionalDouble> getEstimatedValuesInternal(List<Metric> metrics, String query, QueryRunner runner, Session session) {
        Plan queryPlan = runner.createPlan(session, query, WarningCollector.NOOP, PlanOptimizersStatsCollector.createPlanOptimizersStatsCollector());
        OutputNode outputNode = (OutputNode)queryPlan.getRoot();
        PlanNodeStatsEstimate outputNodeStats = queryPlan.getStatsAndCosts().getStats().getOrDefault(queryPlan.getRoot().getId(), PlanNodeStatsEstimate.unknown());
        StatsContext statsContext = MetricComparator.buildStatsContext(queryPlan, outputNode);
        return MetricComparator.getEstimatedValues(metrics, outputNodeStats, statsContext);
    }

    private static StatsContext buildStatsContext(Plan queryPlan, OutputNode outputNode) {
        ImmutableMap.Builder columnSymbols = ImmutableMap.builder();
        for (int columnId = 0; columnId < outputNode.getColumnNames().size(); ++columnId) {
            columnSymbols.put((Object)((String)outputNode.getColumnNames().get(columnId)), (Object)((Symbol)outputNode.getOutputSymbols().get(columnId)));
        }
        return new StatsContext((Map<String, Symbol>)columnSymbols.buildOrThrow(), queryPlan.getTypes());
    }

    private static List<OptionalDouble> getActualValues(List<Metric> metrics, String query, QueryRunner runner) {
        String statsQuery = "SELECT " + metrics.stream().map(Metric::getComputingAggregationSql).collect(Collectors.joining(",")) + " FROM (" + query + ")";
        try {
            MaterializedRow actualValuesRow = (MaterializedRow)Iterables.getOnlyElement((Iterable)runner.execute(statsQuery).getMaterializedRows());
            ImmutableList.Builder actualValues = ImmutableList.builder();
            for (int i = 0; i < metrics.size(); ++i) {
                actualValues.add((Object)metrics.get(i).getValueFromAggregationQueryResult(actualValuesRow.getField(i)));
            }
            return actualValues.build();
        }
        catch (Exception e) {
            throw new RuntimeException(String.format("Failed to execute query to compute actual values: %s", statsQuery), e);
        }
    }

    private static List<OptionalDouble> getEstimatedValues(List<Metric> metrics, PlanNodeStatsEstimate outputNodeStatisticsEstimates, StatsContext statsContext) {
        return (List)metrics.stream().map(metric -> metric.getValueFromPlanNodeEstimate(outputNodeStatisticsEstimates, statsContext)).collect(ImmutableList.toImmutableList());
    }
}

