package org.apache.flink.runtime.state;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.runtime.checkpoint.CheckpointOptions;
import org.apache.flink.runtime.checkpoint.CheckpointType;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriter;
import org.apache.flink.runtime.checkpoint.channel.ChannelStateWriterImpl;
import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
import org.apache.flink.runtime.checkpoint.channel.SequentialChannelStateReaderImpl;
import org.apache.flink.runtime.io.network.buffer.Buffer;
import org.apache.flink.runtime.io.network.buffer.FreeingBufferRecycler;
import org.apache.flink.runtime.io.network.buffer.NetworkBuffer;
import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
import org.apache.flink.runtime.io.network.partition.BufferWritingResultPartition;
import org.apache.flink.runtime.io.network.partition.NoOpBufferAvailablityListener;
import org.apache.flink.runtime.io.network.partition.ResultPartitionBuilder;
import org.apache.flink.runtime.io.network.partition.ResultPartitionType;
import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
import org.apache.flink.runtime.io.network.partition.consumer.InputGate;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGateBuilder;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.state.memory.NonPersistentMetadataCheckpointStorageLocation;
import org.apache.flink.util.CloseableIterator;
import org.apache.flink.util.function.SupplierWithException;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/runtime/state/ChannelPersistenceITCase.class */
public class ChannelPersistenceITCase {
    private static final Random RANDOM = new Random(System.currentTimeMillis());

    @Test
    public void testUpstreamBlocksAfterRecoveringState() throws Exception {
        upstreamBlocksAfterRecoveringState(ResultPartitionType.PIPELINED);
    }

    @Test
    public void testNotBlocksAfterRecoveringStateForApproximateLocalRecovery() throws Exception {
        upstreamBlocksAfterRecoveringState(ResultPartitionType.PIPELINED_APPROXIMATE);
    }

    @Test
    public void testReadWritten() throws Exception {
        byte[] randomBytes = randomBytes(1024);
        byte[] randomBytes2 = randomBytes(1024);
        byte[] randomBytes3 = randomBytes(1024);
        SequentialChannelStateReaderImpl sequentialChannelStateReaderImpl = new SequentialChannelStateReaderImpl(toTaskStateSnapshot(write(1L, Collections.singletonMap(new InputChannelInfo(0, 0), randomBytes), Collections.singletonMap(new ResultSubpartitionInfo(0, 0), randomBytes2), Collections.singletonMap(new ResultSubpartitionInfo(0, 1), randomBytes3))));
        NetworkBufferPool networkBufferPool = new NetworkBufferPool(6, 1024);
        try {
            InputGate buildGate = buildGate(networkBufferPool, 1);
            sequentialChannelStateReaderImpl.readInputData(new InputGate[]{buildGate});
            buildGate.getClass();
            Assert.assertArrayEquals(randomBytes, collectBytes(buildGate::pollNext, (v0) -> {
                return v0.getBuffer();
            }));
            BufferWritingResultPartition buildResultPartition = buildResultPartition(networkBufferPool, ResultPartitionType.PIPELINED, 0, 2);
            sequentialChannelStateReaderImpl.readOutputData(new BufferWritingResultPartition[]{buildResultPartition}, false);
            ResultSubpartitionView createSubpartitionView = buildResultPartition.createSubpartitionView(0, new NoOpBufferAvailablityListener());
            Assert.assertArrayEquals(randomBytes2, collectBytes(() -> {
                return Optional.ofNullable(createSubpartitionView.getNextBuffer());
            }, (v0) -> {
                return v0.buffer();
            }));
            ResultSubpartitionView createSubpartitionView2 = buildResultPartition.createSubpartitionView(1, new NoOpBufferAvailablityListener());
            Assert.assertArrayEquals(randomBytes3, collectBytes(() -> {
                return Optional.ofNullable(createSubpartitionView2.getNextBuffer());
            }, (v0) -> {
                return v0.buffer();
            }));
            networkBufferPool.destroy();
        } catch (Throwable th) {
            networkBufferPool.destroy();
            throw th;
        }
    }

