package org.apache.flink.test.checkpointing;

import java.io.File;
import java.time.Duration;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import org.apache.flink.api.common.functions.RichFlatMapFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.api.common.time.Deadline;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.client.program.ClusterClient;
import org.apache.flink.configuration.CheckpointingOptions;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.configuration.StateBackendOptions;
import org.apache.flink.runtime.jobgraph.JobGraph;
import org.apache.flink.runtime.jobgraph.SavepointRestoreSettings;
import org.apache.flink.runtime.state.FunctionInitializationContext;
import org.apache.flink.runtime.state.FunctionSnapshotContext;
import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
import org.apache.flink.runtime.testutils.MiniClusterResourceConfiguration;
import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
import org.apache.flink.streaming.api.datastream.KeyedStream;
import org.apache.flink.streaming.api.environment.CheckpointConfig;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
import org.apache.flink.streaming.api.functions.source.SourceFunction;
import org.apache.flink.test.util.MiniClusterWithClientResource;
import org.apache.flink.test.util.TestUtils;
import org.apache.flink.util.Collector;
import org.apache.flink.util.TestLogger;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

/* loaded from: input_file:org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase.class */
public class RescaleCheckpointManuallyITCase extends TestLogger {
    private static final int NUM_TASK_MANAGERS = 2;
    private static final int SLOTS_PER_TASK_MANAGER = 2;
    private static MiniClusterWithClientResource cluster;
    private File checkpointDir;

    @ClassRule
    public static TemporaryFolder temporaryFolder = new TemporaryFolder();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase$CollectionSink.class */
    public static class CollectionSink<IN> implements SinkFunction<IN> {
        private static final Set<Object> elements = Collections.newSetFromMap(new ConcurrentHashMap());
        private static final long serialVersionUID = 1;

        private CollectionSink() {
        }

        public static <IN> Set<IN> getElementsSet() {
            return (Set<IN>) elements;
        }

        public static void clearElementsSet() {
            elements.clear();
        }

