package jsat.regression;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.FunctionBase;
import jsat.math.rootfinding.Zeroin;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.DoubleList;
import jsat.utils.FakeExecutor;
import jsat.utils.ListUtils;

/* loaded from: input_file:jsat/regression/StochasticGradientBoosting.class */
public class StochasticGradientBoosting implements Regressor, Parameterized {
    private static final long serialVersionUID = -2855154397476855293L;
    public static final double DEFAULT_TRAINING_PROPORTION = 0.5d;
    public static final double DEFAULT_LEARNING_RATE = 0.1d;
    private double trainingProportion;
    private Regressor weakLearner;
    private Regressor strongLearner;
    private List<Regressor> F;
    private List<Double> coef;
    private double learningRate;
    private int maxIterations;

    public StochasticGradientBoosting(Regressor regressor, Regressor regressor2, int i, double d, double d2) {
        this.trainingProportion = d2;
        this.strongLearner = regressor;
        this.weakLearner = regressor2;
        this.learningRate = d;
        this.maxIterations = i;
    }

    public StochasticGradientBoosting(Regressor regressor, int i, double d, double d2) {
        this(null, regressor, i, d, d2);
    }

    public StochasticGradientBoosting(Regressor regressor, int i, double d) {
        this(regressor, i, d, 0.5d);
    }

    public StochasticGradientBoosting(Regressor regressor, int i) {
        this(regressor, i, 0.1d);
    }

    public void setMaxIterations(int i) {
        this.maxIterations = i;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setLearningRate(double d) {
        if (d > 1.0d || d <= 0.0d || Double.isNaN(d)) {
            throw new ArithmeticException("Invalid learning rate");
        }
        this.learningRate = d;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setTrainingProportion(double d) {
        if (d > 1.0d || d <= 0.0d || Double.isNaN(d)) {
            throw new ArithmeticException("Training Proportion is invalid");
        }
        this.trainingProportion = d;
    }

    public double getTrainingProportion() {
        return this.trainingProportion;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        if (this.F == null || this.F.isEmpty()) {
            throw new UntrainedModelException();
        }
        double d = 0.0d;
        for (int i = 0; i < this.F.size(); i++) {
            d += this.F.get(i).regress(dataPoint) * this.coef.get(i).doubleValue();
        }
        return d;
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        List<DataPointPair<Double>> asDPPList = regressionDataSet.getAsDPPList();
        this.F = new ArrayList(this.maxIterations);
        this.coef = new DoubleList(this.maxIterations);
        Regressor clone = this.strongLearner == null ? this.weakLearner.clone() : this.strongLearner.clone();
        if (executorService == null || (executorService instanceof FakeExecutor)) {
            clone.train(regressionDataSet);
        } else {
            clone.train(regressionDataSet, executorService);
        }
        this.F.add(clone);
        this.coef.add(Double.valueOf(this.learningRate * getMinimizingErrorConst(asDPPList, clone)));
        double[] dArr = new double[regressionDataSet.getSampleSize()];
        RegressionDataSet usingDPPList = RegressionDataSet.usingDPPList(asDPPList);
        int round = (int) Math.round(usingDPPList.getSampleSize() * this.trainingProportion);
        ArrayList arrayList = new ArrayList(round);
        Random random = new Random();
        for (int i = 0; i < this.maxIterations; i++) {
            double doubleValue = this.coef.get(i).doubleValue();
            Regressor regressor = this.F.get(i);
            for (int i2 = 0; i2 < usingDPPList.getSampleSize(); i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] + (doubleValue * regressor.regress(usingDPPList.getDataPoint(i2)));
                usingDPPList.setTargetValue(i2, regressionDataSet.getTargetValue(i2) - dArr[i2]);
            }
            arrayList.clear();
            ListUtils.randomSample((List) asDPPList, (List) arrayList, round, random);
            Regressor clone2 = this.weakLearner.clone();
            RegressionDataSet usingDPPList2 = RegressionDataSet.usingDPPList(arrayList);
            if (executorService == null || (executorService instanceof FakeExecutor)) {
                clone2.train(usingDPPList2);
            } else {
                clone2.train(usingDPPList2, executorService);
            }
            double minimizingErrorConst = getMinimizingErrorConst(asDPPList, clone2);
            this.F.add(clone2);
            this.coef.add(Double.valueOf(this.learningRate * minimizingErrorConst));
        }
    }

    private double getMinimizingErrorConst(List<DataPointPair<Double>> list, Regressor regressor) {
        return new Zeroin().root(1.0E-4d, 50, new double[]{-2.5d, 2.5d}, getDerivativeFunc(list, regressor), 0, 1.0d);
    }

    private Function getDerivativeFunc(final List<DataPointPair<Double>> list, final Regressor regressor) {
        return new FunctionBase() { // from class: jsat.regression.StochasticGradientBoosting.1
            private static final long serialVersionUID = -2211642040228795719L;

            @Override // jsat.math.Function
            public double f(Vec vec) {
                double d = (vec.get(0) * 2.0d) - 1.0E-5d;
                double d2 = 0.0d;
                for (DataPointPair dataPointPair : list) {
                    double regress = regressor.regress(dataPointPair.getDataPoint());
                    d2 += regress * ((d * regress) - (2.0d * ((Double) dataPointPair.getPair()).doubleValue()));
                }
                return d2 * 1.0E-5d;
            }
        };
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        train(regressionDataSet, null);
    }

    @Override // jsat.regression.Regressor
    public boolean supportsWeightedData() {
        return this.strongLearner != null ? this.strongLearner.supportsWeightedData() && this.weakLearner.supportsWeightedData() : this.weakLearner.supportsWeightedData();
    }

    @Override // jsat.regression.Regressor
    public StochasticGradientBoosting clone() {
        StochasticGradientBoosting stochasticGradientBoosting = new StochasticGradientBoosting(this.weakLearner.clone(), this.maxIterations, this.learningRate, this.trainingProportion);
        if (this.F != null) {
            stochasticGradientBoosting.F = new ArrayList(this.F.size());
            Iterator<Regressor> it = this.F.iterator();
            while (it.hasNext()) {
                stochasticGradientBoosting.F.add(it.next().clone());
            }
        }
        if (this.coef != null) {
            stochasticGradientBoosting.coef = new DoubleList(this.coef.size());
            Iterator<Double> it2 = this.coef.iterator();
            while (it2.hasNext()) {
                stochasticGradientBoosting.coef.add(Double.valueOf(it2.next().doubleValue()));
            }
        }
        if (this.strongLearner != null) {
            stochasticGradientBoosting.strongLearner = this.strongLearner.clone();
        }
        return stochasticGradientBoosting;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
