/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.optimizer.custompartition;

import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.io.OutputFormat;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.api.java.operators.PartitionOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.optimizer.plan.OptimizedPlan;
import org.apache.flink.optimizer.plan.SingleInputPlanNode;
import org.apache.flink.optimizer.plan.SinkPlanNode;
import org.apache.flink.optimizer.testfunctions.IdentityPartitionerMapper;
import org.apache.flink.optimizer.util.CompilerTestBase;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.junit.Assert;
import org.junit.Test;

public class CustomPartitioningTest
extends CompilerTestBase {
    @Test
    public void testPartitionTuples() {
        try {
            TestPartitionerInt part = new TestPartitionerInt();
            int parallelism = 4;
            ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
            env.setParallelism(4);
            PartitionOperator data = env.fromElements((Object[])new Tuple2[]{new Tuple2((Object)0, (Object)0)}).rebalance();
            data.partitionCustom((Partitioner)part, 0).mapPartition(new IdentityPartitionerMapper()).output((OutputFormat)new DiscardingOutputFormat());
            Plan p = env.createProgramPlan();
            OptimizedPlan op = this.compileNoStats(p);
            SinkPlanNode sink = (SinkPlanNode)op.getDataSinks().iterator().next();
            SingleInputPlanNode mapper = (SingleInputPlanNode)sink.getInput().getSource();
            SingleInputPlanNode partitioner = (SingleInputPlanNode)mapper.getInput().getSource();
            SingleInputPlanNode balancer = (SingleInputPlanNode)partitioner.getInput().getSource();
            Assert.assertEquals((Object)ShipStrategyType.FORWARD, (Object)sink.getInput().getShipStrategy());
            Assert.assertEquals((long)4L, (long)sink.getParallelism());
            Assert.assertEquals((Object)ShipStrategyType.FORWARD, (Object)mapper.getInput().getShipStrategy());
            Assert.assertEquals((long)4L, (long)mapper.getParallelism());
            Assert.assertEquals((Object)ShipStrategyType.PARTITION_CUSTOM, (Object)partitioner.getInput().getShipStrategy());
            Assert.assertEquals((Object)part, (Object)partitioner.getInput().getPartitioner());
            Assert.assertEquals((long)4L, (long)partitioner.getParallelism());
            Assert.assertEquals((Object)ShipStrategyType.PARTITION_FORCED_REBALANCE, (Object)balancer.getInput().getShipStrategy());
            Assert.assertEquals((long)4L, (long)balancer.getParallelism());
        }
        catch (Exception e) {
            e.printStackTrace();
            Assert.fail((String)e.getMessage());
        }
    }

    @Test
    public void testPartitionTuplesInvalidType() {
        try {
            int parallelism = 4;
            ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
            env.setParallelism(4);
            PartitionOperator data = env.fromElements((Object[])new Tuple2[]{new Tuple2((Object)0, (Object)0)}).rebalance();
            try {
                data.partitionCustom((Partitioner)new TestPartitionerLong(), 0);
                Assert.fail((String)"Should throw an exception");
            }
            catch (InvalidProgramException invalidProgramException) {}
        }
        catch (Exception e) {
            e.printStackTrace();
            Assert.fail((String)e.getMessage());
        }
    }

    @Test
    public void testPartitionPojo() {
        try {
            TestPartitionerInt part = new TestPartitionerInt();
            int parallelism = 4;
            ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
            env.setParallelism(4);
            PartitionOperator data = env.fromElements((Object[])new Pojo[]{new Pojo()}).rebalance();
            data.partitionCustom((Partitioner)part, "a").mapPartition(new IdentityPartitionerMapper()).output((OutputFormat)new DiscardingOutputFormat());
            Plan p = env.createProgramPlan();
            OptimizedPlan op = this.compileNoStats(p);
            SinkPlanNode sink = (SinkPlanNode)op.getDataSinks().iterator().next();
            SingleInputPlanNode mapper = (SingleInputPlanNode)sink.getInput().getSource();
            SingleInputPlanNode partitioner = (SingleInputPlanNode)mapper.getInput().getSource();
            SingleInputPlanNode balancer = (SingleInputPlanNode)partitioner.getInput().getSource();
            Assert.assertEquals((Object)ShipStrategyType.FORWARD, (Object)sink.getInput().getShipStrategy());
            Assert.assertEquals((long)4L, (long)sink.getParallelism());
            Assert.assertEquals((Object)ShipStrategyType.FORWARD, (Object)mapper.getInput().getShipStrategy());
            Assert.assertEquals((long)4L, (long)mapper.getParallelism());
            Assert.assertEquals((Object)ShipStrategyType.PARTITION_CUSTOM, (Object)partitioner.getInput().getShipStrategy());
            Assert.assertEquals((Object)part, (Object)partitioner.getInput().getPartitioner());
            Assert.assertEquals((long)4L, (long)partitioner.getParallelism());
            Assert.assertEquals((Object)ShipStrategyType.PARTITION_FORCED_REBALANCE, (Object)balancer.getInput().getShipStrategy());
            Assert.assertEquals((long)4L, (long)balancer.getParallelism());
        }
        catch (Exception e) {
            e.printStackTrace();
            Assert.fail((String)e.getMessage());
        }
    }

    @Test
    public void testPartitionPojoInvalidType() {
        try {
            int parallelism = 4;
            ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
            env.setParallelism(4);
            PartitionOperator data = env.fromElements((Object[])new Pojo[]{new Pojo()}).rebalance();
            try {
                data.partitionCustom((Partitioner)new TestPartitionerLong(), "a");
                Assert.fail((String)"Should throw an exception");
            }
            catch (InvalidProgramException invalidProgramException) {}
        }
        catch (Exception e) {
            e.printStackTrace();
            Assert.fail((String)e.getMessage());
        }
    }

    @Test
    public void testPartitionKeySelector() {
        try {
            TestPartitionerInt part = new TestPartitionerInt();
            int parallelism = 4;
            ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
            env.setParallelism(4);
            PartitionOperator data = env.fromElements((Object[])new Pojo[]{new Pojo()}).rebalance();
            data.partitionCustom((Partitioner)part, new TestKeySelectorInt()).mapPartition(new IdentityPartitionerMapper()).output((OutputFormat)new DiscardingOutputFormat());
            Plan p = env.createProgramPlan();
            OptimizedPlan op = this.compileNoStats(p);
            SinkPlanNode sink = (SinkPlanNode)op.getDataSinks().iterator().next();
            SingleInputPlanNode mapper = (SingleInputPlanNode)sink.getInput().getSource();
            SingleInputPlanNode keyRemover = (SingleInputPlanNode)mapper.getInput().getSource();
            SingleInputPlanNode partitioner = (SingleInputPlanNode)keyRemover.getInput().getSource();
            SingleInputPlanNode keyExtractor = (SingleInputPlanNode)partitioner.getInput().getSource();
            SingleInputPlanNode balancer = (SingleInputPlanNode)keyExtractor.getInput().getSource();
            Assert.assertEquals((Object)ShipStrategyType.FORWARD, (Object)sink.getInput().getShipStrategy());
            Assert.assertEquals((long)4L, (long)sink.getParallelism());
            Assert.assertEquals((Object)ShipStrategyType.FORWARD, (Object)mapper.getInput().getShipStrategy());
            Assert.assertEquals((long)4L, (long)mapper.getParallelism());
            Assert.assertEquals((Object)ShipStrategyType.FORWARD, (Object)keyRemover.getInput().getShipStrategy());
            Assert.assertEquals((long)4L, (long)keyRemover.getParallelism());
            Assert.assertEquals((Object)ShipStrategyType.PARTITION_CUSTOM, (Object)partitioner.getInput().getShipStrategy());
            Assert.assertEquals((Object)part, (Object)partitioner.getInput().getPartitioner());
            Assert.assertEquals((long)4L, (long)partitioner.getParallelism());
            Assert.assertEquals((Object)ShipStrategyType.FORWARD, (Object)keyExtractor.getInput().getShipStrategy());
            Assert.assertEquals((long)4L, (long)keyExtractor.getParallelism());
            Assert.assertEquals((Object)ShipStrategyType.PARTITION_FORCED_REBALANCE, (Object)balancer.getInput().getShipStrategy());
            Assert.assertEquals((long)4L, (long)balancer.getParallelism());
        }
        catch (Exception e) {
            e.printStackTrace();
            Assert.fail((String)e.getMessage());
        }
    }

    @Test
    public void testPartitionKeySelectorInvalidType() {
        try {
            TestPartitionerLong part = new TestPartitionerLong();
            int parallelism = 4;
            ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
            env.setParallelism(4);
            PartitionOperator data = env.fromElements((Object[])new Pojo[]{new Pojo()}).rebalance();
            try {
                data.partitionCustom((Partitioner)part, new TestKeySelectorInt());
                Assert.fail((String)"Should throw an exception");
            }
            catch (InvalidProgramException invalidProgramException) {}
        }
        catch (Exception e) {
            e.printStackTrace();
            Assert.fail((String)e.getMessage());
        }
    }

    private static class TestKeySelectorInt<T>
    implements KeySelector<T, Integer> {
        private TestKeySelectorInt() {
        }

        public Integer getKey(T value) {
            return null;
        }
    }

    private static class TestPartitionerLong
    implements Partitioner<Long> {
        private TestPartitionerLong() {
        }

        public int partition(Long key, int numPartitions) {
            return 0;
        }
    }

    private static class TestPartitionerInt
    implements Partitioner<Integer> {
        private TestPartitionerInt() {
        }

        public int partition(Integer key, int numPartitions) {
            return 0;
        }
    }

    public static class Pojo {
        public int a;
        public int b;
    }
}

