package org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rules;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Set;
import org.apache.beam.vendor.calcite.v1_20_0.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRule;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRuleCall;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptUtil;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.Aggregate;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.AggregateCall;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.Join;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.JoinRelType;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.RelFactories;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.logical.LogicalJoin;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexUtil;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.RelBuilder;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.tools.RelBuilderFactory;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.util.ImmutableBitSet;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.util.mapping.Mappings;

/* loaded from: input_file:org/apache/beam/vendor/calcite/v1_20_0/org/apache/calcite/rel/rules/AggregateJoinJoinRemoveRule.class */
public class AggregateJoinJoinRemoveRule extends RelOptRule {
    public static final AggregateJoinJoinRemoveRule INSTANCE = new AggregateJoinJoinRemoveRule(LogicalAggregate.class, LogicalJoin.class, RelFactories.LOGICAL_BUILDER);

    public AggregateJoinJoinRemoveRule(Class<? extends Aggregate> cls, Class<? extends Join> cls2, RelBuilderFactory relBuilderFactory) {
        super(operand(cls, operandJ(cls2, null, join -> {
            return join.getJoinType() == JoinRelType.LEFT;
        }, operandJ(cls2, null, join2 -> {
            return join2.getJoinType() == JoinRelType.LEFT;
        }, any()), new RelOptRuleOperand[0]), new RelOptRuleOperand[0]), relBuilderFactory, null);
    }

    @Override // org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelOptRule
    public void onMatch(RelOptRuleCall relOptRuleCall) {
        Aggregate aggregate = (Aggregate) relOptRuleCall.rel(0);
        Join join = (Join) relOptRuleCall.rel(1);
        Join join2 = (Join) relOptRuleCall.rel(2);
        int fieldCount = join2.getLeft().getRowType().getFieldCount();
        Set<Integer> allFields = RelOptUtil.getAllFields(aggregate);
        if (allFields.stream().anyMatch(num -> {
            return num.intValue() >= fieldCount && num.intValue() < join2.getRowType().getFieldCount();
        }) || aggregate.getAggCallList().stream().anyMatch(aggregateCall -> {
            return !aggregateCall.isDistinct();
        })) {
            return;
        }
        ArrayList arrayList = new ArrayList();
        RelOptUtil.splitJoinCondition(join.getLeft(), join.getRight(), join.getCondition(), arrayList, new ArrayList(), new ArrayList());
        if (arrayList.stream().anyMatch(num2 -> {
            return num2.intValue() >= fieldCount;
        })) {
            return;
        }
        ArrayList arrayList2 = new ArrayList();
        RelOptUtil.splitJoinCondition(join2.getLeft(), join2.getRight(), join2.getCondition(), arrayList2, new ArrayList(), new ArrayList());
        if (arrayList.equals(arrayList2)) {
            int fieldCount2 = join2.getRight().getRowType().getFieldCount();
            RelBuilder builder = relOptRuleCall.builder();
            RelNode build = builder.push(join2.getLeft()).push(join.getRight()).join(join.getJoinType(), RexUtil.shift(join.getCondition(), fieldCount, -fieldCount2)).build();
            HashMap hashMap = new HashMap();
            allFields.forEach(num3 -> {
            });
            ImmutableBitSet permute = aggregate.getGroupSet().permute(hashMap);
            ImmutableList.Builder builder2 = ImmutableList.builder();
            int fieldCount3 = aggregate.getInput().getRowType().getFieldCount();
            Mappings.TargetMapping target = Mappings.target(hashMap, fieldCount3, fieldCount3);
            aggregate.getAggCallList().forEach(aggregateCall2 -> {
                builder2.add((ImmutableList.Builder) aggregateCall2.transform(target));
            });
            relOptRuleCall.transformTo(builder.push(build).aggregate(builder.groupKey(permute), (List<AggregateCall>) builder2.build()).build());
        }
    }
}
