/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.regression;

import java.io.File;
import org.apache.spark.SparkContext;
import org.apache.spark.SparkFunSuite;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.LinearRegressionModel;
import org.apache.spark.mllib.regression.LinearRegressionWithSGD;
import org.apache.spark.mllib.regression.RidgeRegressionModel;
import org.apache.spark.mllib.regression.RidgeRegressionModel$;
import org.apache.spark.mllib.regression.RidgeRegressionSuite$;
import org.apache.spark.mllib.regression.RidgeRegressionWithSGD;
import org.apache.spark.mllib.util.LinearDataGenerator$;
import org.apache.spark.mllib.util.MLlibTestSparkContext;
import org.apache.spark.mllib.util.MLlibTestSparkContext$class;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.util.Utils$;
import org.jblas.DoubleMatrix;
import org.jblas.util.Random;
import org.scalactic.Bool;
import org.scalactic.Bool$;
import org.scalatest.Args;
import org.scalatest.BeforeAndAfterAll;
import org.scalatest.ConfigMap;
import org.scalatest.FunSuiteLike;
import org.scalatest.Status;
import org.scalatest.Tag;
import scala.Function0;
import scala.Function1;
import scala.Function2;
import scala.MatchError;
import scala.Option;
import scala.Predef$;
import scala.Serializable;
import scala.Tuple2;
import scala.collection.Seq;
import scala.collection.Seq$;
import scala.collection.TraversableLike;
import scala.collection.TraversableOnce;
import scala.collection.mutable.StringBuilder;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

