/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.regression;

import java.io.Serializable;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.regression.RidgeRegressionModel;
import org.apache.spark.mllib.regression.RidgeRegressionWithSGD;
import org.apache.spark.mllib.util.LinearDataGenerator;
import org.apache.spark.rdd.RDD;
import org.jblas.DoubleMatrix;
import org.jblas.util.Random;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class JavaRidgeRegressionSuite
implements Serializable {
    private transient JavaSparkContext sc;

    @Before
    public void setUp() {
        this.sc = new JavaSparkContext("local", "JavaRidgeRegressionSuite");
    }

    @After
    public void tearDown() {
        this.sc.stop();
        this.sc = null;
    }

    double predictionError(List<LabeledPoint> validationData, RidgeRegressionModel model) {
        double errorSum = 0.0;
        for (LabeledPoint point : validationData) {
            Double prediction = model.predict(point.features());
            errorSum += (prediction - point.label()) * (prediction - point.label());
        }
        return errorSum / (double)validationData.size();
    }

    List<LabeledPoint> generateRidgeData(int numPoints, int numFeatures, double std) {
        Random.seed((long)42L);
        DoubleMatrix w = DoubleMatrix.rand((int)numFeatures, (int)1).subi(0.5);
        return LinearDataGenerator.generateLinearInputAsList((double)0.0, (double[])w.data, (int)numPoints, (int)42, (double)std);
    }

    @Test
    public void runRidgeRegressionUsingConstructor() {
        int numExamples = 50;
        int numFeatures = 20;
        List<LabeledPoint> data = this.generateRidgeData(2 * numExamples, numFeatures, 10.0);
        JavaRDD testRDD = this.sc.parallelize(data.subList(0, numExamples));
        List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
        RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
        ridgeSGDImpl.optimizer().setStepSize(1.0).setRegParam(0.0).setNumIterations(200);
        RidgeRegressionModel model = (RidgeRegressionModel)ridgeSGDImpl.run(testRDD.rdd());
        double unRegularizedErr = this.predictionError(validationData, model);
        ridgeSGDImpl.optimizer().setRegParam(0.1);
        model = (RidgeRegressionModel)ridgeSGDImpl.run(testRDD.rdd());
        double regularizedErr = this.predictionError(validationData, model);
        Assert.assertTrue((regularizedErr < unRegularizedErr ? 1 : 0) != 0);
    }

    @Test
    public void runRidgeRegressionUsingStaticMethods() {
        int numExamples = 50;
        int numFeatures = 20;
        List<LabeledPoint> data = this.generateRidgeData(2 * numExamples, numFeatures, 10.0);
        JavaRDD testRDD = this.sc.parallelize(data.subList(0, numExamples));
        List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
        RidgeRegressionModel model = RidgeRegressionWithSGD.train((RDD)testRDD.rdd(), (int)200, (double)1.0, (double)0.0);
        double unRegularizedErr = this.predictionError(validationData, model);
        model = RidgeRegressionWithSGD.train((RDD)testRDD.rdd(), (int)200, (double)1.0, (double)0.1);
        double regularizedErr = this.predictionError(validationData, model);
        Assert.assertTrue((regularizedErr < unRegularizedErr ? 1 : 0) != 0);
    }
}

