/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.classification;

import java.io.Serializable;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.Classifier;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.classification.OneVsRestModel;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import scala.collection.JavaConverters;

public class JavaOneVsRestSuite
implements Serializable {
    private transient JavaSparkContext jsc;
    private transient SQLContext jsql;
    private transient DataFrame dataset;
    private transient JavaRDD<LabeledPoint> datasetRDD;

    @Before
    public void setUp() {
        this.jsc = new JavaSparkContext("local", "JavaLOneVsRestSuite");
        this.jsql = new SQLContext(this.jsc);
        int nPoints = 3;
        double[] coefficients = new double[]{-0.57997, 0.912083, -0.371077, -0.819866, 2.688191, -0.16624, -0.84355, -0.048509, -0.301789, 4.170682};
        double[] xMean = new double[]{5.843, 3.057, 3.758, 1.199};
        double[] xVariance = new double[]{0.6856, 0.1899, 3.116, 0.581};
        List points = (List)JavaConverters.seqAsJavaListConverter(LogisticRegressionSuite.generateMultinomialLogisticInput(coefficients, xMean, xVariance, true, nPoints, 42)).asJava();
        this.datasetRDD = this.jsc.parallelize(points, 2);
        this.dataset = this.jsql.createDataFrame(this.datasetRDD, LabeledPoint.class);
    }

    @After
    public void tearDown() {
        this.jsc.stop();
        this.jsc = null;
    }

    @Test
    public void oneVsRestDefaultParams() {
        OneVsRest ova = new OneVsRest();
        ova.setClassifier((Classifier)new LogisticRegression());
        Assert.assertEquals((Object)ova.getLabelCol(), (Object)"label");
        Assert.assertEquals((Object)ova.getPredictionCol(), (Object)"prediction");
        OneVsRestModel ovaModel = ova.fit(this.dataset);
        DataFrame predictions = ovaModel.transform(this.dataset).select("label", new String[]{"prediction"});
        predictions.collectAsList();
        Assert.assertEquals((Object)ovaModel.getLabelCol(), (Object)"label");
        Assert.assertEquals((Object)ovaModel.getPredictionCol(), (Object)"prediction");
    }
}