@ScalaSignature(bytes="\u0006\u0001i;Q!\u0001\u0002\t\n5\tACU5eO\u0016\u0014Vm\u001a:fgNLwN\\*vSR,'BA\u0002\u0005\u0003)\u0011Xm\u001a:fgNLwN\u001c\u0006\u0003\u000b\u0019\tQ!\u001c7mS\nT!a\u0002\u0005\u0002\u000bM\u0004\u0018M]6\u000b\u0005%Q\u0011AB1qC\u000eDWMC\u0001\f\u0003\ry'oZ\u0002\u0001!\tqq\"D\u0001\u0003\r\u0015\u0001\"\u0001#\u0003\u0012\u0005Q\u0011\u0016\u000eZ4f%\u0016<'/Z:tS>t7+^5uKN\u0019qB\u0005\r\u0011\u0005M1R\"\u0001\u000b\u000b\u0003U\tQa]2bY\u0006L!a\u0006\u000b\u0003\r\u0005s\u0017PU3g!\t\u0019\u0012$\u0003\u0002\u001b)\ta1+\u001a:jC2L'0\u00192mK\")Ad\u0004C\u0001;\u00051A(\u001b8jiz\"\u0012!\u0004\u0005\b?=\u0011\r\u0011\"\u0001!\u0003\u0015iw\u000eZ3m+\u0005\t\u0003C\u0001\b#\u0013\t\u0019#A\u0001\u000bSS\u0012<WMU3he\u0016\u001c8/[8o\u001b>$W\r\u001c\u0005\u0007K=\u0001\u000b\u0011B\u0011\u0002\r5|G-\u001a7!\u0011\u001d9s\"!A\u0005\n!\n1B]3bIJ+7o\u001c7wKR\t\u0011\u0006\u0005\u0002+_5\t1F\u0003\u0002-[\u0005!A.\u00198h\u0015\u0005q\u0013\u0001\u00026bm\u0006L!\u0001M\u0016\u0003\r=\u0013'.Z2u\r\u0011\u0001\"\u0001\u0001\u001a\u0014\u0007E\u001at\u0007\u0005\u00025k5\ta!\u0003\u00027\r\ti1\u000b]1sW\u001a+hnU;ji\u0016\u0004\"\u0001O\u001e\u000e\u0003eR!A\u000f\u0003\u0002\tU$\u0018\u000e\\\u0005\u0003ye\u0012Q#\u0014'mS\n$Vm\u001d;Ta\u0006\u00148nQ8oi\u0016DH\u000fC\u0003\u001dc\u0011\u0005a\bF\u0001@!\tq\u0011\u0007C\u0003Bc\u0011\u0005!)A\bqe\u0016$\u0017n\u0019;j_:,%O]8s)\r\u0019e\t\u0016\t\u0003'\u0011K!!\u0012\u000b\u0003\r\u0011{WO\u00197f\u0011\u00159\u0005\t1\u0001I\u0003-\u0001(/\u001a3jGRLwN\\:\u0011\u0007%\u000b6I\u0004\u0002K\u001f:\u00111JT\u0007\u0002\u0019*\u0011Q\nD\u0001\u0007yI|w\u000e\u001e \n\u0003UI!\u0001\u0015\u000b\u0002\u000fA\f7m[1hK&\u0011!k\u0015\u0002\u0004'\u0016\f(B\u0001)\u0015\u0011\u0015)\u0006\t1\u0001W\u0003\u0015Ig\u000e];u!\rI\u0015k\u0016\t\u0003\u001daK!!\u0017\u0002\u0003\u00191\u000b'-\u001a7fIB{\u0017N\u001c;")
public class RidgeRegressionSuite
extends SparkFunSuite
implements MLlibTestSparkContext {
    private transient SparkContext sc;
    private transient SQLContext sqlContext;
    private final boolean invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected;

    public static RidgeRegressionModel model() {
        return RidgeRegressionSuite$.MODULE$.model();
    }

    @Override
    public SparkContext sc() {
        return this.sc;
    }

    @Override
    public void sc_$eq(SparkContext x$1) {
        this.sc = x$1;
    }

    @Override
    public SQLContext sqlContext() {
        return this.sqlContext;
    }

    @Override
    public void sqlContext_$eq(SQLContext x$1) {
        this.sqlContext = x$1;
    }

    @Override
    public void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$beforeAll() {
        BeforeAndAfterAll.class.beforeAll((BeforeAndAfterAll)this);
    }

    @Override
    public void org$apache$spark$mllib$util$MLlibTestSparkContext$$super$afterAll() {
        BeforeAndAfterAll.class.afterAll((BeforeAndAfterAll)this);
    }

    @Override
    public void beforeAll() {
        MLlibTestSparkContext$class.beforeAll(this);
    }

    @Override
    public void afterAll() {
        MLlibTestSparkContext$class.afterAll(this);
    }

    public boolean invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected() {
        return this.invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected;
    }

    public Status org$scalatest$BeforeAndAfterAll$$super$run(Option testName, Args args) {
        return FunSuiteLike.class.run((FunSuiteLike)this, (Option)testName, (Args)args);
    }

    public void org$scalatest$BeforeAndAfterAll$_setter_$invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected_$eq(boolean x$1) {
        this.invokeBeforeAllAndAfterAllEvenIfNoTestsAreExpected = x$1;
    }

    public void beforeAll(ConfigMap configMap) {
        BeforeAndAfterAll.class.beforeAll((BeforeAndAfterAll)this, (ConfigMap)configMap);
    }

    public void afterAll(ConfigMap configMap) {
        BeforeAndAfterAll.class.afterAll((BeforeAndAfterAll)this, (ConfigMap)configMap);
    }

    public Status run(Option<String> testName, Args args) {
        return BeforeAndAfterAll.class.run((BeforeAndAfterAll)this, testName, (Args)args);
    }

    public double predictionError(Seq<Object> predictions, Seq<LabeledPoint> input) {
        return BoxesRunTime.unboxToDouble((Object)((TraversableOnce)((TraversableLike)predictions.zip(input, Seq$.MODULE$.canBuildFrom())).map((Function1)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final double apply(Tuple2<Object, LabeledPoint> x0$1) {
                Tuple2<Object, LabeledPoint> tuple2 = x0$1;
                if (tuple2 != null) {
                    double prediction = tuple2._1$mcD$sp();
                    LabeledPoint expected = (LabeledPoint)tuple2._2();
                    double d = (prediction - expected.label()) * (prediction - expected.label());
                    return d;
                }
                throw new MatchError(tuple2);
            }
        }, Seq$.MODULE$.canBuildFrom())).reduceLeft((Function2)new Serializable(this){
            public static final long serialVersionUID = 0L;

            public final double apply(double x$1, double x$2) {
                return this.apply$mcDDD$sp(x$1, x$2);
            }

            public double apply$mcDDD$sp(double x$1, double x$2) {
                return x$1 + x$2;
            }
        })) / (double)predictions.size();
    }

    public RidgeRegressionSuite() {
        BeforeAndAfterAll.class.$init$((BeforeAndAfterAll)this);
        MLlibTestSparkContext$class.$init$(this);
        this.test("ridge regression can help avoid overfitting", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tag[0]), (Function0)new Serializable(this){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ RidgeRegressionSuite $outer;

            public final void apply() {
                this.apply$mcV$sp();
            }

            public void apply$mcV$sp() {
                double ridgeErr;
                int numExamples = 50;
                int numFeatures = 20;
                Random.seed((long)42L);
                DoubleMatrix w = DoubleMatrix.rand((int)numFeatures, (int)1).subi(0.5);
                Seq data = LinearDataGenerator$.MODULE$.generateLinearInput(3.0, w.toArray(), 2 * numExamples, 42, 10.0);
                Seq testData = (Seq)data.take(numExamples);
                Seq validationData = (Seq)data.takeRight(numExamples);
                RDD testRDD = this.$outer.sc().parallelize(testData, 2, ClassTag$.MODULE$.apply(LabeledPoint.class)).cache();
                RDD validationRDD = this.$outer.sc().parallelize(validationData, 2, ClassTag$.MODULE$.apply(LabeledPoint.class)).cache();
                LinearRegressionWithSGD linearReg = new LinearRegressionWithSGD();
                linearReg.optimizer().setNumIterations(200).setStepSize(1.0);
                LinearRegressionModel linearModel = (LinearRegressionModel)linearReg.run(testRDD);
                double linearErr = this.$outer.predictionError((Seq<Object>)Predef$.MODULE$.wrapDoubleArray((double[])linearModel.predict(validationRDD.map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Vector apply(LabeledPoint x$3) {
                        return x$3.features();
                    }
                }, ClassTag$.MODULE$.apply(Vector.class))).collect()), (Seq<LabeledPoint>)validationData);
                RidgeRegressionWithSGD ridgeReg = new RidgeRegressionWithSGD();
                ridgeReg.optimizer().setNumIterations(200).setRegParam(0.1).setStepSize(1.0);
                RidgeRegressionModel ridgeModel = (RidgeRegressionModel)ridgeReg.run(testRDD);
                double $org_scalatest_assert_macro_left = ridgeErr = this.$outer.predictionError((Seq<Object>)Predef$.MODULE$.wrapDoubleArray((double[])ridgeModel.predict(validationRDD.map((Function1)new Serializable(this){
                    public static final long serialVersionUID = 0L;

                    public final Vector apply(LabeledPoint x$4) {
                        return x$4.features();
                    }
                }, ClassTag$.MODULE$.apply(Vector.class))).collect()), (Seq<LabeledPoint>)validationData);
                double $org_scalatest_assert_macro_right = linearErr;
                Bool $org_scalatest_assert_macro_expr = Bool$.MODULE$.binaryMacroBool((Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_left), "<", (Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right), $org_scalatest_assert_macro_left < $org_scalatest_assert_macro_right);
                this.$outer.assertionsHelper().macroAssert($org_scalatest_assert_macro_expr, (Object)new StringBuilder().append((Object)"ridgeError (").append((Object)BoxesRunTime.boxToDouble((double)ridgeErr)).append((Object)") was not less than linearError(").append((Object)BoxesRunTime.boxToDouble((double)linearErr)).append((Object)")").toString());
            }
            {
                if ($outer == null) {
                    throw new NullPointerException();
                }
                this.$outer = $outer;
            }
        });
        this.test("model save/load", (Seq)Predef$.MODULE$.wrapRefArray((Object[])new Tag[0]), (Function0)new Serializable(this){
            public static final long serialVersionUID = 0L;
            private final /* synthetic */ RidgeRegressionSuite $outer;

            public final void apply() {
                this.apply$mcV$sp();
            }

            public void apply$mcV$sp() {
                RidgeRegressionModel model = RidgeRegressionSuite$.MODULE$.model();
                File tempDir = Utils$.MODULE$.createTempDir(Utils$.MODULE$.createTempDir$default$1(), Utils$.MODULE$.createTempDir$default$2());
                String path = tempDir.toURI().toString();
                try {
                    model.save(this.$outer.sc(), path);
                    RidgeRegressionModel sameModel = RidgeRegressionModel$.MODULE$.load(this.$outer.sc(), path);
                    Vector $org_scalatest_assert_macro_left = model.weights();
                    Vector $org_scalatest_assert_macro_right = sameModel.weights();
                    Vector vector = $org_scalatest_assert_macro_left;
                    Vector vector2 = $org_scalatest_assert_macro_right;
                    Bool $org_scalatest_assert_macro_expr = Bool$.MODULE$.binaryMacroBool((Object)$org_scalatest_assert_macro_left, "==", (Object)$org_scalatest_assert_macro_right, !(vector != null ? !vector.equals(vector2) : vector2 != null));
                    this.$outer.assertionsHelper().macroAssert($org_scalatest_assert_macro_expr, (Object)"");
                    double $org_scalatest_assert_macro_left2 = model.intercept();
                    double $org_scalatest_assert_macro_right2 = sameModel.intercept();
                    Bool $org_scalatest_assert_macro_expr2 = Bool$.MODULE$.binaryMacroBool((Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_left2), "==", (Object)BoxesRunTime.boxToDouble((double)$org_scalatest_assert_macro_right2), $org_scalatest_assert_macro_left2 == $org_scalatest_assert_macro_right2);
                    this.$outer.assertionsHelper().macroAssert($org_scalatest_assert_macro_expr2, (Object)"");
                    return;
                }
                finally {
                    Utils$.MODULE$.deleteRecursively(tempDir);
                }
            }
            {
                if ($outer == null) {
                    throw new NullPointerException();
                }
                this.$outer = $outer;
            }
        });
    }
}

