/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.classification;

import java.io.Serializable;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.classification.LogisticRegressionTrainingSummary;
import org.apache.spark.ml.param.ParamPair;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class JavaLogisticRegressionSuite
implements Serializable {
    private transient JavaSparkContext jsc;
    private transient SQLContext jsql;
    private transient DataFrame dataset;
    private transient JavaRDD<LabeledPoint> datasetRDD;
    private double eps = 1.0E-5;

    @Before
    public void setUp() {
        this.jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
        this.jsql = new SQLContext(this.jsc);
        List<LabeledPoint> points = LogisticRegressionSuite.generateLogisticInputAsList(1.0, 1.0, 100, 42);
        this.datasetRDD = this.jsc.parallelize(points, 2);
        this.dataset = this.jsql.createDataFrame(this.datasetRDD, LabeledPoint.class);
        this.dataset.registerTempTable("dataset");
    }

    @After
    public void tearDown() {
        this.jsc.stop();
        this.jsc = null;
    }

    @Test
    public void logisticRegressionDefaultParams() {
        LogisticRegression lr = new LogisticRegression();
        Assert.assertEquals((Object)lr.getLabelCol(), (Object)"label");
        LogisticRegressionModel model = (LogisticRegressionModel)lr.fit(this.dataset);
        model.transform(this.dataset).registerTempTable("prediction");
        DataFrame predictions = this.jsql.sql("SELECT label, probability, prediction FROM prediction");
        predictions.collectAsList();
        Assert.assertEquals((double)0.5, (double)model.getThreshold(), (double)this.eps);
        Assert.assertEquals((Object)"features", (Object)model.getFeaturesCol());
        Assert.assertEquals((Object)"prediction", (Object)model.getPredictionCol());
        Assert.assertEquals((Object)"probability", (Object)model.getProbabilityCol());
    }

    @Test
    public void logisticRegressionWithSetters() {
        LogisticRegression lr = (LogisticRegression)new LogisticRegression().setMaxIter(10).setRegParam(1.0).setThreshold(0.6).setProbabilityCol("myProbability");
        LogisticRegressionModel model = (LogisticRegressionModel)lr.fit(this.dataset);
        LogisticRegression parent = (LogisticRegression)model.parent();
        Assert.assertEquals((long)10L, (long)parent.getMaxIter());
        Assert.assertEquals((double)1.0, (double)parent.getRegParam(), (double)this.eps);
        Assert.assertEquals((double)0.4, (double)parent.getThresholds()[0], (double)this.eps);
        Assert.assertEquals((double)0.6, (double)parent.getThresholds()[1], (double)this.eps);
        Assert.assertEquals((double)0.6, (double)parent.getThreshold(), (double)this.eps);
        Assert.assertEquals((double)0.6, (double)model.getThreshold(), (double)this.eps);
        model.setThreshold(1.0);
        model.transform(this.dataset).registerTempTable("predAllZero");
        DataFrame predAllZero = this.jsql.sql("SELECT prediction, myProbability FROM predAllZero");
        for (Row r : predAllZero.collectAsList()) {
            Assert.assertEquals((double)0.0, (double)r.getDouble(0), (double)this.eps);
        }
        model.transform(this.dataset, model.threshold().w(0.0), new ParamPair[]{model.probabilityCol().w((Object)"myProb")}).registerTempTable("predNotAllZero");
        DataFrame predNotAllZero = this.jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
        boolean foundNonZero = false;
        for (Row r : predNotAllZero.collectAsList()) {
            if (r.getDouble(0) == 0.0) continue;
            foundNonZero = true;
        }
        Assert.assertTrue((boolean)foundNonZero);
        LogisticRegressionModel model2 = (LogisticRegressionModel)lr.fit(this.dataset, lr.maxIter().w(5), new ParamPair[]{lr.regParam().w(0.1), lr.threshold().w(0.4), lr.probabilityCol().w((Object)"theProb")});
        LogisticRegression parent2 = (LogisticRegression)model2.parent();
        Assert.assertEquals((long)5L, (long)parent2.getMaxIter());
        Assert.assertEquals((double)0.1, (double)parent2.getRegParam(), (double)this.eps);
        Assert.assertEquals((double)0.4, (double)parent2.getThreshold(), (double)this.eps);
        Assert.assertEquals((double)0.4, (double)model2.getThreshold(), (double)this.eps);
        Assert.assertEquals((Object)"theProb", (Object)model2.getProbabilityCol());
    }

    @Test
    public void logisticRegressionPredictorClassifierMethods() {
        LogisticRegression lr = new LogisticRegression();
        LogisticRegressionModel model = (LogisticRegressionModel)lr.fit(this.dataset);
        Assert.assertEquals((long)2L, (long)model.numClasses());
        model.transform(this.dataset).registerTempTable("transformed");
        DataFrame trans1 = this.jsql.sql("SELECT rawPrediction, probability FROM transformed");
        for (Row row : trans1.collect()) {
            Vector raw = (Vector)row.get(0);
            Vector prob = (Vector)row.get(1);
            Assert.assertEquals((long)raw.size(), (long)2L);
            Assert.assertEquals((long)prob.size(), (long)2L);
            double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
            Assert.assertEquals((double)0.0, (double)Math.abs(prob.apply(1) - probFromRaw1), (double)this.eps);
            Assert.assertEquals((double)0.0, (double)Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), (double)this.eps);
        }
        DataFrame trans2 = this.jsql.sql("SELECT prediction, probability FROM transformed");
        for (Row row : trans2.collect()) {
            double pred = row.getDouble(0);
            Vector prob = (Vector)row.get(1);
            double probOfPred = prob.apply((int)pred);
            for (int i = 0; i < prob.size(); ++i) {
                Assert.assertTrue((probOfPred >= prob.apply(i) ? 1 : 0) != 0);
            }
        }
    }

    @Test
    public void logisticRegressionTrainingSummary() {
        LogisticRegression lr = new LogisticRegression();
        LogisticRegressionModel model = (LogisticRegressionModel)lr.fit(this.dataset);
        LogisticRegressionTrainingSummary summary = model.summary();
        Assert.assertEquals((long)summary.totalIterations(), (long)summary.objectiveHistory().length);
    }
}

