/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.extensions.sql.impl.rule;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.beam.sdk.extensions.sql.impl.rel.BeamIOSourceRel;
import org.apache.beam.sdk.extensions.sql.impl.utils.CalciteUtils;
import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable;
import org.apache.beam.sdk.schemas.FieldAccessDescriptor;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.utils.SelectHelpers;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRule;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.Calc;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.RelFactories;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataType;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelRecordType;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexCall;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexInputRef;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexLiteral;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexLocalRef;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexProgram;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.RelBuilder;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.util.Pair;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;

public class BeamIOPushDownRule
extends RelOptRule {
    public static final BeamIOPushDownRule INSTANCE = new BeamIOPushDownRule(RelFactories.LOGICAL_BUILDER);

    public BeamIOPushDownRule(RelBuilderFactory relBuilderFactory) {
        super(BeamIOPushDownRule.operand(Calc.class, (RelOptRuleOperand)BeamIOPushDownRule.operand(BeamIOSourceRel.class, (RelOptRuleOperandChildren)BeamIOPushDownRule.any()), (RelOptRuleOperand[])new RelOptRuleOperand[0]), relBuilderFactory, null);
    }

    public void onMatch(RelOptRuleCall call) {
        BeamIOSourceRel ioSourceRel = (BeamIOSourceRel)call.rel(1);
        BeamSqlTable beamSqlTable = ioSourceRel.getBeamSqlTable();
        if (!beamSqlTable.supportsProjects()) {
            return;
        }
        for (RelDataTypeField field : ioSourceRel.getRowType().getFieldList()) {
            if (!(field.getType() instanceof RelRecordType)) continue;
            return;
        }
        Calc calc = (Calc)call.rel(0);
        RexProgram program = calc.getProgram();
        Pair projectFilter = program.split();
        RelDataType calcInputRowType = program.getInputRowType();
        HashSet<String> usedFields = new HashSet<String>();
        for (RexNode project : (ImmutableList)projectFilter.left) {
            this.findUtilizedInputRefs(calcInputRowType, project, usedFields);
        }
        for (RexNode filter : (ImmutableList)projectFilter.right) {
            this.findUtilizedInputRefs(calcInputRowType, filter, usedFields);
        }
        FieldAccessDescriptor resolved = FieldAccessDescriptor.withFieldNames(usedFields).resolve(beamSqlTable.getSchema());
        Schema newSchema = SelectHelpers.getOutputSchema((Schema)ioSourceRel.getBeamSqlTable().getSchema(), (FieldAccessDescriptor)resolved);
        RelDataType calcInputType = CalciteUtils.toCalciteRowType(newSchema, ioSourceRel.getCluster().getTypeFactory());
        if (this.isProjectRenameOnlyProgram(program)) {
            call.transformTo((RelNode)ioSourceRel.copy(calc.getRowType(), newSchema.getFieldNames()));
            return;
        }
        if (usedFields.size() == ioSourceRel.getRowType().getFieldCount()) {
            return;
        }
        BeamIOSourceRel newIoSourceRel = ioSourceRel.copy(calcInputType, newSchema.getFieldNames());
        RelBuilder relBuilder = call.builder();
        relBuilder.push((RelNode)newIoSourceRel);
        ArrayList<RexNode> newProjects = new ArrayList<RexNode>();
        ArrayList<RexNode> newFilter = new ArrayList<RexNode>();
        List<Integer> mapping = resolved.getFieldsAccessed().stream().map(FieldAccessDescriptor.FieldDescriptor::getFieldId).collect(Collectors.toList());
        for (RexNode filter : (ImmutableList)projectFilter.right) {
            newFilter.add(this.reMapRexNodeToNewInputs(filter, mapping));
        }
        for (RexNode project : (ImmutableList)projectFilter.left) {
            newProjects.add(this.reMapRexNodeToNewInputs(project, mapping));
        }
        relBuilder.filter(newFilter);
        relBuilder.project(newProjects, (Iterable)calc.getRowType().getFieldNames());
        RelNode result = relBuilder.build();
        call.transformTo(result);
    }

    @VisibleForTesting
    void findUtilizedInputRefs(RelDataType inputRowType, RexNode startNode, Set<String> usedFields) {
        ArrayDeque<RexNode> prerequisites = new ArrayDeque<RexNode>();
        prerequisites.add(startNode);
        while (!prerequisites.isEmpty()) {
            RexNode node = (RexNode)prerequisites.poll();
            if (node instanceof RexCall) {
                RexCall compositeNode = (RexCall)node;
                prerequisites.addAll(compositeNode.getOperands());
                continue;
            }
            if (node instanceof RexInputRef) {
                int inputFieldIndex = ((RexInputRef)node).getIndex();
                RelDataTypeField field = (RelDataTypeField)inputRowType.getFieldList().get(inputFieldIndex);
                usedFields.add(field.getName());
                continue;
            }
            if (node instanceof RexLiteral) continue;
            throw new RuntimeException("Unexpected RexNode encountered: " + node.getClass().getSimpleName());
        }
    }

    @VisibleForTesting
    RexNode reMapRexNodeToNewInputs(RexNode node, List<Integer> inputRefMapping) {
        if (node instanceof RexInputRef) {
            int oldInputIndex = ((RexInputRef)node).getIndex();
            int newInputIndex = inputRefMapping.indexOf(oldInputIndex);
            return new RexInputRef(newInputIndex, node.getType());
        }
        if (node instanceof RexCall) {
            RexCall compositeNode = (RexCall)node;
            ArrayList<RexNode> newOperands = new ArrayList<RexNode>();
            for (RexNode operand : compositeNode.getOperands()) {
                newOperands.add(this.reMapRexNodeToNewInputs(operand, inputRefMapping));
            }
            return compositeNode.clone(compositeNode.getType(), newOperands);
        }
        Preconditions.checkArgument((boolean)(node instanceof RexLiteral), (Object)("RexLiteral node expected, but was: " + node.getClass().getSimpleName()));
        return node;
    }

    @VisibleForTesting
    boolean isProjectRenameOnlyProgram(RexProgram program) {
        if (program.getCondition() != null) {
            return false;
        }
        int fieldCount = program.getInputRowType().getFieldCount();
        for (RexLocalRef ref : program.getProjectList()) {
            if (ref.getIndex() < fieldCount) continue;
            return false;
        }
        return true;
    }
}

