/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.streaming.util;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.functions.sink.RichSinkFunction;
import org.apache.flink.streaming.api.functions.sink.SinkFunction;
import org.junit.jupiter.api.extension.AfterEachCallback;
import org.junit.jupiter.api.extension.BeforeEachCallback;
import org.junit.jupiter.api.extension.ExtensionContext;

public class StreamCollectorExtension
implements BeforeEachCallback,
AfterEachCallback {
    private static final AtomicLong counter = new AtomicLong();
    private static final Map<Long, CountDownLatch> latches = new ConcurrentHashMap<Long, CountDownLatch>();
    private static final Map<Long, Queue> resultQueues = new ConcurrentHashMap<Long, Queue>();
    private List<Long> ids;

    public void beforeEach(ExtensionContext context) throws Exception {
        this.ids = new ArrayList<Long>();
    }

    public <IN> CompletableFuture<Collection<IN>> collect(DataStream<IN> stream) {
        long id = counter.getAndIncrement();
        this.ids.add(id);
        int parallelism = stream.getParallelism();
        if (parallelism == -1) {
            parallelism = stream.getExecutionEnvironment().getParallelism();
        }
        CountDownLatch latch = new CountDownLatch(parallelism);
        latches.put(id, latch);
        ConcurrentLinkedDeque results = new ConcurrentLinkedDeque();
        resultQueues.put(id, results);
        stream.addSink(new CollectingSink(id));
        return CompletableFuture.runAsync(() -> {
            try {
                latch.await();
            }
            catch (InterruptedException e) {
                throw new RuntimeException("Failed to collect results");
            }
        }).thenApply(ignore -> results);
    }

    public void afterEach(ExtensionContext context) throws Exception {
        for (Long id : this.ids) {
            latches.remove(id);
            resultQueues.remove(id);
        }
    }

    private static class CollectingSink<IN>
    extends RichSinkFunction<IN> {
        private final long id;
        private transient CountDownLatch latch;
        private transient Queue<IN> results;

        private CollectingSink(long id) {
            this.id = id;
        }

        public void open(Configuration parameters) throws Exception {
            this.latch = (CountDownLatch)latches.get(this.id);
            this.results = (Queue)resultQueues.get(this.id);
        }

        public void invoke(IN value, SinkFunction.Context context) throws Exception {
            this.results.add(value);
        }

        public void close() throws Exception {
            this.latch.countDown();
        }
    }
}

