package org.apache.flink.streaming.runtime.tasks;

import java.util.HashMap;
import java.util.concurrent.Future;
import java.util.concurrent.RunnableFuture;
import java.util.concurrent.atomic.AtomicBoolean;
import javax.annotation.Nonnegative;
import javax.annotation.Nullable;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.JobID;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
import org.apache.flink.runtime.checkpoint.CheckpointMetricsBuilder;
import org.apache.flink.runtime.checkpoint.JobManagerTaskRestore;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.checkpoint.StateObjectCollection;
import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
import org.apache.flink.runtime.clusterframework.types.AllocationID;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
import org.apache.flink.runtime.state.DoneFuture;
import org.apache.flink.runtime.state.InputChannelStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.LocalRecoveryConfig;
import org.apache.flink.runtime.state.LocalRecoveryDirectoryProviderImpl;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.ResultSubpartitionStateHandle;
import org.apache.flink.runtime.state.SnapshotResult;
import org.apache.flink.runtime.state.StateObject;
import org.apache.flink.runtime.state.TaskLocalStateStoreImpl;
import org.apache.flink.runtime.state.TaskStateManagerImpl;
import org.apache.flink.runtime.state.TestTaskStateManager;
import org.apache.flink.runtime.state.changelog.inmemory.InMemoryStateChangelogStorage;
import org.apache.flink.runtime.taskmanager.TestCheckpointResponder;
import org.apache.flink.streaming.api.operators.OperatorSnapshotFutures;
import org.apache.flink.streaming.runtime.tasks.StreamTaskITCase;
import org.apache.flink.util.TestLogger;
import org.apache.flink.util.concurrent.Executors;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/flink/streaming/runtime/tasks/LocalStateForwardingTest.class */
public class LocalStateForwardingTest extends TestLogger {

    @Rule
    public TemporaryFolder temporaryFolder = new TemporaryFolder();

    @Test
    public void testReportingFromSnapshotToTaskStateManager() throws Exception {
        TestTaskStateManager testTaskStateManager = new TestTaskStateManager();
        StreamTaskITCase.NoOpStreamTask noOpStreamTask = new StreamTaskITCase.NoOpStreamTask(new StreamMockEnvironment(new Configuration(), new Configuration(), new ExecutionConfig(), 1048576L, new MockInputSplitProvider(), 0, testTaskStateManager));
        CheckpointMetaData checkpointMetaData = new CheckpointMetaData(0L, 0L);
        CheckpointMetricsBuilder checkpointMetricsBuilder = new CheckpointMetricsBuilder();
        HashMap hashMap = new HashMap(1);
        OperatorSnapshotFutures operatorSnapshotFutures = new OperatorSnapshotFutures();
        operatorSnapshotFutures.setKeyedStateManagedFuture(createSnapshotResult(KeyedStateHandle.class));
        operatorSnapshotFutures.setKeyedStateRawFuture(createSnapshotResult(KeyedStateHandle.class));
        operatorSnapshotFutures.setOperatorStateManagedFuture(createSnapshotResult(OperatorStateHandle.class));
        operatorSnapshotFutures.setOperatorStateRawFuture(createSnapshotResult(OperatorStateHandle.class));
        operatorSnapshotFutures.setInputChannelStateFuture(createSnapshotCollectionResult(InputChannelStateHandle.class));
        operatorSnapshotFutures.setResultSubpartitionStateFuture(createSnapshotCollectionResult(ResultSubpartitionStateHandle.class));
        OperatorID operatorID = new OperatorID();
        hashMap.put(operatorID, operatorSnapshotFutures);
        AsyncCheckpointRunnable asyncCheckpointRunnable = new AsyncCheckpointRunnable(hashMap, checkpointMetaData, checkpointMetricsBuilder, 0L, noOpStreamTask.getName(), asyncCheckpointRunnable2 -> {
        }, noOpStreamTask.getEnvironment(), noOpStreamTask, false, false, () -> {
            return true;
        });
        checkpointMetricsBuilder.setAlignmentDurationNanos(0L);
        checkpointMetricsBuilder.setBytesProcessedDuringAlignment(0L);
        asyncCheckpointRunnable.run();
        TaskStateSnapshot lastJobManagerTaskStateSnapshot = testTaskStateManager.getLastJobManagerTaskStateSnapshot();
        TaskStateSnapshot lastTaskManagerTaskStateSnapshot = testTaskStateManager.getLastTaskManagerTaskStateSnapshot();
        OperatorSubtaskState subtaskStateByOperatorID = lastJobManagerTaskStateSnapshot.getSubtaskStateByOperatorID(operatorID);
        OperatorSubtaskState subtaskStateByOperatorID2 = lastTaskManagerTaskStateSnapshot.getSubtaskStateByOperatorID(operatorID);
        performCheck(operatorSnapshotFutures.getKeyedStateManagedFuture(), subtaskStateByOperatorID.getManagedKeyedState(), subtaskStateByOperatorID2.getManagedKeyedState());
        performCheck(operatorSnapshotFutures.getKeyedStateRawFuture(), subtaskStateByOperatorID.getRawKeyedState(), subtaskStateByOperatorID2.getRawKeyedState());
        performCheck(operatorSnapshotFutures.getOperatorStateManagedFuture(), subtaskStateByOperatorID.getManagedOperatorState(), subtaskStateByOperatorID2.getManagedOperatorState());
        performCheck(operatorSnapshotFutures.getOperatorStateRawFuture(), subtaskStateByOperatorID.getRawOperatorState(), subtaskStateByOperatorID2.getRawOperatorState());
        performCollectionCheck(operatorSnapshotFutures.getInputChannelStateFuture(), subtaskStateByOperatorID.getInputChannelState(), subtaskStateByOperatorID2.getInputChannelState());
        performCollectionCheck(operatorSnapshotFutures.getResultSubpartitionStateFuture(), subtaskStateByOperatorID.getResultSubpartitionState(), subtaskStateByOperatorID2.getResultSubpartitionState());
    }

