package org.apache.flink.api.java.operators.translation;

import java.io.Serializable;
import java.util.Iterator;
import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.common.aggregators.AggregatorWithName;
import org.apache.flink.api.common.aggregators.LongSumAggregator;
import org.apache.flink.api.common.functions.RichCoGroupFunction;
import org.apache.flink.api.common.functions.RichJoinFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.operators.GenericDataSinkBase;
import org.apache.flink.api.common.operators.base.DeltaIterationBase;
import org.apache.flink.api.common.operators.base.InnerJoinOperatorBase;
import org.apache.flink.api.common.operators.base.MapOperatorBase;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.operators.JoinOperator;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/api/java/operators/translation/DeltaIterationTranslationTest.class */
public class DeltaIterationTranslationTest implements Serializable {

    /* loaded from: input_file:org/apache/flink/api/java/operators/translation/DeltaIterationTranslationTest$IdentityMapper.class */
    private static class IdentityMapper<T> extends RichMapFunction<T, T> {
        private IdentityMapper() {
        }

        public T map(T t) throws Exception {
            return t;
        }
    }

    /* loaded from: input_file:org/apache/flink/api/java/operators/translation/DeltaIterationTranslationTest$NextWorksetMapper.class */
    private static class NextWorksetMapper extends RichMapFunction<Tuple3<Double, Long, String>, Tuple2<Double, String>> {
        private NextWorksetMapper() {
        }

        public Tuple2<Double, String> map(Tuple3<Double, Long, String> tuple3) {
            return null;
        }
    }

    /* loaded from: input_file:org/apache/flink/api/java/operators/translation/DeltaIterationTranslationTest$SolutionWorksetCoGroup1.class */
    private static class SolutionWorksetCoGroup1 extends RichCoGroupFunction<Tuple2<Double, String>, Tuple3<Double, Long, String>, Tuple3<Double, Long, String>> {
        private SolutionWorksetCoGroup1() {
        }

        public void coGroup(Iterable<Tuple2<Double, String>> iterable, Iterable<Tuple3<Double, Long, String>> iterable2, Collector<Tuple3<Double, Long, String>> collector) {
        }
    }

    /* loaded from: input_file:org/apache/flink/api/java/operators/translation/DeltaIterationTranslationTest$SolutionWorksetCoGroup2.class */
    private static class SolutionWorksetCoGroup2 extends RichCoGroupFunction<Tuple3<Double, Long, String>, Tuple2<Double, String>, Tuple3<Double, Long, String>> {
        private SolutionWorksetCoGroup2() {
        }

        public void coGroup(Iterable<Tuple3<Double, Long, String>> iterable, Iterable<Tuple2<Double, String>> iterable2, Collector<Tuple3<Double, Long, String>> collector) {
        }
    }

    /* loaded from: input_file:org/apache/flink/api/java/operators/translation/DeltaIterationTranslationTest$SolutionWorksetJoin.class */
    private static class SolutionWorksetJoin extends RichJoinFunction<Tuple2<Double, String>, Tuple3<Double, Long, String>, Tuple3<Double, Long, String>> {
        private SolutionWorksetJoin() {
        }

        public Tuple3<Double, Long, String> join(Tuple2<Double, String> tuple2, Tuple3<Double, Long, String> tuple3) {
            return null;
        }
    }

