/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.runtime.operators;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.RichCoGroupFunction;
import org.apache.flink.api.common.typeutils.TypePairComparatorFactory;
import org.apache.flink.runtime.operators.CoGroupDriver;
import org.apache.flink.runtime.operators.Driver;
import org.apache.flink.runtime.operators.DriverStrategy;
import org.apache.flink.runtime.operators.testutils.DriverTestBase;
import org.apache.flink.runtime.operators.testutils.UniformRecordGenerator;
import org.apache.flink.runtime.testutils.recordutils.RecordComparator;
import org.apache.flink.runtime.testutils.recordutils.RecordPairComparatorFactory;
import org.apache.flink.types.IntValue;
import org.apache.flink.types.Record;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;

public class CoGroupTaskExternalITCase
extends DriverTestBase<CoGroupFunction<Record, Record, Record>> {
    private static final long SORT_MEM = 0x300000L;
    private final RecordComparator comparator1 = new RecordComparator(new int[]{0}, new Class[]{IntValue.class});
    private final RecordComparator comparator2 = new RecordComparator(new int[]{0}, new Class[]{IntValue.class});
    private final DriverTestBase.CountingOutputCollector output = new DriverTestBase.CountingOutputCollector();

    public CoGroupTaskExternalITCase(ExecutionConfig config) {
        super(config, 0L, 2, 0x300000L);
    }

    @Test
    public void testExternalSortCoGroupTask() {
        int keyCnt1 = 131072;
        int valCnt1 = 32;
        int keyCnt2 = 262144;
        int valCnt2 = 4;
        int expCnt = valCnt1 * valCnt2 * Math.min(keyCnt1, keyCnt2) + (keyCnt1 > keyCnt2 ? (keyCnt1 - keyCnt2) * valCnt1 : (keyCnt2 - keyCnt1) * valCnt2);
        this.setOutput(this.output);
        this.addDriverComparator(this.comparator1);
        this.addDriverComparator(this.comparator2);
        this.getTaskConfig().setDriverPairComparator((TypePairComparatorFactory)RecordPairComparatorFactory.get());
        this.getTaskConfig().setDriverStrategy(DriverStrategy.CO_GROUP);
        CoGroupDriver testTask = new CoGroupDriver();
        try {
            this.addInputSorted(new UniformRecordGenerator(keyCnt1, valCnt1, false), this.comparator1.duplicate());
            this.addInputSorted(new UniformRecordGenerator(keyCnt2, valCnt2, false), this.comparator2.duplicate());
            this.testDriver((Driver)testTask, MockCoGroupStub.class);
        }
        catch (Exception e) {
            e.printStackTrace();
            Assert.fail((String)"The test caused an exception.");
        }
        Assert.assertEquals((String)"Wrong result set size.", (long)expCnt, (long)this.output.getNumberOfRecords());
    }

    public static final class MockCoGroupStub
    extends RichCoGroupFunction<Record, Record, Record> {
        private static final long serialVersionUID = 1L;
        private final Record res = new Record();

        public void coGroup(Iterable<Record> records1, Iterable<Record> records2, Collector<Record> out) {
            int val1Cnt = 0;
            int val2Cnt = 0;
            for (Record r : records1) {
                ++val1Cnt;
            }
            for (Record r : records2) {
                ++val2Cnt;
            }
            if (val1Cnt == 0) {
                for (int i = 0; i < val2Cnt; ++i) {
                    out.collect((Object)this.res);
                }
            } else if (val2Cnt == 0) {
                for (int i = 0; i < val1Cnt; ++i) {
                    out.collect((Object)this.res);
                }
            } else {
                for (int i = 0; i < val2Cnt * val1Cnt; ++i) {
                    out.collect((Object)this.res);
                }
            }
        }
    }
}

