package org.apache.hadoop.hive.ql.optimizer.calcite.rules;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.rules.JoinCommuteRule;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.Pair;
import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelOptUtil;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveJoin;
import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveMultiJoin;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/hadoop/hive/ql/optimizer/calcite/rules/HiveJoinToMultiJoinRule.class */
public class HiveJoinToMultiJoinRule extends RelOptRule {
    public static final HiveJoinToMultiJoinRule INSTANCE;
    private final RelFactories.ProjectFactory projectFactory;
    private static final transient Logger LOG;
    static final /* synthetic */ boolean $assertionsDisabled;

    public HiveJoinToMultiJoinRule(Class<? extends Join> cls, RelFactories.ProjectFactory projectFactory) {
        super(operand(cls, operand(RelNode.class, any()), new RelOptRuleOperand[]{operand(RelNode.class, any())}));
        this.projectFactory = projectFactory;
    }

    public void onMatch(RelOptRuleCall relOptRuleCall) {
        HiveJoin hiveJoin;
        HiveJoin hiveJoin2 = (HiveJoin) relOptRuleCall.rel(0);
        RelNode rel = relOptRuleCall.rel(1);
        RelNode rel2 = relOptRuleCall.rel(2);
        RelNode mergeJoin = mergeJoin(hiveJoin2, rel, rel2);
        if (mergeJoin != null) {
            relOptRuleCall.transformTo(mergeJoin);
            return;
        }
        Project swap = JoinCommuteRule.swap(hiveJoin2, true);
        if (!$assertionsDisabled && swap == null) {
            throw new AssertionError();
        }
        Project project = null;
        if (swap instanceof HiveJoin) {
            hiveJoin = (HiveJoin) swap;
        } else {
            project = swap;
            hiveJoin = (HiveJoin) swap.getInput(0);
        }
        RelNode mergeJoin2 = mergeJoin(hiveJoin, rel2, rel);
        if (mergeJoin2 != null) {
            if (project != null) {
                mergeJoin2 = this.projectFactory.createProject(mergeJoin2, project.getChildExps(), project.getRowType().getFieldNames());
            }
            relOptRuleCall.transformTo(mergeJoin2);
        }
    }

