package org.apache.flink.optimizer.custompartition;

import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.optimizer.plan.SingleInputPlanNode;
import org.apache.flink.optimizer.plan.SinkPlanNode;
import org.apache.flink.optimizer.testfunctions.IdentityPartitionerMapper;
import org.apache.flink.optimizer.util.CompilerTestBase;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/optimizer/custompartition/CustomPartitioningTest.class */
public class CustomPartitioningTest extends CompilerTestBase {

    /* loaded from: input_file:org/apache/flink/optimizer/custompartition/CustomPartitioningTest$Pojo.class */
    public static class Pojo {
        public int a;
        public int b;
    }

    /* loaded from: input_file:org/apache/flink/optimizer/custompartition/CustomPartitioningTest$TestKeySelectorInt.class */
    private static class TestKeySelectorInt<T> implements KeySelector<T, Integer> {
        private TestKeySelectorInt() {
        }

        public Integer getKey(T t) {
            return null;
        }

        /* JADX WARN: Multi-variable type inference failed */
        /* renamed from: getKey, reason: collision with other method in class */
        public /* bridge */ /* synthetic */ Object m4getKey(Object obj) throws Exception {
            return getKey((TestKeySelectorInt<T>) obj);
        }
    }

    /* loaded from: input_file:org/apache/flink/optimizer/custompartition/CustomPartitioningTest$TestPartitionerInt.class */
    private static class TestPartitionerInt implements Partitioner<Integer> {
        private TestPartitionerInt() {
        }

        public int partition(Integer num, int i) {
            return 0;
        }
    }

    /* loaded from: input_file:org/apache/flink/optimizer/custompartition/CustomPartitioningTest$TestPartitionerLong.class */
    private static class TestPartitionerLong implements Partitioner<Long> {
        private TestPartitionerLong() {
        }

        public int partition(Long l, int i) {
            return 0;
        }
    }

    @Test
    public void testPartitionTuples() {
        try {
            TestPartitionerInt testPartitionerInt = new TestPartitionerInt();
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            executionEnvironment.setParallelism(4);
            executionEnvironment.fromElements(new Tuple2[]{new Tuple2(0, 0)}).rebalance().partitionCustom(testPartitionerInt, 0).mapPartition(new IdentityPartitionerMapper()).output(new DiscardingOutputFormat());
            SinkPlanNode sinkPlanNode = (SinkPlanNode) compileNoStats(executionEnvironment.createProgramPlan()).getDataSinks().iterator().next();
            SingleInputPlanNode source = sinkPlanNode.getInput().getSource();
            SingleInputPlanNode source2 = source.getInput().getSource();
            SingleInputPlanNode source3 = source2.getInput().getSource();
            Assert.assertEquals(ShipStrategyType.FORWARD, sinkPlanNode.getInput().getShipStrategy());
            Assert.assertEquals(4L, sinkPlanNode.getParallelism());
            Assert.assertEquals(ShipStrategyType.FORWARD, source.getInput().getShipStrategy());
            Assert.assertEquals(4L, source.getParallelism());
            Assert.assertEquals(ShipStrategyType.PARTITION_CUSTOM, source2.getInput().getShipStrategy());
            Assert.assertEquals(testPartitionerInt, source2.getInput().getPartitioner());
            Assert.assertEquals(4L, source2.getParallelism());
            Assert.assertEquals(ShipStrategyType.PARTITION_FORCED_REBALANCE, source3.getInput().getShipStrategy());
            Assert.assertEquals(4L, source3.getParallelism());
        } catch (Exception e) {
            e.printStackTrace();
            Assert.fail(e.getMessage());
        }
    }

    @Test
    public void testPartitionTuplesInvalidType() {
        try {
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            executionEnvironment.setParallelism(4);
            try {
                executionEnvironment.fromElements(new Tuple2[]{new Tuple2(0, 0)}).rebalance().partitionCustom(new TestPartitionerLong(), 0);
                Assert.fail("Should throw an exception");
            } catch (InvalidProgramException e) {
            }
        } catch (Exception e2) {
            e2.printStackTrace();
            Assert.fail(e2.getMessage());
        }
    }