    @Test
    public void testReportingFromTaskStateManagerToResponderAndTaskLocalStateStore() throws Exception {
        final JobID jobID = new JobID();
        AllocationID allocationID = new AllocationID();
        final ExecutionAttemptID executionAttemptID = new ExecutionAttemptID();
        final CheckpointMetaData checkpointMetaData = new CheckpointMetaData(42L, 4711L);
        final CheckpointMetrics checkpointMetrics = new CheckpointMetrics();
        JobVertexID jobVertexID = new JobVertexID();
        TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
        final TaskStateSnapshot taskStateSnapshot2 = new TaskStateSnapshot();
        final AtomicBoolean atomicBoolean = new AtomicBoolean(false);
        final AtomicBoolean atomicBoolean2 = new AtomicBoolean(false);
        int i = 42;
        new TaskStateManagerImpl(jobID, executionAttemptID, new TaskLocalStateStoreImpl(jobID, allocationID, jobVertexID, i, new LocalRecoveryConfig(new LocalRecoveryDirectoryProviderImpl(this.temporaryFolder.newFolder(), jobID, jobVertexID, 42)), Executors.directExecutor()) { // from class: org.apache.flink.streaming.runtime.tasks.LocalStateForwardingTest.2
            public void storeLocalState(@Nonnegative long j, @Nullable TaskStateSnapshot taskStateSnapshot3) {
                Assert.assertEquals(taskStateSnapshot2, taskStateSnapshot3);
                atomicBoolean2.set(true);
            }
        }, new InMemoryStateChangelogStorage(), (JobManagerTaskRestore) null, new TestCheckpointResponder() { // from class: org.apache.flink.streaming.runtime.tasks.LocalStateForwardingTest.1
            public void acknowledgeCheckpoint(JobID jobID2, ExecutionAttemptID executionAttemptID2, long j, CheckpointMetrics checkpointMetrics2, TaskStateSnapshot taskStateSnapshot3) {
                Assert.assertEquals(jobID, jobID2);
                Assert.assertEquals(executionAttemptID, executionAttemptID2);
                Assert.assertEquals(checkpointMetaData.getCheckpointId(), j);
                Assert.assertEquals(checkpointMetrics, checkpointMetrics2);
                atomicBoolean.set(true);
            }
        }).reportTaskStateSnapshots(checkpointMetaData, checkpointMetrics, taskStateSnapshot, taskStateSnapshot2);
        Assert.assertTrue("Reporting for JM state was not called.", atomicBoolean.get());
        Assert.assertTrue("Reporting for TM state was not called.", atomicBoolean2.get());
    }

    private static <T extends StateObject> void performCheck(Future<SnapshotResult<T>> future, StateObjectCollection<T> stateObjectCollection, StateObjectCollection<T> stateObjectCollection2) {
        try {
            SnapshotResult<T> snapshotResult = future.get();
            Assert.assertEquals(snapshotResult.getJobManagerOwnedSnapshot(), stateObjectCollection.iterator().next());
            Assert.assertEquals(snapshotResult.getTaskLocalSnapshot(), stateObjectCollection2.iterator().next());
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static <T extends StateObject> void performCollectionCheck(Future<SnapshotResult<StateObjectCollection<T>>> future, StateObjectCollection<T> stateObjectCollection, StateObjectCollection<T> stateObjectCollection2) {
        try {
            SnapshotResult<StateObjectCollection<T>> snapshotResult = future.get();
            Assert.assertEquals(snapshotResult.getJobManagerOwnedSnapshot(), stateObjectCollection);
            Assert.assertEquals(snapshotResult.getTaskLocalSnapshot(), stateObjectCollection2);
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    private static <T extends StateObject> RunnableFuture<SnapshotResult<T>> createSnapshotResult(Class<T> cls) {
        return DoneFuture.of(SnapshotResult.withLocalState((StateObject) Mockito.mock(cls), (StateObject) Mockito.mock(cls)));
    }

    private static <T extends StateObject> RunnableFuture<SnapshotResult<StateObjectCollection<T>>> createSnapshotCollectionResult(Class<T> cls) {
        return DoneFuture.of(SnapshotResult.withLocalState(StateObjectCollection.singleton((StateObject) Mockito.mock(cls)), StateObjectCollection.singleton((StateObject) Mockito.mock(cls))));
    }
}
