package org.apache.flink.runtime.executiongraph;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.concurrent.ScheduledExecutorService;
import org.apache.flink.runtime.JobException;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.jobgraph.DistributionPattern;
import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
import org.apache.flink.runtime.jobgraph.JobVertex;
import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
import org.apache.flink.runtime.scheduler.SchedulerBase;
import org.apache.flink.runtime.scheduler.strategy.ConsumedPartitionGroup;
import org.apache.flink.testutils.TestingUtils;
import org.apache.flink.testutils.executor.TestExecutorResource;
import org.junit.Assert;
import org.junit.ClassRule;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/runtime/executiongraph/PointwisePatternTest.class */
public class PointwisePatternTest {

    @ClassRule
    public static final TestExecutorResource<ScheduledExecutorService> EXECUTOR_RESOURCE = TestingUtils.defaultExecutorResource();

    @Test
    public void testNToN() throws Exception {
        for (ExecutionVertex executionVertex : setUpExecutionGraphAndGetDownstreamVertex(23, 23).getTaskVertices()) {
            Assert.assertEquals(1L, executionVertex.getNumberOfInputs());
            ConsumedPartitionGroup consumedPartitionGroup = executionVertex.getConsumedPartitionGroup(0);
            Assert.assertEquals(1L, consumedPartitionGroup.size());
            Assert.assertEquals(executionVertex.getParallelSubtaskIndex(), consumedPartitionGroup.getFirst().getPartitionNumber());
        }
    }