    private void upstreamBlocksAfterRecoveringState(ResultPartitionType resultPartitionType) throws Exception {
        NetworkBufferPool networkBufferPool = new NetworkBufferPool(4, 1024);
        byte[] randomBytes = randomBytes(1024);
        try {
            BufferWritingResultPartition buildResultPartition = buildResultPartition(networkBufferPool, resultPartitionType, 0, 1);
            new SequentialChannelStateReaderImpl(new TaskStateSnapshot()).readOutputData(new BufferWritingResultPartition[]{buildResultPartition}, true);
            buildResultPartition.emitRecord(ByteBuffer.wrap(randomBytes), 0);
            ResultSubpartitionView createSubpartitionView = buildResultPartition.createSubpartitionView(0, new NoOpBufferAvailablityListener());
            if (resultPartitionType != ResultPartitionType.PIPELINED_APPROXIMATE) {
                Assert.assertEquals(Buffer.DataType.RECOVERY_COMPLETION, createSubpartitionView.getNextBuffer().buffer().getDataType());
                Assert.assertNull(createSubpartitionView.getNextBuffer());
                createSubpartitionView.resumeConsumption();
            }
            Assert.assertArrayEquals(randomBytes, collectBytes(createSubpartitionView.getNextBuffer().buffer()));
            networkBufferPool.destroy();
        } catch (Throwable th) {
            networkBufferPool.destroy();
            throw th;
        }
    }

    private BufferWritingResultPartition buildResultPartition(NetworkBufferPool networkBufferPool, ResultPartitionType resultPartitionType, int i, int i2) throws IOException {
        BufferWritingResultPartition build = new ResultPartitionBuilder().setResultPartitionIndex(i).setResultPartitionType(resultPartitionType).setNumberOfSubpartitions(i2).setBufferPoolFactory(() -> {
            return networkBufferPool.createBufferPool(i2, Integer.MAX_VALUE, i2, Integer.MAX_VALUE, 0);
        }).build();
        build.setup();
        return build;
    }

    private SingleInputGate buildGate(NetworkBufferPool networkBufferPool, int i) throws IOException {
        SingleInputGate build = new SingleInputGateBuilder().setChannelFactory((v0, v1) -> {
            return v0.buildRemoteRecoveredChannel(v1);
        }).setBufferPoolFactory(networkBufferPool.createBufferPool(i, Integer.MAX_VALUE)).setSegmentProvider(networkBufferPool).setNumberOfChannels(i).build();
        build.setup();
        return build;
    }

    private <T> byte[] collectBytes(SupplierWithException<Optional<T>, Exception> supplierWithException, Function<T, Buffer> function) throws Exception {
        ArrayList arrayList = new ArrayList();
        Object obj = supplierWithException.get();
        while (true) {
            Optional optional = (Optional) obj;
            if (!optional.isPresent()) {
                ByteBuffer wrap = ByteBuffer.wrap(new byte[arrayList.stream().mapToInt((v0) -> {
                    return v0.getSize();
                }).sum()]);
                arrayList.forEach(buffer -> {
                    wrap.put(buffer.getNioBufferReadable());
                    buffer.recycleBuffer();
                });
                return wrap.array();
            }
            Optional<T> filter = optional.map(function).filter(buffer2 -> {
                return buffer2.getDataType().isBuffer();
            });
            arrayList.getClass();
            filter.ifPresent((v1) -> {
                r1.add(v1);
            });
            obj = supplierWithException.get();
        }
    }

    private byte[] collectBytes(Buffer buffer) {
        ByteBuffer nioBufferReadable = buffer.getNioBufferReadable();
        byte[] bArr = new byte[nioBufferReadable.capacity()];
        nioBufferReadable.get(bArr);
        return bArr;
    }

    private byte[] randomBytes(int i) {
        byte[] bArr = new byte[i];
        RANDOM.nextBytes(bArr);
        return bArr;
    }

