package jsat.classifiers.svm;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.classifiers.svm.SupportVectorLearner;
import jsat.distributions.kernels.KernelTrick;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.IndexTable;
import jsat.utils.ListUtils;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:jsat/classifiers/svm/SBP.class */
public class SBP extends SupportVectorLearner implements BinaryScoreClassifier, Parameterized {
    private static final long serialVersionUID = 6112916782260792833L;
    private double nu;
    private int iterations;
    private double burnIn;
    private IndexTable it;

    public SBP(KernelTrick kernelTrick, SupportVectorLearner.CacheMode cacheMode, int i, double d) {
        super(kernelTrick, cacheMode);
        this.nu = 0.1d;
        this.burnIn = 0.2d;
        setIterations(i);
        setNu(d);
    }

    protected SBP(SBP sbp) {
        this(sbp.getKernel().mo143clone(), sbp.getCacheMode(), sbp.iterations, sbp.nu);
        if (sbp.alphas != null) {
            this.alphas = Arrays.copyOf(sbp.alphas, sbp.alphas.length);
        }
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public SBP m98clone() {
        return new SBP(this);
    }

    public void setIterations(int i) {
        this.iterations = i;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setNu(double d) {
        if (Double.isNaN(d) || d <= 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("nu must be in the range (0, 1)");
        }
        this.nu = d;
    }

    public double getNu() {
        return this.nu;
    }

    public void setBurnIn(double d) {
        if (Double.isNaN(d) || d < 0.0d || d >= 1.0d) {
            throw new IllegalArgumentException("BurnInFraction must be in [0, 1), not " + d);
        }
        this.burnIn = d;
    }

    public double getBurnIn() {
        return this.burnIn;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.vecs == null) {
            throw new UntrainedModelException("Classifier has yet to be trained");
        }
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (getScore(dataPoint) < 0.0d) {
            categoricalResults.setProb(0, 1.0d);
        } else {
            categoricalResults.setProb(1, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return kEvalSum(dataPoint.getNumericalValues());
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        if (classificationDataSet.getClassSize() != 2) {
            throw new FailedToFitException("SBP supports only binary classification");
        }
        int sampleSize = classificationDataSet.getSampleSize();
        int min = (int) Math.min(this.burnIn * this.iterations, this.iterations - 1);
        double[] dArr = new double[sampleSize];
        double[] dArr2 = new double[sampleSize];
        this.alphas = new double[sampleSize];
        double[] dArr3 = new double[sampleSize];
        double[] dArr4 = new double[sampleSize];
        this.vecs = new ArrayList(sampleSize);
        for (int i = 0; i < sampleSize; i++) {
            dArr4[i] = (classificationDataSet.getDataPointCategory(i) * 2) - 1;
            this.vecs.add(classificationDataSet.getDataPoint(i).getNumericalValues());
        }
        setCacheMode(getCacheMode());
        XORWOW xorwow = new XORWOW();
        double d = 0.0d;
        for (int i2 = 0; i2 < sampleSize; i2++) {
            d = Math.max(d, kEval(i2, i2));
        }
        double sqrt = 1.0d / Math.sqrt(d);
        double d2 = 0.0d;
        for (int i3 = 1; i3 <= this.iterations; i3++) {
            double sqrt2 = sqrt / Math.sqrt(i3);
            int sampleC = sampleC(xorwow, sampleSize, dArr, findGamma(dArr, sampleSize * this.nu));
            double[] dArr5 = this.alphas;
            dArr5[sampleC] = dArr5[sampleC] + sqrt2;
            d2 = projectionStep(updateLoop(d2, sqrt2, dArr, sampleC, dArr4, sampleSize), sampleSize, dArr);
            if (i3 >= min) {
                for (int i4 = 0; i4 < sampleSize; i4++) {
                    int i5 = i4;
                    dArr3[i5] = dArr3[i5] + this.alphas[i4];
                    int i6 = i4;
                    dArr2[i6] = dArr2[i6] + dArr[i4];
                }
            }
        }
        for (int i7 = 0; i7 < sampleSize; i7++) {
            this.alphas[i7] = dArr3[i7] / (this.iterations - min);
            dArr[i7] = dArr2[i7] / (this.iterations - min);
        }
        double findGamma = findGamma(dArr, sampleSize * this.nu);
        for (int i8 = 0; i8 < sampleSize; i8++) {
            double[] dArr6 = this.alphas;
            int i9 = i8;
            dArr6[i9] = dArr6[i9] / findGamma;
        }
        int i10 = 0;
        for (int i11 = 0; i11 < this.vecs.size(); i11++) {
            if (this.alphas[i11] != 0.0d) {
                ListUtils.swap(this.vecs, i10, i11);
                int i12 = i10;
                i10++;
                this.alphas[i12] = this.alphas[i11] * dArr4[i11];
            }
        }
        this.vecs = new ArrayList(this.vecs.subList(0, i10));
        this.alphas = Arrays.copyOfRange(this.alphas, 0, i10);
        this.it = null;
        setCacheMode(null);
        setAlphas(this.alphas);
    }

    private double projectionStep(double d, int i, double[] dArr) {
        if (d > 1.0d) {
            double sqrt = 1.0d / Math.sqrt(d);
            for (int i2 = 0; i2 < i; i2++) {
                int i3 = i2;
                dArr[i3] = dArr[i3] * sqrt;
                double[] dArr2 = this.alphas;
                int i4 = i2;
                dArr2[i4] = dArr2[i4] * sqrt;
            }
            d = 1.0d;
        }
        return d;
    }

    private int sampleC(Random random, int i, double[] dArr, double d) throws FailedToFitException {
        int nextInt;
        int i2 = 0;
        do {
            nextInt = random.nextInt(i);
            i2++;
            if (dArr[nextInt] <= d) {
                break;
            }
        } while (i2 < 5);
        if (dArr[nextInt] > d) {
            int i3 = 0;
            for (double d2 : dArr) {
                if (d2 < d) {
                    i3++;
                }
            }
            if (i3 == 0) {
                throw new FailedToFitException("BUG: please report");
            }
            int nextInt2 = random.nextInt(i3);
            nextInt = 0;
            for (int i4 = 0; i4 < dArr.length && nextInt < nextInt2; i4++) {
                if (dArr[nextInt] < d) {
                    nextInt++;
                }
            }
        }
        return nextInt;
    }

    private double updateLoop(double d, double d2, double[] dArr, int i, double[] dArr2, int i2) {
        double kEval = d + (2.0d * d2 * dArr[i]) + (d2 * d2 * kEval(i, i));
        double d3 = dArr2[i];
        for (int i3 = 0; i3 < i2; i3++) {
            int i4 = i3;
            dArr[i4] = dArr[i4] + (d2 * d3 * dArr2[i3] * kEval(i, i3));
        }
        return kEval;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    private double findGamma(double[] dArr, double d) {
        if (this.it == null) {
            this.it = new IndexTable(dArr);
        } else {
            this.it.sort(dArr);
        }
        double d2 = 0.0d;
        double d3 = 0.0d;
        double d4 = 0.0d;
        for (int i = 0; i < this.it.length(); i++) {
            double d5 = dArr[this.it.index(i)];
            d2 += d5;
            d4 = d3;
            d3 = (((d - (d5 * i)) + d2) / i) + d5;
            if ((d5 * i) - d2 >= d) {
                break;
            }
        }
        return d4;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
