package org.apache.calcite.sql2rel;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.calcite.rel.RelHomogeneousShuttle;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.CorrelationId;
import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.logical.LogicalCorrelate;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCorrelVariable;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexOver;
import org.apache.calcite.rex.RexPatternFieldRef;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexSubQuery;
import org.apache.calcite.rex.RexTableInputRef;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableList;
import org.apache.flink.calcite.shaded.com.google.common.collect.UnmodifiableIterator;
import org.apiguardian.api.API;

@API(since = "1.27", status = API.Status.EXPERIMENTAL)
/* loaded from: input_file:flink-table-planner.jar:org/apache/calcite/sql2rel/CorrelateProjectExtractor.class */
public final class CorrelateProjectExtractor extends RelHomogeneousShuttle {
    private final RelBuilderFactory builderFactory;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:flink-table-planner.jar:org/apache/calcite/sql2rel/CorrelateProjectExtractor$CallReplacer.class */
    public static final class CallReplacer extends RexShuttle {
        private final Map<RexNode, RexNode> mapping;

        CallReplacer(Map<RexNode, RexNode> map) {
            this.mapping = map;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.apache.calcite.rex.RexShuttle, org.apache.calcite.rex.RexVisitor
        /* renamed from: visitCall */
        public RexNode mo4658visitCall(RexCall rexCall) {
            RexNode rexNode = this.mapping.get(rexCall);
            return rexNode != null ? rexNode : super.mo4658visitCall(rexCall);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:flink-table-planner.jar:org/apache/calcite/sql2rel/CorrelateProjectExtractor$SimpleCorrelationCollector.class */
    public static final class SimpleCorrelationCollector extends RexShuttle {
        private final CorrelationId correlationId;
        private final Set<RexNode> correlations = new LinkedHashSet();

        SimpleCorrelationCollector(CorrelationId correlationId) {
            this.correlationId = correlationId;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.apache.calcite.rex.RexShuttle, org.apache.calcite.rex.RexVisitor
        /* renamed from: visitCall */
        public RexNode mo4658visitCall(RexCall rexCall) {
            if (!CorrelateProjectExtractor.isSimpleCorrelatedExpression(rexCall, this.correlationId)) {
                return super.mo4658visitCall(rexCall);
            }
            this.correlations.add(rexCall);
            return rexCall;
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // org.apache.calcite.rex.RexShuttle, org.apache.calcite.rex.RexVisitor
        /* renamed from: visitFieldAccess */
        public RexNode mo5353visitFieldAccess(RexFieldAccess rexFieldAccess) {
            if (!CorrelateProjectExtractor.isSimpleCorrelatedExpression(rexFieldAccess, this.correlationId)) {
                return super.mo5353visitFieldAccess(rexFieldAccess);
            }
            this.correlations.add(rexFieldAccess);
            return rexFieldAccess;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:flink-table-planner.jar:org/apache/calcite/sql2rel/CorrelateProjectExtractor$SimpleCorrelationDetector.class */
    public static class SimpleCorrelationDetector extends RexVisitorImpl<Boolean> {
        private final CorrelationId corrId;

        private SimpleCorrelationDetector(CorrelationId correlationId) {
            super(true);
            this.corrId = correlationId;
        }

        @Override // org.apache.calcite.rex.RexVisitorImpl, org.apache.calcite.rex.RexVisitor
        /* renamed from: visitOver */
        public Boolean mo5379visitOver(RexOver rexOver) {
            return Boolean.FALSE;
        }

        @Override // org.apache.calcite.rex.RexVisitorImpl, org.apache.calcite.rex.RexVisitor
        /* renamed from: visitSubQuery */
        public Boolean mo5136visitSubQuery(RexSubQuery rexSubQuery) {
            return Boolean.FALSE;
        }

        @Override // org.apache.calcite.rex.RexVisitorImpl, org.apache.calcite.rex.RexVisitor
        /* renamed from: visitCall */
        public Boolean mo4658visitCall(RexCall rexCall) {
            Boolean bool = null;
            UnmodifiableIterator<RexNode> it = rexCall.operands.iterator();
            while (it.hasNext()) {
                Boolean bool2 = (Boolean) it.next().accept(this);
                if (bool2 != null) {
                    bool = Boolean.valueOf(bool == null ? bool2.booleanValue() : bool.booleanValue() && bool2.booleanValue());
                }
            }
            return bool == null ? Boolean.FALSE : bool;
        }

        @Override // org.apache.calcite.rex.RexVisitorImpl, org.apache.calcite.rex.RexVisitor
        /* renamed from: visitFieldAccess */
        public Boolean mo5353visitFieldAccess(RexFieldAccess rexFieldAccess) {
            return (Boolean) rexFieldAccess.getReferenceExpr().accept(this);
        }

        @Override // org.apache.calcite.rex.RexVisitorImpl, org.apache.calcite.rex.RexVisitor
        /* renamed from: visitInputRef */
        public Boolean mo5341visitInputRef(RexInputRef rexInputRef) {
            return Boolean.FALSE;
        }

        @Override // org.apache.calcite.rex.RexVisitorImpl, org.apache.calcite.rex.RexVisitor
        /* renamed from: visitCorrelVariable */
        public Boolean mo5135visitCorrelVariable(RexCorrelVariable rexCorrelVariable) {
            return Boolean.valueOf(rexCorrelVariable.id.equals(this.corrId));
        }

        @Override // org.apache.calcite.rex.RexVisitorImpl, org.apache.calcite.rex.RexVisitor
        /* renamed from: visitTableInputRef */
        public Boolean mo5376visitTableInputRef(RexTableInputRef rexTableInputRef) {
            return Boolean.FALSE;
        }

        @Override // org.apache.calcite.rex.RexVisitorImpl, org.apache.calcite.rex.RexVisitor
        /* renamed from: visitLocalRef */
        public Boolean mo5380visitLocalRef(RexLocalRef rexLocalRef) {
            return Boolean.FALSE;
        }

        @Override // org.apache.calcite.rex.RexVisitorImpl, org.apache.calcite.rex.RexVisitor
        /* renamed from: visitPatternFieldRef */
        public Boolean mo5375visitPatternFieldRef(RexPatternFieldRef rexPatternFieldRef) {
            return Boolean.FALSE;
        }
    }

    public CorrelateProjectExtractor(RelBuilderFactory relBuilderFactory) {
        this.builderFactory = relBuilderFactory;
    }

    @Override // org.apache.calcite.rel.RelHomogeneousShuttle, org.apache.calcite.rel.RelShuttleImpl, org.apache.calcite.rel.RelShuttle
    public RelNode visit(LogicalCorrelate logicalCorrelate) {
        List<Integer> asList;
        RelNode accept = logicalCorrelate.getLeft().accept(this);
        RelNode accept2 = logicalCorrelate.getRight().accept(this);
        int fieldCount = accept.getRowType().getFieldCount();
        Set<RexNode> findCorrelationDependentCalls = findCorrelationDependentCalls(logicalCorrelate.getCorrelationId(), accept2);
        if (findCorrelationDependentCalls.stream().allMatch(rexNode -> {
            return rexNode instanceof RexFieldAccess;
        })) {
            return (logicalCorrelate.getLeft().equals(accept) && logicalCorrelate.getRight().equals(accept2)) ? logicalCorrelate : logicalCorrelate.copy(logicalCorrelate.getTraitSet(), accept, accept2, logicalCorrelate.getCorrelationId(), logicalCorrelate.getRequiredColumns(), logicalCorrelate.getJoinType());
        }
        RelBuilder create = this.builderFactory.create(logicalCorrelate.getCluster(), null);
        create.push(accept);
        ArrayList arrayList = new ArrayList();
        Iterator<RexNode> it = findCorrelationDependentCalls.iterator();
        while (it.hasNext()) {
            arrayList.add(replaceCorrelationsWithInputRef(it.next(), create));
        }
        create.projectPlus(arrayList);
        HashMap hashMap = new HashMap();
        for (RexNode rexNode2 : findCorrelationDependentCalls) {
            RexBuilder rexBuilder = create.getRexBuilder();
            hashMap.put(rexNode2, rexBuilder.makeFieldAccess(rexBuilder.makeCorrel(create.peek().getRowType(), logicalCorrelate.getCorrelationId()), fieldCount + hashMap.size()));
        }
        ImmutableList<RexNode> fields = create.fields((List<? extends Number>) ImmutableBitSet.range(fieldCount, fieldCount + arrayList.size()).asList());
        int size = create.fields().size();
        RelNode replaceExpressionsUsingMap = replaceExpressionsUsingMap(accept2, hashMap);
        create.push(replaceExpressionsUsingMap);
        create.correlate(logicalCorrelate.getJoinType(), logicalCorrelate.getCorrelationId(), fields);
        switch (logicalCorrelate.getJoinType()) {
            case SEMI:
            case ANTI:
                asList = ImmutableBitSet.range(0, fieldCount).asList();
                break;
            case LEFT:
            case INNER:
                asList = ImmutableBitSet.builder().set(0, fieldCount).set(size, size + replaceExpressionsUsingMap.getRowType().getFieldCount()).build().asList();
                break;
            default:
                throw new AssertionError(logicalCorrelate.getJoinType());
        }
        create.project(create.fields((List<? extends Number>) asList));
        return create.build();
    }

    private static Set<RexNode> findCorrelationDependentCalls(CorrelationId correlationId, RelNode relNode) {
        final SimpleCorrelationCollector simpleCorrelationCollector = new SimpleCorrelationCollector(correlationId);
        relNode.accept(new RelHomogeneousShuttle() { // from class: org.apache.calcite.sql2rel.CorrelateProjectExtractor.1
            @Override // org.apache.calcite.rel.RelShuttleImpl, org.apache.calcite.rel.RelShuttle
            public RelNode visit(RelNode relNode2) {
                if ((relNode2 instanceof Project) || (relNode2 instanceof Filter)) {
                    relNode2.accept(SimpleCorrelationCollector.this);
                }
                return super.visit(relNode2);
            }
        });
        return simpleCorrelationCollector.correlations;
    }

    private static RelNode replaceExpressionsUsingMap(RelNode relNode, Map<RexNode, RexNode> map) {
        final CallReplacer callReplacer = new CallReplacer(map);
        return relNode.accept(new RelHomogeneousShuttle() { // from class: org.apache.calcite.sql2rel.CorrelateProjectExtractor.2
            @Override // org.apache.calcite.rel.RelShuttleImpl, org.apache.calcite.rel.RelShuttle
            public RelNode visit(RelNode relNode2) {
                return super.visitChildren(relNode2).accept(CallReplacer.this);
            }
        });
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean isSimpleCorrelatedExpression(RexNode rexNode, CorrelationId correlationId) {
        Boolean bool = (Boolean) rexNode.accept(new SimpleCorrelationDetector(correlationId));
        return (bool == null ? Boolean.FALSE : bool).booleanValue();
    }

    private static RexNode replaceCorrelationsWithInputRef(RexNode rexNode, final RelBuilder relBuilder) {
        return (RexNode) rexNode.accept(new RexShuttle() { // from class: org.apache.calcite.sql2rel.CorrelateProjectExtractor.3
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // org.apache.calcite.rex.RexShuttle, org.apache.calcite.rex.RexVisitor
            /* renamed from: visitFieldAccess */
            public RexNode mo5353visitFieldAccess(RexFieldAccess rexFieldAccess) {
                return rexFieldAccess.getReferenceExpr() instanceof RexCorrelVariable ? RelBuilder.this.field(rexFieldAccess.getField().getIndex()) : super.mo5353visitFieldAccess(rexFieldAccess);
            }
        });
    }
}