    @Test
    public void testPartitionPojo() {
        try {
            TestPartitionerInt testPartitionerInt = new TestPartitionerInt();
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            executionEnvironment.setParallelism(4);
            executionEnvironment.fromElements(new Pojo[]{new Pojo()}).rebalance().partitionCustom(testPartitionerInt, "a").mapPartition(new IdentityPartitionerMapper()).output(new DiscardingOutputFormat());
            SinkPlanNode sinkPlanNode = (SinkPlanNode) compileNoStats(executionEnvironment.createProgramPlan()).getDataSinks().iterator().next();
            SingleInputPlanNode source = sinkPlanNode.getInput().getSource();
            SingleInputPlanNode source2 = source.getInput().getSource();
            SingleInputPlanNode source3 = source2.getInput().getSource();
            Assert.assertEquals(ShipStrategyType.FORWARD, sinkPlanNode.getInput().getShipStrategy());
            Assert.assertEquals(4L, sinkPlanNode.getParallelism());
            Assert.assertEquals(ShipStrategyType.FORWARD, source.getInput().getShipStrategy());
            Assert.assertEquals(4L, source.getParallelism());
            Assert.assertEquals(ShipStrategyType.PARTITION_CUSTOM, source2.getInput().getShipStrategy());
            Assert.assertEquals(testPartitionerInt, source2.getInput().getPartitioner());
            Assert.assertEquals(4L, source2.getParallelism());
            Assert.assertEquals(ShipStrategyType.PARTITION_FORCED_REBALANCE, source3.getInput().getShipStrategy());
            Assert.assertEquals(4L, source3.getParallelism());
        } catch (Exception e) {
            e.printStackTrace();
            Assert.fail(e.getMessage());
        }
    }

    @Test
    public void testPartitionPojoInvalidType() {
        try {
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            executionEnvironment.setParallelism(4);
            try {
                executionEnvironment.fromElements(new Pojo[]{new Pojo()}).rebalance().partitionCustom(new TestPartitionerLong(), "a");
                Assert.fail("Should throw an exception");
            } catch (InvalidProgramException e) {
            }
        } catch (Exception e2) {
            e2.printStackTrace();
            Assert.fail(e2.getMessage());
        }
    }

    @Test
    public void testPartitionKeySelector() {
        try {
            TestPartitionerInt testPartitionerInt = new TestPartitionerInt();
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            executionEnvironment.setParallelism(4);
            executionEnvironment.fromElements(new Pojo[]{new Pojo()}).rebalance().partitionCustom(testPartitionerInt, new TestKeySelectorInt()).mapPartition(new IdentityPartitionerMapper()).output(new DiscardingOutputFormat());
            SinkPlanNode sinkPlanNode = (SinkPlanNode) compileNoStats(executionEnvironment.createProgramPlan()).getDataSinks().iterator().next();
            SingleInputPlanNode source = sinkPlanNode.getInput().getSource();
            SingleInputPlanNode source2 = source.getInput().getSource();
            SingleInputPlanNode source3 = source2.getInput().getSource();
            SingleInputPlanNode source4 = source3.getInput().getSource();
            SingleInputPlanNode source5 = source4.getInput().getSource();
            Assert.assertEquals(ShipStrategyType.FORWARD, sinkPlanNode.getInput().getShipStrategy());
            Assert.assertEquals(4L, sinkPlanNode.getParallelism());
            Assert.assertEquals(ShipStrategyType.FORWARD, source.getInput().getShipStrategy());
            Assert.assertEquals(4L, source.getParallelism());
            Assert.assertEquals(ShipStrategyType.FORWARD, source2.getInput().getShipStrategy());
            Assert.assertEquals(4L, source2.getParallelism());
            Assert.assertEquals(ShipStrategyType.PARTITION_CUSTOM, source3.getInput().getShipStrategy());
            Assert.assertEquals(testPartitionerInt, source3.getInput().getPartitioner());
            Assert.assertEquals(4L, source3.getParallelism());
            Assert.assertEquals(ShipStrategyType.FORWARD, source4.getInput().getShipStrategy());
            Assert.assertEquals(4L, source4.getParallelism());
            Assert.assertEquals(ShipStrategyType.PARTITION_FORCED_REBALANCE, source5.getInput().getShipStrategy());
            Assert.assertEquals(4L, source5.getParallelism());
        } catch (Exception e) {
            e.printStackTrace();
            Assert.fail(e.getMessage());
        }
    }

    @Test
    public void testPartitionKeySelectorInvalidType() {
        try {
            TestPartitionerLong testPartitionerLong = new TestPartitionerLong();
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            executionEnvironment.setParallelism(4);
            try {
                executionEnvironment.fromElements(new Pojo[]{new Pojo()}).rebalance().partitionCustom(testPartitionerLong, new TestKeySelectorInt());
                Assert.fail("Should throw an exception");
            } catch (InvalidProgramException e) {
            }
        } catch (Exception e2) {
            e2.printStackTrace();
            Assert.fail(e2.getMessage());
        }
    }
}
