package org.apache.flink.runtime.checkpoint;

import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import org.apache.flink.api.common.JobID;
import org.apache.flink.runtime.concurrent.Executors;
import org.apache.flink.runtime.execution.ExecutionState;
import org.apache.flink.runtime.executiongraph.Execution;
import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
import org.apache.flink.runtime.executiongraph.ExecutionJobVertex;
import org.apache.flink.runtime.executiongraph.ExecutionVertex;
import org.apache.flink.runtime.jobgraph.JobVertexID;
import org.apache.flink.runtime.jobgraph.OperatorID;
import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
import org.apache.flink.runtime.state.ChainedStateHandle;
import org.apache.flink.runtime.state.KeyGroupRange;
import org.apache.flink.runtime.state.KeyGroupsStateHandle;
import org.apache.flink.runtime.state.KeyedStateHandle;
import org.apache.flink.runtime.state.OperatorStateHandle;
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.runtime.state.TaskStateHandles;
import org.apache.flink.runtime.util.SerializableObject;
import org.hamcrest.BaseMatcher;
import org.hamcrest.Description;
import org.junit.Assert;
import org.junit.Test;
import org.mockito.Mockito;

/* loaded from: input_file:org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.class */
public class CheckpointStateRestoreTest {
    @Test
    public void testSetState() {
        try {
            ChainedStateHandle<StreamStateHandle> generateChainedStateHandle = CheckpointCoordinatorTest.generateChainedStateHandle(new SerializableObject());
            KeyGroupsStateHandle generateKeyGroupState = CheckpointCoordinatorTest.generateKeyGroupState(KeyGroupRange.of(0, 0), Collections.singletonList(new SerializableObject()));
            JobID jobID = new JobID();
            JobVertexID jobVertexID = new JobVertexID();
            JobVertexID jobVertexID2 = new JobVertexID();
            Execution mockExecution = mockExecution();
            Execution mockExecution2 = mockExecution();
            Execution mockExecution3 = mockExecution();
            Execution mockExecution4 = mockExecution();
            Execution mockExecution5 = mockExecution();
            ExecutionVertex mockExecutionVertex = mockExecutionVertex(mockExecution, jobVertexID, 0, 3);
            ExecutionVertex mockExecutionVertex2 = mockExecutionVertex(mockExecution2, jobVertexID, 1, 3);
            ExecutionVertex mockExecutionVertex3 = mockExecutionVertex(mockExecution3, jobVertexID, 2, 3);
            ExecutionVertex mockExecutionVertex4 = mockExecutionVertex(mockExecution4, jobVertexID2, 0, 2);
            ExecutionVertex mockExecutionVertex5 = mockExecutionVertex(mockExecution5, jobVertexID2, 1, 2);
            ExecutionJobVertex mockExecutionJobVertex = mockExecutionJobVertex(jobVertexID, new ExecutionVertex[]{mockExecutionVertex, mockExecutionVertex2, mockExecutionVertex3});
            ExecutionJobVertex mockExecutionJobVertex2 = mockExecutionJobVertex(jobVertexID2, new ExecutionVertex[]{mockExecutionVertex4, mockExecutionVertex5});
            HashMap hashMap = new HashMap();
            hashMap.put(jobVertexID, mockExecutionJobVertex);
            hashMap.put(jobVertexID2, mockExecutionJobVertex2);
            CheckpointCoordinator checkpointCoordinator = new CheckpointCoordinator(jobID, 200000L, 200000L, 0L, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[]{mockExecutionVertex, mockExecutionVertex2, mockExecutionVertex3, mockExecutionVertex4, mockExecutionVertex5}, new ExecutionVertex[]{mockExecutionVertex, mockExecutionVertex2, mockExecutionVertex3, mockExecutionVertex4, mockExecutionVertex5}, new ExecutionVertex[0], new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), (String) null, Executors.directExecutor());
            checkpointCoordinator.triggerCheckpoint(34623786L, false);
            long checkpointId = ((PendingCheckpoint) checkpointCoordinator.getPendingCheckpoints().values().iterator().next()).getCheckpointId();
            SubtaskState subtaskState = new SubtaskState(generateChainedStateHandle, (ChainedStateHandle) null, (ChainedStateHandle) null, generateKeyGroupState, (KeyedStateHandle) null);
            checkpointCoordinator.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobID, mockExecution.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskState));
            checkpointCoordinator.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobID, mockExecution2.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskState));
            checkpointCoordinator.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobID, mockExecution3.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskState));
            checkpointCoordinator.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobID, mockExecution4.getAttemptId(), checkpointId));
            checkpointCoordinator.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobID, mockExecution5.getAttemptId(), checkpointId));
            Assert.assertEquals(1L, checkpointCoordinator.getNumberOfRetainedSuccessfulCheckpoints());
            Assert.assertEquals(0L, checkpointCoordinator.getNumberOfPendingCheckpoints());
            checkpointCoordinator.restoreLatestCheckpointedState(hashMap, true, false);
            final TaskStateHandles taskStateHandles = new TaskStateHandles(generateChainedStateHandle, Collections.singletonList(null), Collections.singletonList(null), Collections.singletonList(generateKeyGroupState), (Collection) null);
            BaseMatcher<TaskStateHandles> baseMatcher = new BaseMatcher<TaskStateHandles>() { // from class: org.apache.flink.runtime.checkpoint.CheckpointStateRestoreTest.1
                public boolean matches(Object obj) {
                    if (obj instanceof TaskStateHandles) {
                        return obj.equals(taskStateHandles);
                    }
                    return false;
                }

                public void describeTo(Description description) {
                    description.appendValue(taskStateHandles);
                }
            };
            ((Execution) Mockito.verify(mockExecution, Mockito.times(1))).setInitialState((TaskStateHandles) Mockito.argThat(baseMatcher));
            ((Execution) Mockito.verify(mockExecution2, Mockito.times(1))).setInitialState((TaskStateHandles) Mockito.argThat(baseMatcher));
            ((Execution) Mockito.verify(mockExecution3, Mockito.times(1))).setInitialState((TaskStateHandles) Mockito.argThat(baseMatcher));
            ((Execution) Mockito.verify(mockExecution4, Mockito.times(0))).setInitialState((TaskStateHandles) Mockito.any());
            ((Execution) Mockito.verify(mockExecution5, Mockito.times(0))).setInitialState((TaskStateHandles) Mockito.any());
        } catch (Exception e) {
            e.printStackTrace();
            Assert.fail(e.getMessage());
        }
    }

    @Test
    public void testNoCheckpointAvailable() {
        try {
            try {
                new CheckpointCoordinator(new JobID(), 200000L, 200000L, 0L, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[]{(ExecutionVertex) Mockito.mock(ExecutionVertex.class)}, new ExecutionVertex[]{(ExecutionVertex) Mockito.mock(ExecutionVertex.class)}, new ExecutionVertex[0], new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), (String) null, Executors.directExecutor()).restoreLatestCheckpointedState(new HashMap(), true, false);
                Assert.fail("this should throw an exception");
            } catch (IllegalStateException e) {
            }
        } catch (Exception e2) {
            e2.printStackTrace();
            Assert.fail(e2.getMessage());
        }
    }

    @Test
    public void testNonRestoredState() throws Exception {
        JobVertexID jobVertexID = new JobVertexID();
        JobVertexID jobVertexID2 = new JobVertexID();
        OperatorID fromJobVertexID = OperatorID.fromJobVertexID(jobVertexID);
        ExecutionVertex mockExecutionVertex = mockExecutionVertex(mockExecution(), jobVertexID, 0, 3);
        ExecutionVertex mockExecutionVertex2 = mockExecutionVertex(mockExecution(), jobVertexID, 1, 3);
        ExecutionVertex mockExecutionVertex3 = mockExecutionVertex(mockExecution(), jobVertexID, 2, 3);
        ExecutionVertex mockExecutionVertex4 = mockExecutionVertex(mockExecution(), jobVertexID2, 0, 2);
        ExecutionVertex mockExecutionVertex5 = mockExecutionVertex(mockExecution(), jobVertexID2, 1, 2);
        ExecutionJobVertex mockExecutionJobVertex = mockExecutionJobVertex(jobVertexID, new ExecutionVertex[]{mockExecutionVertex, mockExecutionVertex2, mockExecutionVertex3});
        ExecutionJobVertex mockExecutionJobVertex2 = mockExecutionJobVertex(jobVertexID2, new ExecutionVertex[]{mockExecutionVertex4, mockExecutionVertex5});
        HashMap hashMap = new HashMap();
        hashMap.put(jobVertexID, mockExecutionJobVertex);
        hashMap.put(jobVertexID2, mockExecutionJobVertex2);
        CheckpointCoordinator checkpointCoordinator = new CheckpointCoordinator(new JobID(), 2147483647L, 2147483647L, 0L, Integer.MAX_VALUE, ExternalizedCheckpointSettings.none(), new ExecutionVertex[0], new ExecutionVertex[0], new ExecutionVertex[0], new StandaloneCheckpointIDCounter(), new StandaloneCompletedCheckpointStore(1), (String) null, Executors.directExecutor());
        StreamStateHandle streamStateHandle = CheckpointCoordinatorTest.generateChainedStateHandle(new SerializableObject()).get(0);
        HashMap hashMap2 = new HashMap();
        OperatorState operatorState = new OperatorState(fromJobVertexID, 3, 3);
        operatorState.putState(0, new OperatorSubtaskState(streamStateHandle, (OperatorStateHandle) null, (OperatorStateHandle) null, (KeyedStateHandle) null, (KeyedStateHandle) null));
        operatorState.putState(1, new OperatorSubtaskState(streamStateHandle, (OperatorStateHandle) null, (OperatorStateHandle) null, (KeyedStateHandle) null, (KeyedStateHandle) null));
        operatorState.putState(2, new OperatorSubtaskState(streamStateHandle, (OperatorStateHandle) null, (OperatorStateHandle) null, (KeyedStateHandle) null, (KeyedStateHandle) null));
        hashMap2.put(fromJobVertexID, operatorState);
        checkpointCoordinator.getCheckpointStore().addCheckpoint(new CompletedCheckpoint(new JobID(), 0L, 1L, 2L, new HashMap(hashMap2), Collections.emptyList(), CheckpointProperties.forStandardCheckpoint(), (StreamStateHandle) null, (String) null));
        checkpointCoordinator.restoreLatestCheckpointedState(hashMap, true, false);
        checkpointCoordinator.restoreLatestCheckpointedState(hashMap, true, true);
        OperatorID fromJobVertexID2 = OperatorID.fromJobVertexID(new JobVertexID());
        OperatorState operatorState2 = new OperatorState(fromJobVertexID2, 1, 1);
        operatorState2.putState(0, new OperatorSubtaskState(streamStateHandle, (OperatorStateHandle) null, (OperatorStateHandle) null, (KeyedStateHandle) null, (KeyedStateHandle) null));
        hashMap2.put(fromJobVertexID2, operatorState2);
        checkpointCoordinator.getCheckpointStore().addCheckpoint(new CompletedCheckpoint(new JobID(), 1L, 2L, 3L, new HashMap(hashMap2), Collections.emptyList(), CheckpointProperties.forStandardCheckpoint(), (StreamStateHandle) null, (String) null));
        checkpointCoordinator.restoreLatestCheckpointedState(hashMap, true, true);
        try {
            checkpointCoordinator.restoreLatestCheckpointedState(hashMap, true, false);
            Assert.fail("Did not throw the expected Exception.");
        } catch (IllegalStateException e) {
        }
    }

    private Execution mockExecution() {
        return mockExecution(ExecutionState.RUNNING);
    }

    private Execution mockExecution(ExecutionState executionState) {
        Execution execution = (Execution) Mockito.mock(Execution.class);
        Mockito.when(execution.getAttemptId()).thenReturn(new ExecutionAttemptID());
        Mockito.when(execution.getState()).thenReturn(executionState);
        return execution;
    }

    private ExecutionVertex mockExecutionVertex(Execution execution, JobVertexID jobVertexID, int i, int i2) {
        ExecutionVertex executionVertex = (ExecutionVertex) Mockito.mock(ExecutionVertex.class);
        Mockito.when(executionVertex.getJobvertexId()).thenReturn(jobVertexID);
        Mockito.when(Integer.valueOf(executionVertex.getParallelSubtaskIndex())).thenReturn(Integer.valueOf(i));
        Mockito.when(executionVertex.getCurrentExecutionAttempt()).thenReturn(execution);
        Mockito.when(Integer.valueOf(executionVertex.getTotalNumberOfParallelSubtasks())).thenReturn(Integer.valueOf(i2));
        Mockito.when(Integer.valueOf(executionVertex.getMaxParallelism())).thenReturn(Integer.valueOf(i2));
        return executionVertex;
    }

    private ExecutionJobVertex mockExecutionJobVertex(JobVertexID jobVertexID, ExecutionVertex[] executionVertexArr) {
        ExecutionJobVertex executionJobVertex = (ExecutionJobVertex) Mockito.mock(ExecutionJobVertex.class);
        Mockito.when(Integer.valueOf(executionJobVertex.getParallelism())).thenReturn(Integer.valueOf(executionVertexArr.length));
        Mockito.when(Integer.valueOf(executionJobVertex.getMaxParallelism())).thenReturn(Integer.valueOf(executionVertexArr.length));
        Mockito.when(executionJobVertex.getJobVertexId()).thenReturn(jobVertexID);
        Mockito.when(executionJobVertex.getTaskVertices()).thenReturn(executionVertexArr);
        Mockito.when(executionJobVertex.getOperatorIDs()).thenReturn(Collections.singletonList(OperatorID.fromJobVertexID(jobVertexID)));
        Mockito.when(executionJobVertex.getUserDefinedOperatorIDs()).thenReturn(Collections.singletonList(null));
        for (ExecutionVertex executionVertex : executionVertexArr) {
            Mockito.when(executionVertex.getJobVertex()).thenReturn(executionJobVertex);
        }
        return executionJobVertex;
    }
}