    @Test
    public void test2NToN() throws Exception {
        for (ExecutionVertex executionVertex : setUpExecutionGraphAndGetDownstreamVertex(34, 17).getTaskVertices()) {
            Assert.assertEquals(1L, executionVertex.getNumberOfInputs());
            ConsumedPartitionGroup consumedPartitionGroup = executionVertex.getConsumedPartitionGroup(0);
            Assert.assertEquals(2L, consumedPartitionGroup.size());
            int i = 0;
            Iterator it = consumedPartitionGroup.iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                Assert.assertEquals((executionVertex.getParallelSubtaskIndex() * 2) + i2, ((IntermediateResultPartitionID) it.next()).getPartitionNumber());
            }
        }
    }

    @Test
    public void test3NToN() throws Exception {
        for (ExecutionVertex executionVertex : setUpExecutionGraphAndGetDownstreamVertex(51, 17).getTaskVertices()) {
            Assert.assertEquals(1L, executionVertex.getNumberOfInputs());
            ConsumedPartitionGroup consumedPartitionGroup = executionVertex.getConsumedPartitionGroup(0);
            Assert.assertEquals(3L, consumedPartitionGroup.size());
            int i = 0;
            Iterator it = consumedPartitionGroup.iterator();
            while (it.hasNext()) {
                int i2 = i;
                i++;
                Assert.assertEquals((executionVertex.getParallelSubtaskIndex() * 3) + i2, ((IntermediateResultPartitionID) it.next()).getPartitionNumber());
            }
        }
    }

    @Test
    public void testNTo2N() throws Exception {
        for (ExecutionVertex executionVertex : setUpExecutionGraphAndGetDownstreamVertex(41, 82).getTaskVertices()) {
            Assert.assertEquals(1L, executionVertex.getNumberOfInputs());
            ConsumedPartitionGroup consumedPartitionGroup = executionVertex.getConsumedPartitionGroup(0);
            Assert.assertEquals(1L, consumedPartitionGroup.size());
            Assert.assertEquals(executionVertex.getParallelSubtaskIndex() / 2, consumedPartitionGroup.getFirst().getPartitionNumber());
        }
    }

    @Test
    public void testNTo7N() throws Exception {
        for (ExecutionVertex executionVertex : setUpExecutionGraphAndGetDownstreamVertex(11, 77).getTaskVertices()) {
            Assert.assertEquals(1L, executionVertex.getNumberOfInputs());
            ConsumedPartitionGroup consumedPartitionGroup = executionVertex.getConsumedPartitionGroup(0);
            Assert.assertEquals(1L, consumedPartitionGroup.size());
            Assert.assertEquals(executionVertex.getParallelSubtaskIndex() / 7, consumedPartitionGroup.getFirst().getPartitionNumber());
        }
    }

    @Test
    public void testLowHighIrregular() throws Exception {
        testLowToHigh(3, 16);
        testLowToHigh(19, 21);
        testLowToHigh(15, 20);
        testLowToHigh(11, 31);
        testLowToHigh(11, 29);
    }

    @Test
    public void testHighLowIrregular() throws Exception {
        testHighToLow(16, 3);
        testHighToLow(21, 19);
        testHighToLow(20, 15);
        testHighToLow(31, 11);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r3v1, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v11, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v13, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v15, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v3, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v5, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v7, types: [int[], int[][]] */
    /* JADX WARN: Type inference failed for: r3v9, types: [int[], int[][]] */
    @Test
    public void testPointwiseConnectionSequence() throws Exception {
        testConnections(3, 5, new int[]{new int[]{0}, new int[]{0}, new int[]{1}, new int[]{1}, new int[]{2}});
        testConnections(3, 10, new int[]{new int[]{0}, new int[]{0}, new int[]{0}, new int[]{0}, new int[]{1}, new int[]{1}, new int[]{1}, new int[]{2}, new int[]{2}, new int[]{2}});
        testConnections(4, 6, new int[]{new int[]{0}, new int[]{0}, new int[]{1}, new int[]{2}, new int[]{2}, new int[]{3}});
        testConnections(6, 10, new int[]{new int[]{0}, new int[]{0}, new int[]{1}, new int[]{1}, new int[]{2}, new int[]{3}, new int[]{3}, new int[]{4}, new int[]{4}, new int[]{5}});
        testConnections(5, 3, new int[]{new int[]{0}, new int[]{1, 2}, new int[]{3, 4}});
        testConnections(10, 3, new int[]{new int[]{0, 1, 2}, new int[]{3, 4, 5}, new int[]{6, 7, 8, 9}});
        testConnections(6, 4, new int[]{new int[]{0}, new int[]{1, 2}, new int[]{3}, new int[]{4, 5}});
        testConnections(10, 6, new int[]{new int[]{0}, new int[]{1, 2}, new int[]{3, 4}, new int[]{5}, new int[]{6, 7}, new int[]{8, 9}});
    }

    private void testLowToHigh(int i, int i2) throws Exception {
        if (i2 < i) {
            throw new IllegalArgumentException();
        }
        int i3 = i2 / i;
        int i4 = i2 % i == 0 ? 0 : 1;
        int[] iArr = new int[i];
        for (ExecutionVertex executionVertex : setUpExecutionGraphAndGetDownstreamVertex(i, i2).getTaskVertices()) {
            Assert.assertEquals(1L, executionVertex.getNumberOfInputs());
            ConsumedPartitionGroup consumedPartitionGroup = executionVertex.getConsumedPartitionGroup(0);
            Assert.assertEquals(1L, consumedPartitionGroup.size());
            int partitionNumber = consumedPartitionGroup.getFirst().getPartitionNumber();
            iArr[partitionNumber] = iArr[partitionNumber] + 1;
        }
        int length = iArr.length;
        for (int i5 = 0; i5 < length; i5++) {
            int i6 = iArr[i5];
            Assert.assertTrue(i6 >= i3 && i6 <= i3 + i4);
        }
    }

    private void testHighToLow(int i, int i2) throws Exception {
        if (i < i2) {
            throw new IllegalArgumentException();
        }
        int i3 = i / i2;
        int i4 = i % i2 == 0 ? 0 : 1;
        int[] iArr = new int[i];
        for (ExecutionVertex executionVertex : setUpExecutionGraphAndGetDownstreamVertex(i, i2).getTaskVertices()) {
            Assert.assertEquals(1L, executionVertex.getNumberOfInputs());
            ArrayList arrayList = new ArrayList();
            Iterator it = executionVertex.getAllConsumedPartitionGroups().iterator();
            while (it.hasNext()) {
                Iterator it2 = ((ConsumedPartitionGroup) it.next()).iterator();
                while (it2.hasNext()) {
                    arrayList.add((IntermediateResultPartitionID) it2.next());
                }
            }
            Assert.assertTrue(arrayList.size() >= i3 && arrayList.size() <= i3 + i4);
            Iterator it3 = arrayList.iterator();
            while (it3.hasNext()) {
                int partitionNumber = ((IntermediateResultPartitionID) it3.next()).getPartitionNumber();
                iArr[partitionNumber] = iArr[partitionNumber] + 1;
            }
        }
        for (int i5 : iArr) {
            Assert.assertEquals(1L, i5);
        }
    }

    private ExecutionJobVertex setUpExecutionGraphAndGetDownstreamVertex(int i, int i2) throws Exception {
        JobVertex jobVertex = new JobVertex("vertex1");
        JobVertex jobVertex2 = new JobVertex("vertex2");
        jobVertex.setParallelism(i);
        jobVertex2.setParallelism(i2);
        jobVertex.setInvokableClass(AbstractInvokable.class);
        jobVertex2.setInvokableClass(AbstractInvokable.class);
        jobVertex2.connectNewDataSetAsInput(jobVertex, DistributionPattern.POINTWISE, ResultPartitionType.PIPELINED);
        ArrayList arrayList = new ArrayList(Arrays.asList(jobVertex, jobVertex2));
        DefaultExecutionGraph build = TestingDefaultExecutionGraphBuilder.newBuilder().setVertexParallelismStore(SchedulerBase.computeVertexParallelismStore(arrayList)).build((ScheduledExecutorService) EXECUTOR_RESOURCE.getExecutor());
        try {
            build.attachJobGraph(arrayList);
        } catch (JobException e) {
            e.printStackTrace();
            Assert.fail("Job failed with exception: " + e.getMessage());
        }
        return (ExecutionJobVertex) build.getAllVertices().get(jobVertex2.getID());
    }

    private void testConnections(int i, int i2, int[][] iArr) throws Exception {
        ExecutionJobVertex upExecutionGraphAndGetDownstreamVertex = setUpExecutionGraphAndGetDownstreamVertex(i, i2);
        for (int i3 = 0; i3 < upExecutionGraphAndGetDownstreamVertex.getTaskVertices().length; i3++) {
            ConsumedPartitionGroup consumedPartitionGroup = upExecutionGraphAndGetDownstreamVertex.getTaskVertices()[i3].getConsumedPartitionGroup(0);
            Assert.assertEquals(iArr[i3].length, consumedPartitionGroup.size());
            int i4 = 0;
            Iterator it = consumedPartitionGroup.iterator();
            while (it.hasNext()) {
                IntermediateResultPartitionID intermediateResultPartitionID = (IntermediateResultPartitionID) it.next();
                int i5 = i4;
                i4++;
                Assert.assertEquals(iArr[i3][i5], intermediateResultPartitionID.getPartitionNumber());
            }
        }
    }
}
