package jsat.classifiers.linear;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.lossfunctions.LogisticLoss;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;

/* loaded from: input_file:jsat/classifiers/linear/BBR.class */
public class BBR implements Classifier, Parameterized, SingleWeightVectorModel {
    private static final long serialVersionUID = 8297213093357011082L;
    private Vec w;
    private int maxIterations;
    private double regularization;
    private boolean autoSetRegularization;
    private double bias;
    private boolean useBias;
    private double tolerance;
    private Prior prior;

    /* loaded from: input_file:jsat/classifiers/linear/BBR$Prior.class */
    public enum Prior {
        LAPLACE,
        GAUSSIAN
    }

    public BBR(double d, int i) {
        this(d, i, Prior.LAPLACE);
    }

    public BBR(double d, int i, Prior prior) {
        this.autoSetRegularization = true;
        this.useBias = true;
        this.tolerance = 5.0E-4d;
        setMaxIterations(i);
        setRegularization(d);
        setAutoSetRegularization(false);
        setPrior(prior);
    }

    public BBR(int i) {
        this(0.001d, i, Prior.LAPLACE);
    }

    public BBR(int i, Prior prior) {
        this.autoSetRegularization = true;
        this.useBias = true;
        this.tolerance = 5.0E-4d;
        setMaxIterations(i);
        setRegularization(0.01d);
        setAutoSetRegularization(true);
        setPrior(prior);
    }

    protected BBR(BBR bbr) {
        this.autoSetRegularization = true;
        this.useBias = true;
        this.tolerance = 5.0E-4d;
        if (bbr.w != null) {
            this.w = bbr.w.mo45clone();
        }
        this.maxIterations = bbr.maxIterations;
        this.regularization = bbr.regularization;
        this.autoSetRegularization = bbr.autoSetRegularization;
        this.bias = bbr.bias;
        this.useBias = bbr.useBias;
        this.tolerance = bbr.tolerance;
        this.prior = bbr.prior;
    }

    public void setRegularization(double d) {
        if (Double.isNaN(d) || Double.isNaN(d) || d <= 0.0d) {
            throw new IllegalArgumentException("Regularization must be positive, not " + d);
        }
        this.regularization = d;
    }

    public double getRegularization() {
        return this.regularization;
    }

    public void setAutoSetRegularization(boolean z) {
        this.autoSetRegularization = z;
    }

    public boolean isAutoSetRegularization() {
        return this.autoSetRegularization;
    }