    @Test
    public void testCorrectTranslation() {
        try {
            int[] iArr = {2};
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            executionEnvironment.setParallelism(133);
            DeltaIteration iterateDelta = executionEnvironment.fromElements(new Tuple3[]{new Tuple3(Double.valueOf(3.44d), 5L, "abc")}).iterateDelta(executionEnvironment.fromElements(new Tuple2[]{new Tuple2(Double.valueOf(1.23d), "abc")}), 13, iArr);
            iterateDelta.name("Test Name").parallelism(77);
            iterateDelta.registerAggregator("AggregatorName", new LongSumAggregator());
            JoinOperator.EquiJoin with = iterateDelta.getWorkset().map(new IdentityMapper()).join(iterateDelta.getWorkset()).where(new int[]{1}).equalTo(new int[]{1}).projectFirst(new int[]{0, 1}).join(iterateDelta.getSolutionSet()).where(new int[]{1}).equalTo(new int[]{2}).with(new SolutionWorksetJoin());
            DataSet closeWith = iterateDelta.closeWith(with, with.map(new NextWorksetMapper()).name("Some Mapper"));
            closeWith.output(new DiscardingOutputFormat());
            closeWith.writeAsText("/dev/null");
            Plan createProgramPlan = executionEnvironment.createProgramPlan("Test JobName");
            Assert.assertEquals("Test JobName", createProgramPlan.getJobName());
            Assert.assertEquals(133L, createProgramPlan.getDefaultParallelism());
            Iterator it = createProgramPlan.getDataSinks().iterator();
            GenericDataSinkBase genericDataSinkBase = (GenericDataSinkBase) it.next();
            GenericDataSinkBase genericDataSinkBase2 = (GenericDataSinkBase) it.next();
            DeltaIterationBase input = genericDataSinkBase.getInput();
            Assert.assertEquals(input, genericDataSinkBase2.getInput());
            Assert.assertEquals(13L, input.getMaximumNumberOfIterations());
            Assert.assertArrayEquals(iArr, input.getSolutionSetKeyFields());
            Assert.assertEquals(77L, input.getParallelism());
            Assert.assertEquals("Test Name", input.getName());
            MapOperatorBase nextWorkset = input.getNextWorkset();
            InnerJoinOperatorBase solutionSetDelta = input.getSolutionSetDelta();
            Assert.assertEquals(IdentityMapper.class, solutionSetDelta.getFirstInput().getFirstInput().getUserCodeWrapper().getUserCodeClass());
            Assert.assertEquals(NextWorksetMapper.class, nextWorkset.getUserCodeWrapper().getUserCodeClass());
            if (solutionSetDelta.getUserCodeWrapper().getUserCodeObject() instanceof WrappingFunction) {
                Assert.assertEquals(SolutionWorksetJoin.class, ((WrappingFunction) solutionSetDelta.getUserCodeWrapper().getUserCodeObject()).getWrappedFunction().getClass());
            } else {
                Assert.assertEquals(SolutionWorksetJoin.class, solutionSetDelta.getUserCodeWrapper().getUserCodeClass());
            }
            Assert.assertEquals("Some Mapper", nextWorkset.getName());
            Assert.assertEquals("AggregatorName", ((AggregatorWithName) input.getAggregators().getAllRegisteredAggregators().iterator().next()).getName());
        } catch (Exception e) {
            System.err.println(e.getMessage());
            e.printStackTrace();
            Assert.fail(e.getMessage());
        }
    }

    @Test
    public void testRejectWhenSolutionSetKeysDontMatchJoin() {
        try {
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            DeltaIteration iterateDelta = executionEnvironment.fromElements(new Tuple3[]{new Tuple3(Double.valueOf(3.44d), 5L, "abc")}).iterateDelta(executionEnvironment.fromElements(new Tuple2[]{new Tuple2(Double.valueOf(1.23d), "abc")}), 10, new int[]{1});
            try {
                iterateDelta.getWorkset().join(iterateDelta.getSolutionSet()).where(new int[]{1}).equalTo(new int[]{2});
                Assert.fail("Accepted invalid program.");
            } catch (InvalidProgramException e) {
            }
            try {
                iterateDelta.getSolutionSet().join(iterateDelta.getWorkset()).where(new int[]{2}).equalTo(new int[]{1});
                Assert.fail("Accepted invalid program.");
            } catch (InvalidProgramException e2) {
            }
        } catch (Exception e3) {
            System.err.println(e3.getMessage());
            e3.printStackTrace();
            Assert.fail(e3.getMessage());
        }
    }

    @Test
    public void testRejectWhenSolutionSetKeysDontMatchCoGroup() {
        try {
            ExecutionEnvironment executionEnvironment = ExecutionEnvironment.getExecutionEnvironment();
            DeltaIteration iterateDelta = executionEnvironment.fromElements(new Tuple3[]{new Tuple3(Double.valueOf(3.44d), 5L, "abc")}).iterateDelta(executionEnvironment.fromElements(new Tuple2[]{new Tuple2(Double.valueOf(1.23d), "abc")}), 10, new int[]{1});
            try {
                iterateDelta.getWorkset().coGroup(iterateDelta.getSolutionSet()).where(new int[]{1}).equalTo(new int[]{2}).with(new SolutionWorksetCoGroup1());
                Assert.fail("Accepted invalid program.");
            } catch (InvalidProgramException e) {
            }
            try {
                iterateDelta.getSolutionSet().coGroup(iterateDelta.getWorkset()).where(new int[]{2}).equalTo(new int[]{1}).with(new SolutionWorksetCoGroup2());
                Assert.fail("Accepted invalid program.");
            } catch (InvalidProgramException e2) {
            }
        } catch (Exception e3) {
            System.err.println(e3.getMessage());
            e3.printStackTrace();
            Assert.fail(e3.getMessage());
        }
    }
}
