/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.classification;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.classification.NaiveBayes;
import org.apache.spark.ml.classification.NaiveBayesModel;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class JavaNaiveBayesSuite
implements Serializable {
    private transient JavaSparkContext jsc;
    private transient SQLContext jsql;

    @Before
    public void setUp() {
        this.jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
        this.jsql = new SQLContext(this.jsc);
    }

    @After
    public void tearDown() {
        this.jsc.stop();
        this.jsc = null;
    }

    public void validatePrediction(DataFrame predictionAndLabels) {
        for (Row r : predictionAndLabels.collect()) {
            double prediction = (Double)r.getAs(0);
            double label = (Double)r.getAs(1);
            Assert.assertEquals((double)label, (double)prediction, (double)1.0E-5);
        }
    }

    @Test
    public void naiveBayesDefaultParams() {
        NaiveBayes nb = new NaiveBayes();
        Assert.assertEquals((Object)"label", (Object)nb.getLabelCol());
        Assert.assertEquals((Object)"features", (Object)nb.getFeaturesCol());
        Assert.assertEquals((Object)"prediction", (Object)nb.getPredictionCol());
        Assert.assertEquals((double)1.0, (double)nb.getSmoothing(), (double)1.0E-5);
        Assert.assertEquals((Object)"multinomial", (Object)nb.getModelType());
    }

    @Test
    public void testNaiveBayes() {
        List<Row> data = Arrays.asList(RowFactory.create((Object[])new Object[]{0.0, Vectors.dense((double)1.0, (double[])new double[]{0.0, 0.0})}), RowFactory.create((Object[])new Object[]{0.0, Vectors.dense((double)2.0, (double[])new double[]{0.0, 0.0})}), RowFactory.create((Object[])new Object[]{1.0, Vectors.dense((double)0.0, (double[])new double[]{1.0, 0.0})}), RowFactory.create((Object[])new Object[]{1.0, Vectors.dense((double)0.0, (double[])new double[]{2.0, 0.0})}), RowFactory.create((Object[])new Object[]{2.0, Vectors.dense((double)0.0, (double[])new double[]{0.0, 1.0})}), RowFactory.create((Object[])new Object[]{2.0, Vectors.dense((double)0.0, (double[])new double[]{0.0, 2.0})}));
        StructType schema = new StructType(new StructField[]{new StructField("label", DataTypes.DoubleType, false, Metadata.empty()), new StructField("features", (DataType)new VectorUDT(), false, Metadata.empty())});
        DataFrame dataset = this.jsql.createDataFrame(data, schema);
        NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
        NaiveBayesModel model = (NaiveBayesModel)nb.fit(dataset);
        DataFrame predictionAndLabels = model.transform(dataset).select("prediction", new String[]{"label"});
        this.validatePrediction(predictionAndLabels);
    }
}

