/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.ml.regression;

import java.io.Serializable;
import java.util.HashMap;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.ml.regression.DecisionTreeRegressionModel;
import org.apache.spark.ml.regression.DecisionTreeRegressor;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;

public class JavaDecisionTreeRegressorSuite
implements Serializable {
    private transient JavaSparkContext sc;

    @Before
    public void setUp() {
        this.sc = new JavaSparkContext("local", "JavaDecisionTreeRegressorSuite");
    }

    @After
    public void tearDown() {
        this.sc.stop();
        this.sc = null;
    }

    @Test
    public void runDT() {
        int nPoints = 20;
        double A = 2.0;
        double B = -1.5;
        JavaRDD data = this.sc.parallelize(LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
        HashMap<Integer, Integer> categoricalFeatures = new HashMap<Integer, Integer>();
        DataFrame dataFrame = TreeTests.setMetadata((JavaRDD<LabeledPoint>)data, categoricalFeatures, 0);
        DecisionTreeRegressor dt = new DecisionTreeRegressor().setMaxDepth(2).setMaxBins(10).setMinInstancesPerNode(5).setMinInfoGain(0.0).setMaxMemoryInMB(256).setCacheNodeIds(false).setCheckpointInterval(10).setMaxDepth(2);
        for (String impurity : DecisionTreeRegressor.supportedImpurities()) {
            dt.setImpurity(impurity);
        }
        DecisionTreeRegressionModel model = (DecisionTreeRegressionModel)dt.fit(dataFrame);
        model.transform(dataFrame);
        model.numNodes();
        model.depth();
        model.toDebugString();
    }
}

