/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.ann;

import org.apache.spark.SparkContext;
import org.apache.spark.SparkFunSuite;
import org.apache.spark.ml.ann.ANNSuite$;
import org.apache.spark.ml.ann.FeedForwardModel$;
import org.apache.spark.ml.ann.FeedForwardTopology;
import org.apache.spark.ml.ann.FeedForwardTopology$;
import org.apache.spark.ml.ann.FeedForwardTrainer;
import org.apache.spark.ml.ann.Topology;
import org.apache.spark.ml.ann.TopologyModel;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors$;
import org.apache.spark.mllib.util.MLlibTestSparkContext;
import org.apache.spark.mllib.util.MLlibTestSparkContext$class;
import org.apache.spark.mllib.util.TestingUtils$;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.SQLContext;
import org.scalactic.Bool;
import org.scalactic.Bool$;
import org.scalactic.Equality$;
import org.scalactic.TripleEqualsSupport;
import org.scalatest.Args;
import org.scalatest.BeforeAndAfterAll;
import org.scalatest.ConfigMap;
import org.scalatest.FunSuiteLike;
import org.scalatest.Status;
import org.scalatest.Tag;
import scala.Array$;
import scala.Function0;
import scala.Function1;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.GenIterable;
import scala.collection.Seq;
import scala.math.package$;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001}1A!\u0001\u0002\u0001\u001b\tA\u0011I\u0014(Tk&$XM\u0003\u0002\u0004\t\u0005\u0019\u0011M\u001c8\u000b\u0005\u00151\u0011AA7m\u0015\t9\u0001\"A\u0003ta\u0006\u00148N\u0003\u0002\n\u0015\u00051\u0011\r]1dQ\u0016T\u0011aC\u0001\u0004_J<7\u0001A\n\u0004\u00019\u0011\u0002CA\b\u0011\u001b\u00051\u0011BA\t\u0007\u00055\u0019\u0006/\u0019:l\rVt7+^5uKB\u00111\u0003G\u0007\u0002))\u0011QCF\u0001\u0005kRLGN\u0003\u0002\u0018\r\u0005)Q\u000e\u001c7jE&\u0011\u0011\u0004\u0006\u0002\u0016\u001b2c\u0017N\u0019+fgR\u001c\u0006/\u0019:l\u0007>tG/\u001a=u\u0011\u0015Y\u0002\u0001\"\u0001\u001d\u0003\u0019a\u0014N\\5u}Q\tQ\u0004\u0005\u0002\u001f\u00015\t!\u0001")
public class ANNSuite
extends SparkFunSuite
implements MLlibTestSparkContext {
    private transient SparkContext sc;
    private transient SQLContext sqlContext;
    private final boolean invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected;

    @Override
    public SparkContext sc() {
        return this.sc;
    }

    @Override
    public void sc_$eq(SparkContext x$1) {
        this.sc = x$1;
    }

    @Override
    public SQLContext sqlContext() {
        return this.sqlContext;
    }

    @Override
    public void sqlContext_$eq(SQLContext x$1) {
        this.sqlContext = x$1;
    }

    @Override
    public void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$beforeAll() {
        BeforeAndAfterAll.class.beforeAll((BeforeAndAfterAll)this);
    }

    @Override
    public void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$afterAll() {
        BeforeAndAfterAll.class.afterAll((BeforeAndAfterAll)this);
    }

    @Override
    public void beforeAll() {
        MLlibTestSparkContext$class.beforeAll(this);
    }

    @Override
    public void afterAll() {
        MLlibTestSparkContext$class.afterAll(this);
    }

    public boolean invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected() {
        return this.invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected;
    }

    public Status org$scalatest$BeforeAndAfterAll$$super$run(Option testName, Args args) {
        return FunSuiteLike.class.run((FunSuiteLike)this, (Option)testName, (Args)args);
    }

    public void org$scalatest$BeforeAndAfterAll$_setter_$invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected_$eq(boolean x$1) {
        this.invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected = x$1;
    }

    public void beforeAll(ConfigMap configMap) {
        BeforeAndAfterAll.class.beforeAll((BeforeAndAfterAll)this, (ConfigMap)configMap);
    }

    public void afterAll(ConfigMap configMap) {
        BeforeAndAfterAll.class.afterAll((BeforeAndAfterAll)this, (ConfigMap)configMap);
    }

    public Status run(Option<String> testName, Args args) {
        return BeforeAndAfterAll.class.run((BeforeAndAfterAll)this, testName, (Args)args);
    }

    public ANNSuite() {
        BeforeAndAfterAll.class.$init$((BeforeAndAfterAll)this);
        MLlibTestSparkContext$class.$init$(this);
        this.test("ANN with Sigmoid learns XOR function with LBFGS optimizer", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tag[0]), (Function0)new Serializable(this){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ ANNSuite $outer;

            public final void apply() {
                this.apply$mcV$sp();
            }

            public void apply$mcV$sp() {
                double[][] inputs = (double[][])((Object[])new double[][]{{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}});
                double[] outputs = new double[]{0.0, 1.0, 1.0, 0.0};
                Tuple2[] data = (Tuple2[])Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])inputs).zip((GenIterable)Predef$.MODULE$.wrapDoubleArray(outputs), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Tuple2<Vector, Vector> apply(Tuple2<double[], Object> x0$1) {
                        Tuple2<double[], Object> tuple2 = x0$1;
                        if (tuple2 != null) {
                            double[] features = (double[])tuple2._1();
                            double label = tuple2._2$mcD$sp();
                            Tuple2 tuple22 = new Tuple2((Object)Vectors$.MODULE$.dense(features), (Object)Vectors$.MODULE$.dense(label, (Seq)Predef$.MODULE$.wrapDoubleArray(new double[0])));
                            return tuple22;
                        }
                        throw new MatchError(tuple2);
                    }
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)));
                RDD rddData = this.$outer.sc().parallelize((Seq)Predef$.MODULE$.wrapRefArray((Object[])data), 1, ClassTag$.MODULE$.apply(Tuple2.class));
                int[] hiddenLayersTopology = new int[]{5};
                Tuple2 dataSample = (Tuple2)rddData.first();
                int n = ((Vector)dataSample._1()).size();
                int[] layerSizes = (int[])Predef$.MODULE$.intArrayOps((int[])Predef$.MODULE$.intArrayOps(hiddenLayersTopology).$plus$colon((Object)BoxesRunTime.boxToInteger((int)n), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()))).$colon$plus((Object)BoxesRunTime.boxToInteger((int)((Vector)dataSample._2()).size()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()));
                FeedForwardTopology topology = FeedForwardTopology$.MODULE$.multiLayerPerceptron(layerSizes, false);
                Vector initialWeights = FeedForwardModel$.MODULE$.apply(topology, 23124L).weights();
                FeedForwardTrainer trainer = new FeedForwardTrainer((Topology)topology, 2, 1);
                trainer.setWeights(initialWeights);
                trainer.LBFGSOptimizer().setNumIterations(20);
                TopologyModel model = trainer.train(rddData);
                Tuple2[] predictionAndLabels = (Tuple2[])rddData.map((Function1)new Serializable(this, model){
                    public static final long serialVersionUID = 0L;
                    private final TopologyModel model$1;

                    public final Tuple2<Object, Object> apply(Tuple2<Vector, Vector> x0$2) {
                        Tuple2<Vector, Vector> tuple2 = x0$2;
                        if (tuple2 != null) {
                            Vector input = (Vector)tuple2._1();
                            Vector label = (Vector)tuple2._2();
                            Tuple2.mcDD.sp sp2 = new Tuple2.mcDD.sp(this.model$1.predict(input).apply(0), label.apply(0));
                            return sp2;
                        }
                        throw new MatchError(tuple2);
                    }
                    {
                        this.model$1 = model$1;
                    }
                }, ClassTag$.MODULE$.apply(Tuple2.class)).collect();
                Predef$.MODULE$.refArrayOps((Object[])predictionAndLabels).foreach((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;
                    private final /* synthetic */ $anonfun$1 $outer;

                    public final void apply(Tuple2<Object, Object> x0$3) {
                        Tuple2<Object, Object> tuple2 = x0$3;
                        if (tuple2 != null) {
                            double p = tuple2._1$mcD$sp();
                            double l = tuple2._2$mcD$sp();
                            TripleEqualsSupport.Equalizer $org_scalatest_assert_macro_left = this.$outer.org$apache$spark$ml$ann$ANNSuite$$anonfun$$$outer().convertToEqualizer(BoxesRunTime.boxToLong((long)package$.MODULE$.round(p)));
                            double $org_scalatest_assert_macro_right = l;
                            Bool $org_scalatest_assert_macro_expr = Bool$.MODULE$.binaryMacroBool((Object)$org_scalatest_assert_macro_left, "===", (Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right), $org_scalatest_assert_macro_left.$eq$eq$eq((Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right), Equality$.MODULE$.default()));
                            this.$outer.org$apache$spark$ml$ann$ANNSuite$$anonfun$$$outer().assertionsHelper().macroAssert($org_scalatest_assert_macro_expr, (Object)"");
                            BoxedUnit boxedUnit = BoxedUnit.UNIT;
                            return;
                        }
                        throw new MatchError(tuple2);
                    }
                    {
                        if ($outer == null) {
                            throw new NullPointerException();
                        }
                        this.$outer = $outer;
                    }
                });
            }

            public /* synthetic */ ANNSuite org$apache$spark$ml$ann$ANNSuite$$anonfun$$$outer() {
                return this.$outer;
            }
            {
                if ($outer == null) {
                    throw new NullPointerException();
                }
                this.$outer = $outer;
            }
        });
        this.test("ANN with SoftMax learns XOR function with 2-bit output and batch GD optimizer", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tag[0]), (Function0)new Serializable(this){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ ANNSuite $outer;

            public final void apply() {
                this.apply$mcV$sp();
            }

            public void apply$mcV$sp() {
                double[][] inputs = (double[][])((Object[])new double[][]{{0.0, 0.0}, {0.0, 1.0}, {1.0, 0.0}, {1.0, 1.0}});
                double[][] outputs = (double[][])((Object[])new double[][]{{1.0, 0.0}, {0.0, 1.0}, {0.0, 1.0}, {1.0, 0.0}});
                Tuple2[] data = (Tuple2[])Predef$.MODULE$.refArrayOps((Object[])Predef$.MODULE$.refArrayOps((Object[])inputs).zip((GenIterable)Predef$.MODULE$.wrapRefArray((Object[])outputs), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Tuple2<Vector, Vector> apply(Tuple2<double[], double[]> x0$4) {
                        Tuple2<double[], double[]> tuple2 = x0$4;
                        if (tuple2 != null) {
                            double[] features = (double[])tuple2._1();
                            double[] label = (double[])tuple2._2();
                            Tuple2 tuple22 = new Tuple2((Object)Vectors$.MODULE$.dense(features), (Object)Vectors$.MODULE$.dense(label));
                            return tuple22;
                        }
                        throw new MatchError(tuple2);
                    }
                }, Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)));
                RDD rddData = this.$outer.sc().parallelize((Seq)Predef$.MODULE$.wrapRefArray((Object[])data), 1, ClassTag$.MODULE$.apply(Tuple2.class));
                int[] hiddenLayersTopology = new int[]{5};
                Tuple2 dataSample = (Tuple2)rddData.first();
                int n = ((Vector)dataSample._1()).size();
                int[] layerSizes = (int[])Predef$.MODULE$.intArrayOps((int[])Predef$.MODULE$.intArrayOps(hiddenLayersTopology).$plus$colon((Object)BoxesRunTime.boxToInteger((int)n), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()))).$colon$plus((Object)BoxesRunTime.boxToInteger((int)((Vector)dataSample._2()).size()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.Int()));
                FeedForwardTopology topology = FeedForwardTopology$.MODULE$.multiLayerPerceptron(layerSizes, false);
                Vector initialWeights = FeedForwardModel$.MODULE$.apply(topology, 23124L).weights();
                FeedForwardTrainer trainer = new FeedForwardTrainer((Topology)topology, 2, 2);
                trainer.SGDOptimizer().setNumIterations(2000);
                trainer.setWeights(initialWeights);
                TopologyModel model = trainer.train(rddData);
                Tuple2[] predictionAndLabels = (Tuple2[])rddData.map((Function1)new Serializable(this, model){
                    public static final long serialVersionUID = 0L;
                    private final TopologyModel model$2;

                    public final Tuple2<Vector, Vector> apply(Tuple2<Vector, Vector> x0$5) {
                        Tuple2<Vector, Vector> tuple2 = x0$5;
                        if (tuple2 != null) {
                            Vector input = (Vector)tuple2._1();
                            Vector label = (Vector)tuple2._2();
                            Tuple2 tuple22 = new Tuple2((Object)this.model$2.predict(input), (Object)label);
                            return tuple22;
                        }
                        throw new MatchError(tuple2);
                    }
                    {
                        this.model$2 = model$2;
                    }
                }, ClassTag$.MODULE$.apply(Tuple2.class)).collect();
                Predef$.MODULE$.refArrayOps((Object[])predictionAndLabels).foreach((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;
                    private final /* synthetic */ $anonfun$2 $outer;

                    public final void apply(Tuple2<Vector, Vector> x0$6) {
                        Tuple2<Vector, Vector> tuple2 = x0$6;
                        if (tuple2 != null) {
                            Vector p = (Vector)tuple2._1();
                            Vector l = (Vector)tuple2._2();
                            Bool $org_scalatest_assert_macro_expr = Bool$.MODULE$.simpleMacroBool(TestingUtils$.MODULE$.VectorWithAlmostEquals(p).$tilde$eq$eq(TestingUtils$.MODULE$.VectorWithAlmostEquals(l).absTol(0.5)), "org.apache.spark.mllib.util.TestingUtils.VectorWithAlmostEquals(p).~==(org.apache.spark.mllib.util.TestingUtils.VectorWithAlmostEquals(l).absTol(0.5))");
                            this.$outer.org$apache$spark$ml$ann$ANNSuite$$anonfun$$$outer().assertionsHelper().macroAssert($org_scalatest_assert_macro_expr, (Object)"");
                            BoxedUnit boxedUnit = BoxedUnit.UNIT;
                            return;
                        }
                        throw new MatchError(tuple2);
                    }
                    {
                        if ($outer == null) {
                            throw new NullPointerException();
                        }
                        this.$outer = $outer;
                    }
                });
            }

            public /* synthetic */ ANNSuite org$apache$spark$ml$ann$ANNSuite$$anonfun$$$outer() {
                return this.$outer;
            }
            {
                if ($outer == null) {
                    throw new NullPointerException();
                }
                this.$outer = $outer;
            }
        });
    }
}

