/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.piglet;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.calcite.piglet.ImmutablePigToSqlAggregateRule;
import org.apache.calcite.piglet.PigRelUdfConverter;
import org.apache.calcite.piglet.PigTypes;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexVisitor;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;
import org.immutables.value.Value;

@Value.Enclosing
public class PigToSqlAggregateRule
extends RelRule<Config> {
    private static final String MULTISET_PROJECTION = "MULTISET_PROJECTION";
    public static final PigToSqlAggregateRule INSTANCE = ImmutablePigToSqlAggregateRule.Config.builder().withOperandSupplier(b0 -> b0.operand(Project.class).oneInput(b1 -> b1.operand(Project.class).oneInput(b2 -> b2.operand(Aggregate.class).oneInput(b3 -> b3.operand(Project.class).anyInputs())))).build().toRule();

    protected PigToSqlAggregateRule(Config config) {
        super((RelRule.Config)config);
    }

    public void onMatch(RelOptRuleCall call) {
        Project oldTopProject = (Project)call.rel(0);
        Project oldMiddleProject = (Project)call.rel(1);
        Aggregate oldAgg = (Aggregate)call.rel(2);
        Project oldBottomProject = (Project)call.rel(3);
        RelBuilder relBuilder = call.builder();
        if (oldAgg.getAggCallList().size() != 1 || ((AggregateCall)oldAgg.getAggCallList().get(0)).getAggregation().getKind() != SqlKind.COLLECT) {
            return;
        }
        ArrayList<RexCall> pigAggUdfs = new ArrayList<RexCall>();
        boolean needGroupingCol = false;
        for (RexNode rex : oldTopProject.getProjects()) {
            PigAggUdfFinder udfVisitor = new PigAggUdfFinder(1);
            rex.accept((RexVisitor)udfVisitor);
            if (!udfVisitor.pigAggCalls.isEmpty()) {
                for (Object pigAgg : udfVisitor.pigAggCalls) {
                    if (pigAggUdfs.contains(pigAgg)) continue;
                    pigAggUdfs.add((RexCall)pigAgg);
                }
                continue;
            }
            if (!udfVisitor.projectColReferred) continue;
            needGroupingCol = true;
        }
        ArrayList newBottomProjects = new ArrayList();
        relBuilder.push(oldBottomProject.getInput());
        for (int i = 0; i < oldAgg.getGroupCount(); ++i) {
            newBottomProjects.add(oldBottomProject.getProjects().get(i));
        }
        if (needGroupingCol) {
            RexNode row = relBuilder.getRexBuilder().makeCall(relBuilder.peek().getRowType(), (SqlOperator)SqlStdOperatorTable.ROW, (List)relBuilder.fields());
            newBottomProjects.add(row);
        }
        int groupCount = oldAgg.getGroupCount() + (needGroupingCol ? 1 : 0);
        HashMap<Integer, Integer> projectedAggColumns = new HashMap<Integer, Integer>();
        for (int i = 0; i < newBottomProjects.size(); ++i) {
            if (!(newBottomProjects.get(i) instanceof RexInputRef)) continue;
            projectedAggColumns.put(((RexInputRef)newBottomProjects.get(i)).getIndex(), i);
        }
        HashMap aggCallColumns = new HashMap();
        for (RexCall rexCall : pigAggUdfs) {
            List<Integer> requiredColumns = PigToSqlAggregateRule.getAggColumns(rexCall);
            ArrayList<Integer> newColIndexes = new ArrayList<Integer>();
            Iterator<Integer> iterator = requiredColumns.iterator();
            while (iterator.hasNext()) {
                int col = iterator.next();
                Integer newCol = (Integer)projectedAggColumns.get(col);
                if (newCol != null) {
                    newColIndexes.add(newCol);
                    continue;
                }
                RexCall rowCall = (RexCall)oldBottomProject.getProjects().get(oldAgg.getGroupCount());
                RexInputRef columnRef = (RexInputRef)rowCall.getOperands().get(col);
                int newIndex = newBottomProjects.size();
                newBottomProjects.add(columnRef);
                projectedAggColumns.put(columnRef.getIndex(), newIndex);
                newColIndexes.add(newIndex);
            }
            aggCallColumns.put(rexCall, newColIndexes);
        }
        relBuilder.project(newBottomProjects);
        RelBuilder.GroupKey groupKey = relBuilder.groupKey(oldAgg.getGroupSet(), (Iterable)oldAgg.groupSets);
        ArrayList<RelBuilder.AggCall> aggCalls = new ArrayList<RelBuilder.AggCall>();
        if (needGroupingCol) {
            aggCalls.add(relBuilder.aggregateCall(SqlStdOperatorTable.COLLECT, new RexNode[]{relBuilder.field(groupCount - 1)}));
        }
        for (RexCall rexCall : pigAggUdfs) {
            ArrayList<RexInputRef> aggOperands = new ArrayList<RexInputRef>();
            Iterator col = ((List)aggCallColumns.get(rexCall)).iterator();
            while (col.hasNext()) {
                int i = (Integer)col.next();
                aggOperands.add(relBuilder.field(i));
            }
            if (PigToSqlAggregateRule.isMultisetProjection(rexCall)) {
                if (aggOperands.size() == 1) {
                    aggCalls.add(relBuilder.aggregateCall(SqlStdOperatorTable.COLLECT, aggOperands));
                    continue;
                }
                RelDataType rowType = PigToSqlAggregateRule.createRecordType(relBuilder, (List)aggCallColumns.get(rexCall));
                RexNode row = relBuilder.getRexBuilder().makeCall(rowType, (SqlOperator)SqlStdOperatorTable.ROW, aggOperands);
                aggCalls.add(relBuilder.aggregateCall(SqlStdOperatorTable.COLLECT, new RexNode[]{row}));
                continue;
            }
            SqlAggFunction udf = PigRelUdfConverter.getSqlAggFuncForPigUdf(rexCall);
            aggCalls.add(relBuilder.aggregateCall(udf, aggOperands));
        }
        relBuilder.aggregate(groupKey, aggCalls);
        RelDataType aggType = relBuilder.peek().getRowType();
        HashMap<RexNode, RexNode> pigCallToNewProjections = new HashMap<RexNode, RexNode>();
        for (int i = 0; i < pigAggUdfs.size(); ++i) {
            RelDataType oldFieldType;
            RexCall pigAgg = (RexCall)pigAggUdfs.get(i);
            int colIndex = i + groupCount;
            RelDataType fieldType = ((RelDataTypeField)aggType.getFieldList().get(colIndex)).getType();
            if (fieldType.equals(oldFieldType = pigAgg.getType())) {
                pigCallToNewProjections.put((RexNode)pigAgg, (RexNode)relBuilder.field(colIndex));
                continue;
            }
            pigCallToNewProjections.put((RexNode)pigAgg, relBuilder.getRexBuilder().makeCast(oldFieldType, (RexNode)relBuilder.field(colIndex)));
        }
        ArrayList<Object> newTopProjects = new ArrayList<Object>();
        List oldUpperProjects = oldTopProject.getProjects();
        for (RexNode rexNode : oldUpperProjects) {
            int groupRefIndex = PigToSqlAggregateRule.getGroupRefIndex(rexNode);
            if (groupRefIndex >= 0) {
                newTopProjects.add(relBuilder.field(groupRefIndex));
                continue;
            }
            if (rexNode instanceof RexInputRef && ((RexInputRef)rexNode).getIndex() == 0) {
                newTopProjects.add(oldMiddleProject.getProjects().get(0));
                continue;
            }
            RexCallReplacer replacer = needGroupingCol ? new RexCallReplacer(relBuilder.getRexBuilder(), pigCallToNewProjections, 1, (RexNode)relBuilder.field(groupCount - 1)) : new RexCallReplacer(relBuilder.getRexBuilder(), pigCallToNewProjections);
            newTopProjects.add(rexNode.accept((RexVisitor)replacer));
        }
        relBuilder.project(newTopProjects, (Iterable)oldTopProject.getRowType().getFieldNames());
        call.transformTo(relBuilder.build());
    }

    private static RelDataType createRecordType(RelBuilder relBuilder, List<Integer> fields) {
        ArrayList<String> destNames = new ArrayList<String>();
        ArrayList<RelDataType> destTypes = new ArrayList<RelDataType>();
        List fieldList = relBuilder.peek().getRowType().getFieldList();
        for (Integer index : fields) {
            RelDataTypeField field = (RelDataTypeField)fieldList.get(index);
            destNames.add(field.getName());
            destTypes.add(field.getType());
        }
        return PigTypes.TYPE_FACTORY.createStructType(destTypes, destNames);
    }

    private static int getGroupRefIndex(RexNode rex) {
        RexInputRef inputRef;
        RexFieldAccess fieldAccess;
        if (rex instanceof RexFieldAccess && (fieldAccess = (RexFieldAccess)rex).getReferenceExpr() instanceof RexInputRef && (inputRef = (RexInputRef)fieldAccess.getReferenceExpr()).getIndex() == 0) {
            return fieldAccess.getField().getIndex();
        }
        return -1;
    }

    private static List<Integer> getAggColumns(RexCall pigAggCall) {
        if (PigToSqlAggregateRule.isMultisetProjection(pigAggCall)) {
            return PigToSqlAggregateRule.getColsFromMultisetProjection(pigAggCall);
        }
        assert (pigAggCall.getOperands().size() == 1 && pigAggCall.getOperands().get(0) instanceof RexCall);
        RexCall pigBag = (RexCall)pigAggCall.getOperands().get(0);
        assert (pigBag.getOperands().size() == 1);
        RexNode pigBagInput = (RexNode)pigBag.getOperands().get(0);
        if (pigBagInput instanceof RexCall) {
            RexCall multisetProjection = (RexCall)pigBagInput;
            assert (PigToSqlAggregateRule.isMultisetProjection(multisetProjection));
            return PigToSqlAggregateRule.getColsFromMultisetProjection(multisetProjection);
        }
        return new ArrayList<Integer>();
    }

    private static List<Integer> getColsFromMultisetProjection(RexCall multisetProjection) {
        ArrayList<Integer> columns = new ArrayList<Integer>();
        assert (multisetProjection.getOperands().size() >= 1);
        for (int i = 1; i < multisetProjection.getOperands().size(); ++i) {
            RexLiteral indexLiteral = (RexLiteral)multisetProjection.getOperands().get(i);
            columns.add(((BigDecimal)indexLiteral.getValue()).intValue());
        }
        return columns;
    }

    private static boolean isMultisetProjection(RexCall rexCall) {
        return rexCall.getOperator().getName().equals(MULTISET_PROJECTION);
    }

    @Value.Immutable(singleton=false)
    public static interface Config
    extends RelRule.Config {
        default public PigToSqlAggregateRule toRule() {
            return new PigToSqlAggregateRule(this);
        }
    }

    private static class RexCallReplacer
    extends RexShuttle {
        private final Map<RexNode, RexNode> replacementMap;
        private final RexBuilder builder;
        private final int oldProjectCol;
        private final RexNode newProjectCol;

        RexCallReplacer(RexBuilder builder, Map<RexNode, RexNode> replacementMap, int oldProjectCol, RexNode newProjectCol) {
            this.replacementMap = replacementMap;
            this.builder = builder;
            this.oldProjectCol = oldProjectCol;
            this.newProjectCol = newProjectCol;
        }

        RexCallReplacer(RexBuilder builder, Map<RexNode, RexNode> replacementMap) {
            this(builder, replacementMap, -1, null);
        }

        public RexNode visitCall(RexCall call) {
            if (this.replacementMap.containsKey(call)) {
                return this.replacementMap.get(call);
            }
            ArrayList<Object> newOperands = new ArrayList<Object>();
            for (RexNode operand : call.operands) {
                if (this.replacementMap.containsKey(operand)) {
                    newOperands.add(this.replacementMap.get(operand));
                    continue;
                }
                newOperands.add(operand.accept((RexVisitor)this));
            }
            return this.builder.makeCall(call.type, call.op, newOperands);
        }

        public RexNode visitInputRef(RexInputRef inputRef) {
            if (inputRef.getIndex() == this.oldProjectCol && this.newProjectCol != null && inputRef.getType() == this.newProjectCol.getType()) {
                return this.newProjectCol;
            }
            return inputRef;
        }
    }

    private static class PigAggUdfFinder
    extends RexVisitorImpl<Void> {
        private final int projectCol;
        private final List<RexCall> pigAggCalls;
        private boolean projectColReferred;
        private boolean ignoreMultisetProject = false;

        PigAggUdfFinder(int projectCol) {
            super(true);
            this.projectCol = projectCol;
            this.pigAggCalls = new ArrayList<RexCall>();
            this.projectColReferred = false;
        }

        public Void visitCall(RexCall call) {
            if (PigRelUdfConverter.getSqlAggFuncForPigUdf(call) != null) {
                this.pigAggCalls.add(call);
                this.ignoreMultisetProject = true;
            } else if (PigToSqlAggregateRule.isMultisetProjection(call) && !this.ignoreMultisetProject) {
                this.pigAggCalls.add(call);
            }
            this.visitEach((Iterable)call.operands);
            return null;
        }

        public Void visitInputRef(RexInputRef inputRef) {
            if (inputRef.getIndex() == this.projectCol) {
                this.projectColReferred = true;
            }
            return null;
        }
    }
}

