/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.mllib.tree;

import java.io.Serializable;
import java.util.HashMap;
import java.util.List;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.DecisionTree$;
import org.apache.spark.mllib.tree.DecisionTreeSuite;
import org.apache.spark.mllib.tree.configuration.Algo;
import org.apache.spark.mllib.tree.configuration.Strategy;
import org.apache.spark.mllib.tree.impurity.Gini;
import org.apache.spark.mllib.tree.impurity.Impurity;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class JavaDecisionTreeSuite
implements Serializable {
    private transient JavaSparkContext sc;

    @Before
    public void setUp() {
        this.sc = new JavaSparkContext("local", "JavaDecisionTreeSuite");
    }

    @After
    public void tearDown() {
        this.sc.stop();
        this.sc = null;
    }

    int validatePrediction(List<LabeledPoint> validationData, DecisionTreeModel model) {
        int numCorrect = 0;
        for (LabeledPoint point : validationData) {
            Double prediction = model.predict(point.features());
            if (prediction.doubleValue() != point.label()) continue;
            ++numCorrect;
        }
        return numCorrect;
    }

    @Test
    public void runDTUsingConstructor() {
        List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
        JavaRDD rdd = this.sc.parallelize(arr);
        HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
        categoricalFeaturesInfo.put(1, 2);
        int maxDepth = 4;
        int numClasses = 2;
        int maxBins = 100;
        Strategy strategy = new Strategy(Algo.Classification(), (Impurity)Gini.instance(), maxDepth, numClasses, maxBins, categoricalFeaturesInfo);
        DecisionTree learner = new DecisionTree(strategy);
        DecisionTreeModel model = learner.run(rdd.rdd());
        int numCorrect = this.validatePrediction(arr, model);
        Assert.assertTrue(((long)numCorrect == rdd.count() ? 1 : 0) != 0);
    }

    @Test
    public void runDTUsingStaticMethods() {
        List<LabeledPoint> arr = DecisionTreeSuite.generateCategoricalDataPointsAsJavaList();
        JavaRDD rdd = this.sc.parallelize(arr);
        HashMap<Integer, Integer> categoricalFeaturesInfo = new HashMap<Integer, Integer>();
        categoricalFeaturesInfo.put(1, 2);
        int maxDepth = 4;
        int numClasses = 2;
        int maxBins = 100;
        Strategy strategy = new Strategy(Algo.Classification(), (Impurity)Gini.instance(), maxDepth, numClasses, maxBins, categoricalFeaturesInfo);
        DecisionTreeModel model = DecisionTree$.MODULE$.train(rdd.rdd(), strategy);
        int numCorrect = this.validatePrediction(arr, model);
        Assert.assertTrue(((long)numCorrect == rdd.count() ? 1 : 0) != 0);
    }
}

