package org.apache.flink.runtime.checkpoint.channel;

import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import org.apache.flink.api.common.JobID;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.checkpoint.PendingCheckpointTest;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.state.CheckpointStorageLocationReference;
import org.apache.flink.runtime.state.storage.JobManagerCheckpointStorage;
import org.apache.flink.testutils.junit.extensions.parameterized.Parameter;
import org.apache.flink.testutils.junit.extensions.parameterized.ParameterizedTestExtension;
import org.apache.flink.testutils.junit.extensions.parameterized.Parameters;
import org.apache.flink.util.CloseableIterator;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.TestTemplate;
import org.junit.jupiter.api.extension.ExtendWith;

@ExtendWith({ParameterizedTestExtension.class})
/* loaded from: input_file:org/apache/flink/runtime/checkpoint/channel/ChannelStateWriteRequestDispatcherTest.class */
public class ChannelStateWriteRequestDispatcherTest {
    private static final JobID JOB_ID = new JobID();
    private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
    private static final int SUBTASK_INDEX = 0;

    @Parameter
    public Optional<Class<Exception>> expectedException;

    @Parameter(PendingCheckpointTest.PARALLELISM)
    public List<ChannelStateWriteRequest> requests;
    private static final long CHECKPOINT_ID = 42;

    @Parameters(name = "expectedException={0} requests={1}")
    public static List<Object[]> data() {
        return Arrays.asList(new Object[]{Optional.empty(), Arrays.asList(start(), completeIn(), completeOut())}, new Object[]{Optional.empty(), Arrays.asList(start(), writeIn(), completeIn())}, new Object[]{Optional.empty(), Arrays.asList(start(), writeOut(), completeOut())}, new Object[]{Optional.empty(), Arrays.asList(start(), writeOutFuture(), completeOut())}, new Object[]{Optional.empty(), Arrays.asList(start(), completeIn(), writeOut())}, new Object[]{Optional.empty(), Arrays.asList(start(), completeIn(), writeOutFuture())}, new Object[]{Optional.empty(), Arrays.asList(start(), completeOut(), writeIn())}, new Object[]{Optional.of(IllegalArgumentException.class), Collections.singletonList(writeIn())}, new Object[]{Optional.of(IllegalArgumentException.class), Collections.singletonList(writeOut())}, new Object[]{Optional.of(IllegalArgumentException.class), Collections.singletonList(writeOutFuture())}, new Object[]{Optional.of(IllegalArgumentException.class), Collections.singletonList(completeIn())}, new Object[]{Optional.of(IllegalArgumentException.class), Collections.singletonList(completeOut())}, new Object[]{Optional.of(IllegalArgumentException.class), Arrays.asList(start(), completeIn(), completeIn())}, new Object[]{Optional.of(IllegalArgumentException.class), Arrays.asList(start(), completeOut(), completeOut())}, new Object[]{Optional.of(IllegalStateException.class), Arrays.asList(start(), completeIn(), writeIn())}, new Object[]{Optional.of(IllegalStateException.class), Arrays.asList(start(), completeOut(), writeOut())}, new Object[]{Optional.of(IllegalStateException.class), Arrays.asList(start(), completeOut(), writeOutFuture())}, new Object[]{Optional.of(IllegalStateException.class), Arrays.asList(start(), start())});
    }

    private static CheckpointInProgressRequest completeOut() {
        return ChannelStateWriteRequest.completeOutput(JOB_VERTEX_ID, 0, CHECKPOINT_ID);
    }

    private static CheckpointInProgressRequest completeIn() {
        return ChannelStateWriteRequest.completeInput(JOB_VERTEX_ID, 0, CHECKPOINT_ID);
    }

    private static ChannelStateWriteRequest writeIn() {
        return ChannelStateWriteRequest.write(JOB_VERTEX_ID, 0, CHECKPOINT_ID, new InputChannelInfo(1, 1), CloseableIterator.ofElement(new NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment(1), FreeingBufferRecycler.INSTANCE), (v0) -> {
            v0.recycleBuffer();
        }));
    }

    private static ChannelStateWriteRequest writeOut() {
        return ChannelStateWriteRequest.write(JOB_VERTEX_ID, 0, CHECKPOINT_ID, new ResultSubpartitionInfo(1, 1), new Buffer[]{new NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment(1), FreeingBufferRecycler.INSTANCE)});
    }

    private static ChannelStateWriteRequest writeOutFuture() {
        CompletableFuture completableFuture = new CompletableFuture();
        ChannelStateWriteRequest write = ChannelStateWriteRequest.write(JOB_VERTEX_ID, 0, CHECKPOINT_ID, new ResultSubpartitionInfo(1, 1), completableFuture);
        completableFuture.complete(Collections.singletonList(new NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment(1), FreeingBufferRecycler.INSTANCE)));
        return write;
    }

    private static SubtaskRegisterRequest register() {
        return new SubtaskRegisterRequest(JOB_VERTEX_ID, 0);
    }

    private static CheckpointStartRequest start() {
        return new CheckpointStartRequest(JOB_VERTEX_ID, 0, CHECKPOINT_ID, new ChannelStateWriter.ChannelStateWriteResult(), new CheckpointStorageLocationReference(new byte[]{1}));
    }

    @TestTemplate
    void doRun() {
        ChannelStateWriteRequestDispatcherImpl channelStateWriteRequestDispatcherImpl = new ChannelStateWriteRequestDispatcherImpl(new JobManagerCheckpointStorage(), JOB_ID, new ChannelStateSerializerImpl());
        try {
            channelStateWriteRequestDispatcherImpl.dispatch(register());
            Iterator<ChannelStateWriteRequest> it = this.requests.iterator();
            while (it.hasNext()) {
                channelStateWriteRequestDispatcherImpl.dispatch(it.next());
            }
            this.expectedException.ifPresent(cls -> {
                Assertions.fail("expected exception " + cls);
            });
        } catch (Throwable th) {
            if (!this.expectedException.filter(cls2 -> {
                return cls2.isInstance(th);
            }).isPresent()) {
                throw new RuntimeException("unexpected exception", th);
            }
        }
    }
}