    public void setMaxIterations(int i) {
        this.maxIterations = i;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setTolerance(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d) || d <= 0.0d) {
            throw new IllegalArgumentException("Tolerance must be positive, not " + d);
        }
        this.tolerance = d;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setUseBias(boolean z) {
        this.useBias = z;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    public void setPrior(Prior prior) {
        this.prior = prior;
    }

    public Prior getPrior() {
        return this.prior;
    }

    public Vec getWeightVec() {
        return this.w;
    }

    @Override // jsat.SingleWeightVectorModel
    public Vec getRawWeight() {
        return this.w;
    }

    @Override // jsat.SingleWeightVectorModel
    public double getBias() {
        return this.bias;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        if (i < 1) {
            return getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        if (i < 1) {
            return getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return 1;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        return LogisticLoss.classify(this.w.dot(dataPoint.getNumericalValues()) + this.bias);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        double d;
        double tenativeUpdate;
        double tenativeUpdate2;
        int numNumericalVars = classificationDataSet.getNumNumericalVars();
        if (numNumericalVars <= 0) {
            throw new FailedToFitException("Data set contains no numeric features");
        }
        Vec[] numericColumns = classificationDataSet.getNumericColumns();
        this.w = new DenseVector(numNumericalVars);
        double[] dArr = new double[this.useBias ? numNumericalVars + 1 : numNumericalVars];
        Arrays.fill(dArr, 1.0d);
        int sampleSize = classificationDataSet.getSampleSize();
        double[] dArr2 = new double[sampleSize];
        double[] dArr3 = new double[sampleSize];
        for (int i = 0; i < sampleSize; i++) {
            dArr3[i] = (classificationDataSet.getDataPointCategory(i) * 2) - 1;
        }
        if (this.autoSetRegularization) {
            double d2 = 0.0d;
            for (int i2 = 0; i2 < sampleSize; i2++) {
                d2 += Math.pow(classificationDataSet.getDataPoint(i2).getNumericalValues().pNorm(2.0d), 2.0d);
            }
            double d3 = (numNumericalVars * sampleSize) / d2;
            d = this.prior == Prior.LAPLACE ? Math.max(Math.sqrt(2.0d) / d3, 1.0E-15d) : Math.max(d3 * d3, 1.0E-15d);
        } else {
            d = this.regularization;
        }
        double[] dArr4 = new double[sampleSize];
        for (int i3 = 0; i3 < this.maxIterations; i3++) {
            for (int i4 = 0; i4 < numNumericalVars; i4++) {
                double d4 = this.w.get(i4);
                if (this.prior != Prior.LAPLACE) {
                    tenativeUpdate2 = tenativeUpdate(numericColumns, i4, d4, dArr3, dArr2, d, 0.0d, dArr);
                } else if (d4 == 0.0d) {
                    tenativeUpdate2 = tenativeUpdate(numericColumns, i4, d4, dArr3, dArr2, d, 1.0d, dArr);
                    if (tenativeUpdate2 <= 0.0d) {
                        tenativeUpdate2 = tenativeUpdate(numericColumns, i4, d4, dArr3, dArr2, d, -1.0d, dArr);
                        if (tenativeUpdate2 >= 0.0d) {
                            tenativeUpdate2 = 0.0d;
                        }
                    }
                } else {
                    double signum = Math.signum(d4);
                    tenativeUpdate2 = tenativeUpdate(numericColumns, i4, d4, dArr3, dArr2, d, signum, dArr);
                    if (signum * (d4 + tenativeUpdate2) < 0.0d) {
                        tenativeUpdate2 = -d4;
                    }
                }
                double min = Math.min(Math.max(tenativeUpdate2, -dArr[i4]), dArr[i4]);
                Iterator<IndexValue> it = numericColumns[i4].iterator();
                while (it.hasNext()) {
                    IndexValue next = it.next();
                    int index = next.getIndex();
                    double value = min * next.getValue() * dArr3[index];
                    dArr2[index] = dArr2[index] + value;
                    dArr4[index] = dArr4[index] + value;
                }
                double d5 = d4 + min;
                if (Math.abs(d5) < 1.0E-15d) {
                    d5 = 0.0d;
                }
                this.w.set(i4, d5);
                dArr[i4] = Math.max(2.0d * Math.abs(min), dArr[i4] / 2.0d);
            }
            if (this.useBias) {
                if (this.bias == 0.0d) {
                    tenativeUpdate = tenativeUpdate(null, numNumericalVars, this.bias, dArr3, dArr2, d, 1.0d, dArr);
                    if (tenativeUpdate <= 0.0d) {
                        tenativeUpdate = tenativeUpdate(null, numNumericalVars, this.bias, dArr3, dArr2, d, -1.0d, dArr);
                        if (tenativeUpdate >= 0.0d) {
                            tenativeUpdate = 0.0d;
                        }
                    }
                } else {
                    double signum2 = Math.signum(this.bias);
                    tenativeUpdate = tenativeUpdate(null, numNumericalVars, this.bias, dArr3, dArr2, d, signum2, dArr);
                    if (signum2 * (this.bias + tenativeUpdate) < 0.0d) {
                        tenativeUpdate = -this.bias;
                    }
                }
                double min2 = Math.min(Math.max(tenativeUpdate, -dArr[numNumericalVars]), dArr[numNumericalVars]);
                for (int i5 = 0; i5 < sampleSize; i5++) {
                    double d6 = min2 * dArr3[i5];
                    int i6 = i5;
                    dArr2[i6] = dArr2[i6] + d6;
                    int i7 = i5;
                    dArr4[i7] = dArr4[i7] + d6;
                }
                double d7 = this.bias + min2;
                if (Math.abs(d7) < 1.0E-15d) {
                    d7 = 0.0d;
                }
                this.bias = d7;
                dArr[numNumericalVars] = Math.max(2.0d * Math.abs(min2), dArr[numNumericalVars] / 2.0d);
            }
            double d8 = 0.0d;
            double d9 = 0.0d;
            for (int i8 = 0; i8 < sampleSize; i8++) {
                d8 += Math.abs(dArr4[i8]);
                d9 += Math.abs(dArr2[i8]);
            }
            if (d8 / (1.0d + d9) <= this.tolerance) {
                return;
            }
            Arrays.fill(dArr4, 0.0d);
        }
    }

    private static double F(double d, double d2) {
        if (Math.abs(d) <= d2) {
            return 0.25d;
        }
        return 1.0d / ((2.0d + Math.exp(Math.abs(d) - d2)) + Math.exp(d2 - Math.abs(d)));
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public BBR m42clone() {
        return new BBR(this);
    }

    private double tenativeUpdate(Vec[] vecArr, int i, double d, double[] dArr, double[] dArr2, double d2, double d3, double[] dArr3) {
        double d4 = 0.0d;
        double d5 = 0.0d;
        if (vecArr != null) {
            Vec vec = vecArr[i];
            if (vec.nnz() == 0) {
                return 0.0d;
            }
            Iterator<IndexValue> it = vec.iterator();
            while (it.hasNext()) {
                IndexValue next = it.next();
                double value = next.getValue();
                int index = next.getIndex();
                double exp = d4 + ((value * dArr[index]) / (1.0d + Math.exp(dArr2[index])));
                d5 += value * value * F(dArr2[index], dArr3[i] * Math.abs(value));
                if (this.prior == Prior.LAPLACE) {
                    d4 = exp - (d2 * d3);
                } else {
                    d4 = exp - (d / d2);
                    d5 += 1.0d / d2;
                }
            }
        } else {
            for (int i2 = 0; i2 < dArr.length; i2++) {
                d4 += (dArr[i2] / (1.0d + Math.exp(dArr2[i2]))) - (d2 * d3);
                d5 += F(dArr2[i2], dArr3[i]);
            }
        }
        return d4 / d5;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
