/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelTrait;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.TraitsUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveExcept;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveFilter;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveRelNode;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveTableFunctionScan;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveUnion;
import org.apache.hadoop.hive.ql.optimizer.calcite.rules.HiveIntersectRewriteRule;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.SqlFunctionConverter;
import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hive.com.google.common.base.Function;
import org.apache.hive.com.google.common.collect.ImmutableList;
import org.apache.hive.com.google.common.collect.Lists;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class HiveExceptRewriteRule
extends RelOptRule {
    public static final HiveExceptRewriteRule INSTANCE = new HiveExceptRewriteRule();
    protected static final Logger LOG = LoggerFactory.getLogger(HiveIntersectRewriteRule.class);

    private HiveExceptRewriteRule() {
        super(HiveExceptRewriteRule.operand(HiveExcept.class, HiveExceptRewriteRule.any()));
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        HiveExcept hiveExcept = (HiveExcept)call.rel(0);
        RelOptCluster cluster = hiveExcept.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        ImmutableList.Builder bldr = new ImmutableList.Builder();
        try {
            bldr.add(this.createFirstGB(hiveExcept.getInputs().get(0), true, cluster, rexBuilder));
            bldr.add(this.createFirstGB(hiveExcept.getInputs().get(1), false, cluster, rexBuilder));
        }
        catch (CalciteSemanticException e) {
            LOG.debug(e.toString());
            throw new RuntimeException(e);
        }
        HiveUnion union = new HiveUnion(cluster, TraitsUtil.getDefaultTraitSet(cluster), (List<RelNode>)((Object)bldr.build()));
        ArrayList<RexNode> gbChildProjLst = Lists.newArrayList();
        ArrayList<Integer> groupSetPositions = Lists.newArrayList();
        int unionColumnSize = union.getRowType().getFieldList().size();
        for (int cInd = 0; cInd < unionColumnSize; ++cInd) {
            gbChildProjLst.add(rexBuilder.makeInputRef(union, cInd));
            if (cInd >= unionColumnSize - 2) continue;
            groupSetPositions.add(cInd);
        }
        try {
            gbChildProjLst.add(this.multiply(rexBuilder.makeInputRef(union, unionColumnSize - 2), rexBuilder.makeInputRef(union, unionColumnSize - 1), cluster, rexBuilder));
        }
        catch (CalciteSemanticException e) {
            LOG.debug(e.toString());
            throw new RuntimeException(e);
        }
        HiveProject gbInputRel = null;
        try {
            gbInputRel = HiveProject.create(union, gbChildProjLst, null);
        }
        catch (CalciteSemanticException e) {
            LOG.debug(e.toString());
            throw new RuntimeException(e);
        }
        ArrayList<AggregateCall> aggregateCalls = Lists.newArrayList();
        RelDataType aggFnRetType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory());
        AggregateCall aggregateCall = HiveCalciteUtil.createSingleArgAggCall("sum", cluster, TypeInfoFactory.longTypeInfo, unionColumnSize - 1, aggFnRetType);
        aggregateCalls.add(aggregateCall);
        aggregateCall = HiveCalciteUtil.createSingleArgAggCall("sum", cluster, TypeInfoFactory.longTypeInfo, unionColumnSize, aggFnRetType);
        aggregateCalls.add(aggregateCall);
        ImmutableBitSet groupSet = ImmutableBitSet.of(groupSetPositions);
        HiveAggregate aggregateRel = new HiveAggregate(cluster, cluster.traitSetOf((RelTrait)HiveRelNode.CONVENTION), gbInputRel, false, groupSet, null, aggregateCalls);
        if (!hiveExcept.all) {
            HiveFilter filterRel = null;
            try {
                filterRel = new HiveFilter(cluster, cluster.traitSetOf((RelTrait)HiveRelNode.CONVENTION), aggregateRel, this.makeFilterExprForExceptDistinct(aggregateRel, unionColumnSize, cluster, rexBuilder));
            }
            catch (CalciteSemanticException e) {
                LOG.debug(e.toString());
                throw new RuntimeException(e);
            }
            HashSet<Integer> projectOutColumnPositions = new HashSet<Integer>();
            projectOutColumnPositions.add(filterRel.getRowType().getFieldList().size() - 2);
            projectOutColumnPositions.add(filterRel.getRowType().getFieldList().size() - 1);
            try {
                call.transformTo(HiveCalciteUtil.createProjectWithoutColumn(filterRel, projectOutColumnPositions));
            }
            catch (CalciteSemanticException e) {
                LOG.debug(e.toString());
                throw new RuntimeException(e);
            }
        }
        List<RexNode> originalInputRefs = Lists.transform(aggregateRel.getRowType().getFieldList(), new Function<RelDataTypeField, RexNode>(){

            @Override
            public RexNode apply(RelDataTypeField input) {
                return new RexInputRef(input.getIndex(), input.getType());
            }
        });
        ArrayList<RexNode> copyInputRefs = new ArrayList<RexNode>();
        try {
            copyInputRefs.add(this.makeExprForExceptAll(aggregateRel, unionColumnSize, cluster, rexBuilder));
        }
        catch (CalciteSemanticException e) {
            LOG.debug(e.toString());
            throw new RuntimeException(e);
        }
        for (int i = 0; i < originalInputRefs.size() - 2; ++i) {
            copyInputRefs.add(originalInputRefs.get(i));
        }
        HiveProject srcRel = null;
        try {
            srcRel = HiveProject.create(aggregateRel, copyInputRefs, null);
            HiveTableFunctionScan udtf = HiveCalciteUtil.createUDTFForSetOp(cluster, srcRel);
            HashSet<Integer> projectOutColumnPositions = new HashSet<Integer>();
            projectOutColumnPositions.add(0);
            call.transformTo(HiveCalciteUtil.createProjectWithoutColumn(udtf, projectOutColumnPositions));
        }
        catch (SemanticException e) {
            LOG.debug(e.toString());
            throw new RuntimeException(e);
        }
    }

    private RelNode createFirstGB(RelNode input, boolean left, RelOptCluster cluster, RexBuilder rexBuilder) throws CalciteSemanticException {
        ArrayList<RexNode> gbChildProjLst = Lists.newArrayList();
        ArrayList<Integer> groupSetPositions = Lists.newArrayList();
        for (int cInd = 0; cInd < input.getRowType().getFieldList().size(); ++cInd) {
            gbChildProjLst.add(rexBuilder.makeInputRef(input, cInd));
            groupSetPositions.add(cInd);
        }
        if (left) {
            gbChildProjLst.add(rexBuilder.makeBigintLiteral(new BigDecimal(2)));
        } else {
            gbChildProjLst.add(rexBuilder.makeBigintLiteral(new BigDecimal(1)));
        }
        groupSetPositions.add(input.getRowType().getFieldList().size());
        HiveProject gbInputRel = HiveProject.create(input, gbChildProjLst, null);
        ImmutableBitSet groupSet = ImmutableBitSet.of(groupSetPositions);
        ArrayList<AggregateCall> aggregateCalls = Lists.newArrayList();
        RelDataType aggFnRetType = TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory());
        AggregateCall aggregateCall = HiveCalciteUtil.createSingleArgAggCall("count", cluster, TypeInfoFactory.longTypeInfo, input.getRowType().getFieldList().size(), aggFnRetType);
        aggregateCalls.add(aggregateCall);
        return new HiveAggregate(cluster, cluster.traitSetOf((RelTrait)HiveRelNode.CONVENTION), gbInputRel, false, groupSet, null, aggregateCalls);
    }

    private RexNode multiply(RexNode r1, RexNode r2, RelOptCluster cluster, RexBuilder rexBuilder) throws CalciteSemanticException {
        ArrayList<RexNode> childRexNodeLst = new ArrayList<RexNode>();
        childRexNodeLst.add(r1);
        childRexNodeLst.add(r2);
        ImmutableList.Builder calciteArgTypesBldr = new ImmutableList.Builder();
        calciteArgTypesBldr.add(TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
        calciteArgTypesBldr.add(TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
        return rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn("*", (ImmutableList<RelDataType>)calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), true), childRexNodeLst);
    }

    private RexNode makeFilterExprForExceptDistinct(HiveRelNode input, int columnSize, RelOptCluster cluster, RexBuilder rexBuilder) throws CalciteSemanticException {
        ArrayList<RexNode> childRexNodeLst = new ArrayList<RexNode>();
        RexInputRef a = rexBuilder.makeInputRef(input, columnSize - 2);
        RexLiteral zero = rexBuilder.makeBigintLiteral(new BigDecimal(0));
        childRexNodeLst.add(a);
        childRexNodeLst.add(zero);
        ImmutableList.Builder calciteArgTypesBldr = new ImmutableList.Builder();
        calciteArgTypesBldr.add(TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
        calciteArgTypesBldr.add(TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
        RexNode aMorethanZero = rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn(">", (ImmutableList<RelDataType>)calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), false), childRexNodeLst);
        childRexNodeLst = new ArrayList();
        RexLiteral two = rexBuilder.makeBigintLiteral(new BigDecimal(2));
        childRexNodeLst.add(a);
        childRexNodeLst.add(two);
        RexNode twoa = rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn("*", (ImmutableList<RelDataType>)calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), false), childRexNodeLst);
        childRexNodeLst = new ArrayList();
        RexInputRef b = rexBuilder.makeInputRef(input, columnSize - 1);
        childRexNodeLst.add(twoa);
        childRexNodeLst.add(b);
        RexNode twoaEqualTob = rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn("=", (ImmutableList<RelDataType>)calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), false), childRexNodeLst);
        childRexNodeLst = new ArrayList();
        childRexNodeLst.add(aMorethanZero);
        childRexNodeLst.add(twoaEqualTob);
        return rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn("and", (ImmutableList<RelDataType>)calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), false), childRexNodeLst);
    }

    private RexNode makeExprForExceptAll(HiveRelNode input, int columnSize, RelOptCluster cluster, RexBuilder rexBuilder) throws CalciteSemanticException {
        ArrayList<RexNode> childRexNodeLst = new ArrayList<RexNode>();
        ImmutableList.Builder calciteArgTypesBldr = new ImmutableList.Builder();
        calciteArgTypesBldr.add(TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
        calciteArgTypesBldr.add(TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()));
        RexInputRef a = rexBuilder.makeInputRef(input, columnSize - 2);
        RexLiteral three = rexBuilder.makeBigintLiteral(new BigDecimal(3));
        childRexNodeLst.add(three);
        childRexNodeLst.add(a);
        RexNode threea = rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn("*", (ImmutableList<RelDataType>)calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), false), childRexNodeLst);
        RexLiteral two = rexBuilder.makeBigintLiteral(new BigDecimal(2));
        RexInputRef b = rexBuilder.makeInputRef(input, columnSize - 1);
        childRexNodeLst = new ArrayList();
        childRexNodeLst.add(two);
        childRexNodeLst.add(b);
        RexNode twob = rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn("*", (ImmutableList<RelDataType>)calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), false), childRexNodeLst);
        childRexNodeLst = new ArrayList();
        childRexNodeLst.add(twob);
        childRexNodeLst.add(threea);
        return rexBuilder.makeCall(SqlFunctionConverter.getCalciteFn("-", (ImmutableList<RelDataType>)calciteArgTypesBldr.build(), TypeConverter.convert(TypeInfoFactory.longTypeInfo, cluster.getTypeFactory()), false), childRexNodeLst);
    }
}