    private static RelNode mergeJoin(HiveJoin hiveJoin, RelNode relNode, RelNode relNode2) {
        RexNode condition;
        List<Pair<Integer, Integer>> joinInputs;
        List<JoinRelType> joinTypes;
        List<RexNode> joinFilters;
        boolean z;
        RexBuilder rexBuilder = hiveJoin.getCluster().getRexBuilder();
        ArrayList newArrayList = Lists.newArrayList();
        ArrayList newArrayList2 = Lists.newArrayList();
        ArrayList newArrayList3 = Lists.newArrayList();
        ArrayList newArrayList4 = Lists.newArrayList();
        ArrayList newArrayList5 = Lists.newArrayList();
        if (!(relNode instanceof HiveJoin) && !(relNode instanceof HiveMultiJoin)) {
            return null;
        }
        if (relNode instanceof HiveJoin) {
            HiveJoin hiveJoin2 = (HiveJoin) relNode;
            condition = hiveJoin2.getCondition();
            joinInputs = ImmutableList.of(Pair.of(0, 1));
            joinTypes = ImmutableList.of(hiveJoin2.getJoinType());
            joinFilters = ImmutableList.of(hiveJoin2.getJoinFilter());
            try {
                z = isCombinableJoin(hiveJoin, hiveJoin2);
            } catch (CalciteSemanticException e) {
                LOG.trace("Failed to merge join-join", e);
                z = false;
            }
        } else {
            HiveMultiJoin hiveMultiJoin = (HiveMultiJoin) relNode;
            condition = hiveMultiJoin.getCondition();
            joinInputs = hiveMultiJoin.getJoinInputs();
            joinTypes = hiveMultiJoin.getJoinTypes();
            joinFilters = hiveMultiJoin.getJoinFilters();
            try {
                z = isCombinableJoin(hiveJoin, hiveMultiJoin);
            } catch (CalciteSemanticException e2) {
                LOG.trace("Failed to merge join-multijoin", e2);
                z = false;
            }
        }
        if (!z) {
            return null;
        }
        newArrayList2.add(condition);
        for (int i = 0; i < joinInputs.size(); i++) {
            newArrayList3.add(joinInputs.get(i));
            newArrayList4.add(joinTypes.get(i));
            newArrayList5.add(joinFilters.get(i));
        }
        newArrayList.addAll(relNode.getInputs());
        int size = newArrayList.size();
        newArrayList.add(relNode2);
        newArrayList2.add(hiveJoin.getCondition());
        if (newArrayList2.size() == 1) {
            return null;
        }
        ImmutableList of = ImmutableList.of();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (int i2 = 0; i2 < newArrayList.size(); i2++) {
            arrayList.add(new ArrayList());
        }
        try {
            RexNode splitHiveJoinCondition = HiveRelOptUtil.splitHiveJoinCondition(of, newArrayList, hiveJoin.getCondition(), arrayList, arrayList2, null);
            ImmutableBitSet.Builder builder = ImmutableBitSet.builder();
            for (int i3 = 0; i3 < newArrayList.size(); i3++) {
                if (!((List) arrayList.get(i3)).isEmpty()) {
                    builder.set(i3);
                }
            }
            ImmutableBitSet build = builder.build();
            ImmutableBitSet intersect = build.intersect(ImmutableBitSet.range(size));
            ImmutableBitSet intersect2 = build.intersect(ImmutableBitSet.range(size, newArrayList.size()));
            if (hiveJoin.getJoinType() != JoinRelType.INNER && (intersect.cardinality() > 1 || intersect2.cardinality() > 1)) {
                return null;
            }
            if (hiveJoin.getJoinType() != JoinRelType.INNER) {
                newArrayList3.add(Pair.of(Integer.valueOf(build.nextSetBit(0)), Integer.valueOf(build.nextSetBit(size))));
                newArrayList4.add(hiveJoin.getJoinType());
                newArrayList5.add(splitHiveJoinCondition);
            } else {
                Iterator it = intersect.iterator();
                while (it.hasNext()) {
                    int intValue = ((Integer) it.next()).intValue();
                    Iterator it2 = intersect2.iterator();
                    while (it2.hasNext()) {
                        newArrayList3.add(Pair.of(Integer.valueOf(intValue), Integer.valueOf(((Integer) it2.next()).intValue())));
                        newArrayList4.add(hiveJoin.getJoinType());
                        newArrayList5.add(splitHiveJoinCondition);
                    }
                }
            }
            return new HiveMultiJoin(hiveJoin.getCluster(), newArrayList, RexUtil.flatten(rexBuilder, RexUtil.composeConjunction(rexBuilder, newArrayList2, false)), hiveJoin.getRowType(), newArrayList3, newArrayList4, newArrayList5);
        } catch (CalciteSemanticException e3) {
            LOG.trace("Failed to merge joins", e3);
            return null;
        }
    }

    private static boolean isCombinableJoin(HiveJoin hiveJoin, HiveJoin hiveJoin2) throws CalciteSemanticException {
        return isCombinablePredicate(HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(hiveJoin, hiveJoin.getCondition()), HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(hiveJoin2, hiveJoin2.getCondition()), hiveJoin2.getInputs().size());
    }

    private static boolean isCombinableJoin(HiveJoin hiveJoin, HiveMultiJoin hiveMultiJoin) throws CalciteSemanticException {
        return isCombinablePredicate(HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(hiveJoin, hiveJoin.getCondition()), HiveCalciteUtil.JoinPredicateInfo.constructJoinPredicateInfo(hiveMultiJoin, hiveMultiJoin.getCondition()), hiveMultiJoin.getInputs().size());
    }

    private static boolean isCombinablePredicate(HiveCalciteUtil.JoinPredicateInfo joinPredicateInfo, HiveCalciteUtil.JoinPredicateInfo joinPredicateInfo2, int i) throws CalciteSemanticException {
        Set<Integer> projsJoinKeysInChildSchema = joinPredicateInfo.getProjsJoinKeysInChildSchema(0);
        if (projsJoinKeysInChildSchema.isEmpty()) {
            return false;
        }
        for (int i2 = 0; i2 < i; i2++) {
            if (projsJoinKeysInChildSchema.equals(joinPredicateInfo2.getProjsJoinKeysInJoinSchema(i2))) {
                return true;
            }
        }
        return false;
    }

    static {
        $assertionsDisabled = !HiveJoinToMultiJoinRule.class.desiredAssertionStatus();
        INSTANCE = new HiveJoinToMultiJoinRule(HiveJoin.class, HiveRelFactories.HIVE_PROJECT_FACTORY);
        LOG = LoggerFactory.getLogger(HiveJoinToMultiJoinRule.class);
    }
}