    private ChannelStateWriter.ChannelStateWriteResult write(long j, Map<InputChannelInfo, byte[]> map, Map<ResultSubpartitionInfo, byte[]> map2, Map<ResultSubpartitionInfo, byte[]> map3) throws Exception {
        int sizeOfBytes = sizeOfBytes(map) + sizeOfBytes(map2) + sizeOfBytes(map3) + 24;
        Map wrapWithBuffers = wrapWithBuffers(map);
        Map wrapWithBuffers2 = wrapWithBuffers(map2);
        Map wrapWithBuffers3 = wrapWithBuffers(map3);
        ChannelStateWriterImpl channelStateWriterImpl = new ChannelStateWriterImpl("test", 0, getStreamFactoryFactory(sizeOfBytes));
        Throwable th = null;
        try {
            channelStateWriterImpl.open();
            channelStateWriterImpl.start(j, new CheckpointOptions(CheckpointType.CHECKPOINT, new CheckpointStorageLocationReference("poly".getBytes())));
            for (Map.Entry entry : wrapWithBuffers.entrySet()) {
                channelStateWriterImpl.addInputData(j, (InputChannelInfo) entry.getKey(), -2, CloseableIterator.ofElements((v0) -> {
                    v0.recycleBuffer();
                }, new Buffer[]{(Buffer) entry.getValue()}));
            }
            channelStateWriterImpl.finishInput(j);
            for (Map.Entry entry2 : wrapWithBuffers3.entrySet()) {
                CompletableFuture completableFuture = new CompletableFuture();
                channelStateWriterImpl.addOutputDataFuture(j, (ResultSubpartitionInfo) entry2.getKey(), -2, completableFuture);
                completableFuture.complete(Collections.singletonList(entry2.getValue()));
            }
            for (Map.Entry entry3 : wrapWithBuffers2.entrySet()) {
                channelStateWriterImpl.addOutputData(j, (ResultSubpartitionInfo) entry3.getKey(), -2, new Buffer[]{(Buffer) entry3.getValue()});
            }
            channelStateWriterImpl.finishOutput(j);
            ChannelStateWriter.ChannelStateWriteResult andRemoveWriteResult = channelStateWriterImpl.getAndRemoveWriteResult(j);
            andRemoveWriteResult.getResultSubpartitionStateHandles().join();
            if (channelStateWriterImpl != null) {
                if (0 != 0) {
                    try {
                        channelStateWriterImpl.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    channelStateWriterImpl.close();
                }
            }
            return andRemoveWriteResult;
        } catch (Throwable th3) {
            if (channelStateWriterImpl != null) {
                if (0 != 0) {
                    try {
                        channelStateWriterImpl.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    channelStateWriterImpl.close();
                }
            }
            throw th3;
        }
    }

    public static CheckpointStorageWorkerView getStreamFactoryFactory() {
        return getStreamFactoryFactory(42);
    }

    public static CheckpointStorageWorkerView getStreamFactoryFactory(final int i) {
        return new CheckpointStorageWorkerView() { // from class: org.apache.flink.runtime.state.ChannelPersistenceITCase.1
            public CheckpointStreamFactory resolveCheckpointStorageLocation(long j, CheckpointStorageLocationReference checkpointStorageLocationReference) {
                return new NonPersistentMetadataCheckpointStorageLocation(i);
            }

            public CheckpointStateOutputStream createTaskOwnedStateStream() {
                throw new UnsupportedOperationException();
            }

            public CheckpointStateToolset createTaskOwnedCheckpointStateToolset() {
                throw new UnsupportedOperationException();
            }
        };
    }

    private TaskStateSnapshot toTaskStateSnapshot(ChannelStateWriter.ChannelStateWriteResult channelStateWriteResult) throws Exception {
        return new TaskStateSnapshot(Collections.singletonMap(new OperatorID(), OperatorSubtaskState.builder().setInputChannelState(new StateObjectCollection((Collection) channelStateWriteResult.getInputChannelStateHandles().get())).setResultSubpartitionState(new StateObjectCollection((Collection) channelStateWriteResult.getResultSubpartitionStateHandles().get())).build()));
    }

    private static int sizeOfBytes(Map<?, byte[]> map) {
        return map.values().stream().mapToInt(bArr -> {
            return bArr.length;
        }).sum();
    }

    private <K> Map<K, Buffer> wrapWithBuffers(Map<K, byte[]> map) {
        return (Map) map.entrySet().stream().collect(Collectors.toMap((v0) -> {
            return v0.getKey();
        }, entry -> {
            return wrapWithBuffer((byte[]) entry.getValue());
        }));
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static Buffer wrapWithBuffer(byte[] bArr) {
        NetworkBuffer networkBuffer = new NetworkBuffer(MemorySegmentFactory.allocateUnpooledSegment(bArr.length, (Object) null), FreeingBufferRecycler.INSTANCE);
        networkBuffer.writeBytes(bArr);
        return networkBuffer;
    }
}
