package org.apache.flink.table.planner.plan.nodes.exec.common;

import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexFieldAccess;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.dag.Transformation;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.ReadableConfig;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
import org.apache.flink.streaming.api.operators.StreamOperator;
import org.apache.flink.streaming.api.transformations.OneInputTransformation;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.connector.Projection;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.functions.python.PythonFunctionInfo;
import org.apache.flink.table.functions.python.PythonFunctionKind;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
import org.apache.flink.table.planner.codegen.ProjectionCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeConfig;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeContext;
import org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
import org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTranslator;
import org.apache.flink.table.planner.plan.nodes.exec.utils.CommonPythonUtil;
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
import org.apache.flink.table.planner.plan.utils.PythonUtil;
import org.apache.flink.table.runtime.generated.GeneratedProjection;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.RowType;
import org.apache.flink.util.Preconditions;

/* loaded from: input_file:flink-table-planner.jar:org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecPythonCalc.class */
public abstract class CommonExecPythonCalc extends ExecNodeBase<RowData> implements SingleTransformationTranslator<RowData> {
    public static final String PYTHON_CALC_TRANSFORMATION = "python-calc";
    public static final String FIELD_NAME_PROJECTION = "projection";
    private static final String PYTHON_SCALAR_FUNCTION_OPERATOR_NAME = "org.apache.flink.table.runtime.operators.python.scalar.PythonScalarFunctionOperator";
    private static final String EMBEDDED_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME = "org.apache.flink.table.runtime.operators.python.scalar.EmbeddedPythonScalarFunctionOperator";
    private static final String ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME = "org.apache.flink.table.runtime.operators.python.scalar.arrow.ArrowPythonScalarFunctionOperator";

    @JsonProperty("projection")
    private final List<RexNode> projection;

    public CommonExecPythonCalc(int i, ExecNodeContext execNodeContext, ReadableConfig readableConfig, List<RexNode> list, List<InputProperty> list2, RowType rowType, String str) {
        super(i, execNodeContext, readableConfig, list2, rowType, str);
        Preconditions.checkArgument(list2.size() == 1);
        this.projection = (List) Preconditions.checkNotNull(list);
    }

    @Override // org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase
    protected Transformation<RowData> translateToPlanInternal(PlannerBase plannerBase, ExecNodeConfig execNodeConfig) {
        Transformation<?> translateToPlan = getInputEdges().get(0).translateToPlan(plannerBase);
        Configuration extractPythonConfiguration = CommonPythonUtil.extractPythonConfiguration(plannerBase.getExecEnv(), execNodeConfig, plannerBase.getFlinkContext().getClassLoader());
        OneInputTransformation<RowData, RowData> createPythonOneInputTransformation = createPythonOneInputTransformation(translateToPlan, execNodeConfig, plannerBase.getFlinkContext().getClassLoader(), extractPythonConfiguration);
        if (CommonPythonUtil.isPythonWorkerUsingManagedMemory(extractPythonConfiguration, plannerBase.getFlinkContext().getClassLoader())) {
            createPythonOneInputTransformation.declareManagedMemoryUseCaseAtSlotScope(ManagedMemoryUseCase.PYTHON);
        }
        return createPythonOneInputTransformation;
    }

    private OneInputTransformation<RowData, RowData> createPythonOneInputTransformation(Transformation<RowData> transformation, ExecNodeConfig execNodeConfig, ClassLoader classLoader, Configuration configuration) {
        List<RexCall> list = (List) this.projection.stream().filter(rexNode -> {
            return rexNode instanceof RexCall;
        }).map(rexNode2 -> {
            return (RexCall) rexNode2;
        }).collect(Collectors.toList());
        List list2 = (List) this.projection.stream().filter(rexNode3 -> {
            return rexNode3 instanceof RexInputRef;
        }).map(rexNode4 -> {
            return Integer.valueOf(((RexInputRef) rexNode4).getIndex());
        }).collect(Collectors.toList());
        Tuple2<int[], PythonFunctionInfo[]> extractPythonScalarFunctionInfos = extractPythonScalarFunctionInfos(list, classLoader);
        int[] iArr = (int[]) extractPythonScalarFunctionInfos.f0;
        PythonFunctionInfo[] pythonFunctionInfoArr = (PythonFunctionInfo[]) extractPythonScalarFunctionInfos.f1;
        LogicalType[] rowFieldTypes = transformation.getOutputType().toRowFieldTypes();
        InternalTypeInfo<RowData> internalTypeInfo = (InternalTypeInfo) transformation.getOutputType();
        List list3 = (List) list2.stream().map(num -> {
            return rowFieldTypes[num.intValue()];
        }).collect(Collectors.toList());
        List list4 = (List) list.stream().map(rexCall -> {
            return FlinkTypeFactory.toLogicalType(rexCall.getType());
        }).collect(Collectors.toList());
        ArrayList arrayList = new ArrayList();
        arrayList.addAll(list3);
        arrayList.addAll(list4);
        InternalTypeInfo<RowData> ofFields = InternalTypeInfo.ofFields((LogicalType[]) arrayList.toArray(new LogicalType[0]));
        return ExecNodeUtil.createOneInputTransformation((Transformation) transformation, createTransformationMeta(PYTHON_CALC_TRANSFORMATION, execNodeConfig), (StreamOperator) getPythonScalarFunctionOperator(execNodeConfig, classLoader, configuration, internalTypeInfo, ofFields, iArr, pythonFunctionInfoArr, list2.stream().mapToInt(num2 -> {
            return num2.intValue();
        }).toArray(), list.stream().anyMatch(rexCall2 -> {
            return PythonUtil.containsPythonCall(rexCall2, PythonFunctionKind.PANDAS);
        })), (TypeInformation) ofFields, transformation.getParallelism());
    }

