/*
 * Decompiled with CFR 0.152.
 */
package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperandChildren;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.metadata.RelColumnOrigin;
import org.apache.calcite.rel.metadata.RelMetadataQuery;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;

public final class HiveExpandDistinctAggregatesRule
extends RelOptRule {
    public static final HiveExpandDistinctAggregatesRule INSTANCE = new HiveExpandDistinctAggregatesRule(HiveAggregate.class, HiveRelFactories.HIVE_PROJECT_FACTORY);
    private static RelFactories.ProjectFactory projFactory;

    public HiveExpandDistinctAggregatesRule(Class<? extends Aggregate> clazz, RelFactories.ProjectFactory projectFactory) {
        super(HiveExpandDistinctAggregatesRule.operand(clazz, (RelOptRuleOperandChildren)HiveExpandDistinctAggregatesRule.any()));
        projFactory = projectFactory;
    }

    public void onMatch(RelOptRuleCall call) {
        Aggregate aggregate = (Aggregate)call.rel(0);
        if (!aggregate.containsDistinctCall()) {
            return;
        }
        int nonDistinctCount = 0;
        LinkedHashSet argListSets = new LinkedHashSet();
        for (AggregateCall aggCall : aggregate.getAggCallList()) {
            if (!aggCall.isDistinct()) {
                ++nonDistinctCount;
                continue;
            }
            ArrayList<Integer> argList = new ArrayList<Integer>();
            for (Integer arg : aggCall.getArgList()) {
                argList.add(arg);
            }
            argListSets.add(argList);
        }
        Util.permAssert((argListSets.size() > 0 ? 1 : 0) != 0, (String)"containsDistinctCall lied");
        if (nonDistinctCount == 0 && argListSets.size() == 1) {
            for (Integer arg : (List)argListSets.iterator().next()) {
                Set colOrigs = RelMetadataQuery.getColumnOrigins((RelNode)aggregate, (int)arg);
                if (null == colOrigs) continue;
                for (RelColumnOrigin colOrig : colOrigs) {
                    RelOptHiveTable hiveTbl = (RelOptHiveTable)colOrig.getOriginTable();
                    if (!hiveTbl.getPartColInfoMap().containsKey(colOrig.getOriginColumnOrdinal())) continue;
                    return;
                }
            }
            RelNode converted = this.convertMonopole(aggregate, (List)argListSets.iterator().next());
            call.transformTo(converted);
            return;
        }
    }

    private RelNode convertMonopole(Aggregate aggregate, List<Integer> argList) {
        HashMap<Integer, Integer> sourceOf = new HashMap<Integer, Integer>();
        Aggregate distinct = HiveExpandDistinctAggregatesRule.createSelectDistinct(aggregate, argList, sourceOf);
        ArrayList<AggregateCall> newAggCalls = Lists.newArrayList(aggregate.getAggCallList());
        HiveExpandDistinctAggregatesRule.rewriteAggCalls(newAggCalls, argList, sourceOf);
        int cardinality = aggregate.getGroupSet().cardinality();
        return aggregate.copy(aggregate.getTraitSet(), (RelNode)distinct, aggregate.indicator, ImmutableBitSet.range((int)cardinality), null, newAggCalls);
    }

    private static void rewriteAggCalls(List<AggregateCall> newAggCalls, List<Integer> argList, Map<Integer, Integer> sourceOf) {
        for (int i = 0; i < newAggCalls.size(); ++i) {
            AggregateCall aggCall = newAggCalls.get(i);
            if (!aggCall.isDistinct() || !aggCall.getArgList().equals(argList)) continue;
            int argCount = aggCall.getArgList().size();
            ArrayList<Integer> newArgs = new ArrayList<Integer>(argCount);
            for (int j = 0; j < argCount; ++j) {
                Integer arg = (Integer)aggCall.getArgList().get(j);
                newArgs.add(sourceOf.get(arg));
            }
            AggregateCall newAggCall = new AggregateCall(aggCall.getAggregation(), false, newArgs, aggCall.getType(), aggCall.getName());
            newAggCalls.set(i, newAggCall);
        }
    }

    private static Aggregate createSelectDistinct(Aggregate aggregate, List<Integer> argList, Map<Integer, Integer> sourceOf) {
        ArrayList<Pair> projects = new ArrayList<Pair>();
        RelNode child = aggregate.getInput();
        List childFields = child.getRowType().getFieldList();
        Iterator<Integer> iterator = aggregate.getGroupSet().iterator();
        while (iterator.hasNext()) {
            int i = (Integer)iterator.next();
            sourceOf.put(i, projects.size());
            projects.add(RexInputRef.of2((int)i, (List)childFields));
        }
        for (Integer arg : argList) {
            if (sourceOf.get(arg) != null) continue;
            sourceOf.put(arg, projects.size());
            projects.add(RexInputRef.of2((int)arg, (List)childFields));
        }
        RelNode project = projFactory.createProject(child, Pair.left(projects), Pair.right(projects));
        return aggregate.copy(aggregate.getTraitSet(), project, false, ImmutableBitSet.range((int)projects.size()), null, ImmutableList.of());
    }
}