        public void invoke(IN in) throws Exception {
            elements.add(in);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase$NotifyingDefiniteKeySource.class */
    public static class NotifyingDefiniteKeySource extends RichParallelSourceFunction<Integer> {
        private static final long serialVersionUID = 1;
        private final int numberKeys;
        private final int numberElements;
        private final boolean terminateAfterEmission;
        protected int counter = 0;
        private boolean running = true;

        public NotifyingDefiniteKeySource(int i, int i2, boolean z) {
            this.numberKeys = i;
            this.numberElements = i2;
            this.terminateAfterEmission = z;
        }

        public void run(SourceFunction.SourceContext<Integer> sourceContext) throws Exception {
            int indexOfThisSubtask = getRuntimeContext().getIndexOfThisSubtask();
            while (this.running) {
                if (this.counter < this.numberElements) {
                    synchronized (sourceContext.getCheckpointLock()) {
                        int i = indexOfThisSubtask;
                        while (i < this.numberKeys) {
                            sourceContext.collect(Integer.valueOf(i));
                            i += getRuntimeContext().getNumberOfParallelSubtasks();
                        }
                        this.counter++;
                    }
                } else if (this.terminateAfterEmission) {
                    this.running = false;
                } else {
                    Thread.sleep(100L);
                }
            }
        }

        public void cancel() {
            this.running = false;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/flink/test/checkpointing/RescaleCheckpointManuallyITCase$SubtaskIndexFlatMapper.class */
    public static class SubtaskIndexFlatMapper extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>> implements CheckpointedFunction {
        private static final long serialVersionUID = 1;
        public static CountDownLatch workCompletedLatch = new CountDownLatch(1);
        private transient ValueState<Integer> counter;
        private transient ValueState<Integer> sum;
        private final int numberElements;

        public SubtaskIndexFlatMapper(int i) {
            this.numberElements = i;
        }

        public void flatMap(Integer num, Collector<Tuple2<Integer, Integer>> collector) throws Exception {
            Integer num2 = (Integer) this.counter.value();
            int intValue = num2 == null ? 1 : num2.intValue() + 1;
            this.counter.update(Integer.valueOf(intValue));
            Integer num3 = (Integer) this.sum.value();
            int intValue2 = num3 == null ? num.intValue() : num3.intValue() + num.intValue();
            this.sum.update(Integer.valueOf(intValue2));
            if (intValue % this.numberElements == 0) {
                collector.collect(Tuple2.of(Integer.valueOf(getRuntimeContext().getIndexOfThisSubtask()), Integer.valueOf(intValue2)));
                workCompletedLatch.countDown();
            }
        }

        public void snapshotState(FunctionSnapshotContext functionSnapshotContext) throws Exception {
        }

        public void initializeState(FunctionInitializationContext functionInitializationContext) throws Exception {
            this.counter = functionInitializationContext.getKeyedStateStore().getState(new ValueStateDescriptor("counter", Integer.class));
            this.sum = functionInitializationContext.getKeyedStateStore().getState(new ValueStateDescriptor("sum", Integer.class));
        }

        public /* bridge */ /* synthetic */ void flatMap(Object obj, Collector collector) throws Exception {
            flatMap((Integer) obj, (Collector<Tuple2<Integer, Integer>>) collector);
        }
    }

    @Before
    public void setup() throws Exception {
        Configuration configuration = new Configuration();
        this.checkpointDir = temporaryFolder.newFolder();
        configuration.setString(StateBackendOptions.STATE_BACKEND, "rocksdb");
        configuration.setString(CheckpointingOptions.CHECKPOINTS_DIRECTORY, this.checkpointDir.toURI().toString());
        configuration.setBoolean(CheckpointingOptions.INCREMENTAL_CHECKPOINTS, true);
        cluster = new MiniClusterWithClientResource(new MiniClusterResourceConfiguration.Builder().setConfiguration(configuration).setNumberTaskManagers(2).setNumberSlotsPerTaskManager(2).build());
        cluster.before();
    }

    @After
    public void shutDownExistingCluster() {
        if (cluster != null) {
            cluster.after();
            cluster = null;
        }
    }

    @Test
    public void testCheckpointRescalingInKeyedState() throws Exception {
        testCheckpointRescalingKeyedState(false);
    }

    @Test
    public void testCheckpointRescalingOutKeyedState() throws Exception {
        testCheckpointRescalingKeyedState(true);
    }

    public void testCheckpointRescalingKeyedState(boolean z) throws Exception {
        int i = z ? 3 : 4;
        int i2 = z ? 4 : 3;
        ClusterClient<?> clusterClient = cluster.getClusterClient();
        String runJobAndGetCheckpoint = runJobAndGetCheckpoint(42, 1000, i, 13, clusterClient, this.checkpointDir);
        Assert.assertNotNull(runJobAndGetCheckpoint);
        restoreAndAssert(i2, 13, 42, 500, 1500, clusterClient, runJobAndGetCheckpoint);
    }

    private static String runJobAndGetCheckpoint(int i, int i2, int i3, int i4, ClusterClient<?> clusterClient, File file) throws Exception {
        try {
            Deadline plus = Deadline.now().plus(Duration.ofMinutes(5L));
            JobGraph createJobGraphWithKeyedState = createJobGraphWithKeyedState(i3, i4, i, i2, false, 100);
            clusterClient.submitJob(createJobGraphWithKeyedState).get();
            Assert.assertTrue(SubtaskIndexFlatMapper.workCompletedLatch.await(plus.timeLeft().toMillis(), TimeUnit.MILLISECONDS));
            Set elementsSet = CollectionSink.getElementsSet();
            HashSet hashSet = new HashSet();
            for (int i5 = 0; i5 < i; i5++) {
                hashSet.add(Tuple2.of(Integer.valueOf(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(i4, i3, KeyGroupRangeAssignment.assignToKeyGroup(Integer.valueOf(i5), i4))), Integer.valueOf(i2 * i5)));
            }
            Assert.assertEquals(hashSet, elementsSet);
            cluster.getMiniCluster().triggerCheckpoint(createJobGraphWithKeyedState.getJobID()).get();
            clusterClient.cancel(createJobGraphWithKeyedState.getJobID()).get();
            TestUtils.waitUntilJobCanceled(createJobGraphWithKeyedState.getJobID(), clusterClient);
            String absolutePath = TestUtils.getMostRecentCompletedCheckpoint(file).getAbsolutePath();
            CollectionSink.clearElementsSet();
            return absolutePath;
        } catch (Throwable th) {
            CollectionSink.clearElementsSet();
            throw th;
        }
    }

    private void restoreAndAssert(int i, int i2, int i3, int i4, int i5, ClusterClient<?> clusterClient, String str) throws Exception {
        try {
            JobGraph createJobGraphWithKeyedState = createJobGraphWithKeyedState(i, i2, i3, i4, true, 100);
            createJobGraphWithKeyedState.setSavepointRestoreSettings(SavepointRestoreSettings.forPath(str));
            TestUtils.submitJobAndWaitForResult(clusterClient, createJobGraphWithKeyedState, getClass().getClassLoader());
            Set elementsSet = CollectionSink.getElementsSet();
            HashSet hashSet = new HashSet();
            for (int i6 = 0; i6 < i3; i6++) {
                hashSet.add(Tuple2.of(Integer.valueOf(KeyGroupRangeAssignment.computeOperatorIndexForKeyGroup(i2, i, KeyGroupRangeAssignment.assignToKeyGroup(Integer.valueOf(i6), i2))), Integer.valueOf(i6 * i5)));
            }
            Assert.assertEquals(hashSet, elementsSet);
        } finally {
            CollectionSink.clearElementsSet();
        }
    }

    private static JobGraph createJobGraphWithKeyedState(int i, int i2, int i3, int i4, boolean z, int i5) {
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(i);
        if (0 < i2) {
            executionEnvironment.getConfig().setMaxParallelism(i2);
        }
        executionEnvironment.enableCheckpointing(i5);
        executionEnvironment.getCheckpointConfig().setExternalizedCheckpointCleanup(CheckpointConfig.ExternalizedCheckpointCleanup.RETAIN_ON_CANCELLATION);
        executionEnvironment.setRestartStrategy(RestartStrategies.noRestart());
        executionEnvironment.getConfig().setUseSnapshotCompression(true);
        KeyedStream keyBy = executionEnvironment.addSource(new NotifyingDefiniteKeySource(i3, i4, z)).keyBy(new KeySelector<Integer, Integer>() { // from class: org.apache.flink.test.checkpointing.RescaleCheckpointManuallyITCase.1
            private static final long serialVersionUID = 1;

            public Integer getKey(Integer num) throws Exception {
                return num;
            }
        });
        SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(i3);
        keyBy.flatMap(new SubtaskIndexFlatMapper(i4)).addSink(new CollectionSink());
        return executionEnvironment.getStreamGraph().getJobGraph();
    }
}