    private Tuple2<int[], PythonFunctionInfo[]> extractPythonScalarFunctionInfos(List<RexCall> list, ClassLoader classLoader) {
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        return Tuple2.of(linkedHashMap.keySet().stream().map(rexNode -> {
            if (rexNode instanceof RexInputRef) {
                return Integer.valueOf(((RexInputRef) rexNode).getIndex());
            }
            if (rexNode instanceof RexFieldAccess) {
                return Integer.valueOf(((RexFieldAccess) rexNode).getField().getIndex());
            }
            return null;
        }).mapToInt(num -> {
            return num.intValue();
        }).toArray(), (PythonFunctionInfo[]) ((List) list.stream().map(rexCall -> {
            return CommonPythonUtil.createPythonFunctionInfo(rexCall, linkedHashMap, classLoader);
        }).collect(Collectors.toList())).toArray(new PythonFunctionInfo[list.size()]));
    }

    private OneInputStreamOperator<RowData, RowData> getPythonScalarFunctionOperator(ExecNodeConfig execNodeConfig, ClassLoader classLoader, Configuration configuration, InternalTypeInfo<RowData> internalTypeInfo, InternalTypeInfo<RowData> internalTypeInfo2, int[] iArr, PythonFunctionInfo[] pythonFunctionInfoArr, int[] iArr2, boolean z) {
        boolean isPythonWorkerInProcessMode = CommonPythonUtil.isPythonWorkerInProcessMode(configuration, classLoader);
        Class<?> loadClass = z ? CommonPythonUtil.loadClass(ARROW_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME, classLoader) : isPythonWorkerInProcessMode ? CommonPythonUtil.loadClass(PYTHON_SCALAR_FUNCTION_OPERATOR_NAME, classLoader) : CommonPythonUtil.loadClass(EMBEDDED_PYTHON_SCALAR_FUNCTION_OPERATOR_NAME, classLoader);
        RowType rowType = internalTypeInfo.toRowType();
        RowType rowType2 = internalTypeInfo2.toRowType();
        RowType project = Projection.of(iArr).project(rowType);
        RowType project2 = Projection.of(iArr2).project(rowType);
        RowType project3 = Projection.range(iArr2.length, rowType2.getFieldCount()).project(rowType2);
        try {
            if (isPythonWorkerInProcessMode) {
                return (OneInputStreamOperator) loadClass.getConstructor(Configuration.class, PythonFunctionInfo[].class, RowType.class, RowType.class, RowType.class, GeneratedProjection.class, GeneratedProjection.class).newInstance(configuration, pythonFunctionInfoArr, rowType, project, project3, ProjectionCodeGenerator.generateProjection(new CodeGeneratorContext(execNodeConfig, classLoader), "UdfInputProjection", rowType, project, iArr), ProjectionCodeGenerator.generateProjection(new CodeGeneratorContext(execNodeConfig, classLoader), "ForwardedFieldProjection", rowType, project2, iArr2));
            }
            GeneratedProjection generatedProjection = null;
            if (iArr2.length > 0) {
                generatedProjection = ProjectionCodeGenerator.generateProjection(new CodeGeneratorContext(execNodeConfig, classLoader), "ForwardedFieldProjection", rowType, project2, iArr2);
            }
            return (OneInputStreamOperator) loadClass.getConstructor(Configuration.class, PythonFunctionInfo[].class, RowType.class, RowType.class, RowType.class, int[].class, GeneratedProjection.class).newInstance(configuration, pythonFunctionInfoArr, rowType, project, project3, iArr, generatedProjection);
        } catch (Exception e) {
            throw new TableException("Python Scalar Function Operator constructed failed.", e);
        }
    }
}
