package jsat.classifiers.linear;

import java.util.Arrays;
import java.util.List;
import jsat.DataSet;
import jsat.SimpleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.linear.PassiveAggressive;
import jsat.distributions.Distribution;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.IndexTable;

/* loaded from: input_file:jsat/classifiers/linear/SPA.class */
public class SPA extends BaseUpdateableClassifier implements Parameterized, SimpleWeightVectorModel {
    private static final long serialVersionUID = 3613279663279244169L;
    private Vec[] w;
    private double[] bias;
    private double C;
    private boolean useBias;
    private PassiveAggressive.Mode mode;
    private double[] loss;
    private IndexTable it;

    public SPA() {
        this(10, PassiveAggressive.Mode.PA2);
    }

    public SPA(int i, PassiveAggressive.Mode mode) {
        this.C = 1.0d;
        this.useBias = false;
        setEpochs(i);
        setMode(mode);
    }

    public void setUseBias(boolean z) {
        this.useBias = z;
    }

    public boolean isUseBias() {
        return this.useBias;
    }

    public void setC(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d) || d <= 0.0d) {
            throw new ArithmeticException("Aggressiveness must be a positive constant");
        }
        this.C = d;
    }

    public double getC() {
        return this.C;
    }

    public void setMode(PassiveAggressive.Mode mode) {
        this.mode = mode;
    }

    public PassiveAggressive.Mode getMode() {
        return this.mode;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        return this.w[i];
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        return this.bias[i];
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return this.w.length;
    }

    @Override // jsat.classifiers.BaseUpdateableClassifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public SPA mo0clone() {
        SPA spa = new SPA();
        if (this.w != null) {
            spa.w = new Vec[this.w.length];
            for (int i = 0; i < this.w.length; i++) {
                spa.w[i] = this.w[i].mo45clone();
            }
        }
        if (this.it != null) {
            spa.it = new IndexTable(this.it.length());
        }
        if (this.loss != null) {
            spa.loss = Arrays.copyOf(this.loss, this.loss.length);
        }
        spa.C = this.C;
        spa.mode = this.mode;
        if (this.bias != null) {
            spa.bias = Arrays.copyOf(this.bias, this.bias.length);
        }
        spa.useBias = this.useBias;
        return spa;
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        this.w = new Vec[categoricalData.getNumOfCategories()];
        for (int i2 = 0; i2 < this.w.length; i2++) {
            this.w[i2] = new DenseVector(i);
        }
        this.bias = new double[this.w.length];
        this.loss = new double[this.w.length];
        this.it = new IndexTable(this.w.length);
    }

    private double getSupportClassGoal(double d, int i, double d2) {
        return this.mode == PassiveAggressive.Mode.PA1 ? Math.min(((i - 1) * d2) + (this.C * d), i * d2) : this.mode == PassiveAggressive.Mode.PA2 ? (((i * d) + ((i - 1) / (2.0d * this.C))) / (d + (1.0d / (2.0d * this.C)))) * d2 : i * d2;
    }

    private double getStepSize(double d, double d2, int i, double d3) {
        return this.mode == PassiveAggressive.Mode.PA1 ? Math.max(0.0d, d - Math.max((d3 / (i - 1)) - ((this.C / (i - 1)) * d2), d3 / i)) / d2 : this.mode == PassiveAggressive.Mode.PA2 ? Math.max(0.0d, d - (((d2 + (1.0d / (2.0d * this.C))) / ((i * d2) + ((i - 1) / (2.0d * this.C)))) * d3)) / d2 : Math.max(0.0d, d - (d3 / i)) / d2;
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        Vec numericalValues = dataPoint.getNumericalValues();
        double dot = this.w[i].dot(numericalValues) + this.bias[i];
        for (int i2 = 0; i2 < this.w.length; i2++) {
            if (i2 != i) {
                this.loss[i2] = Math.max(0.0d, 1.0d - ((dot - this.w[i2].dot(numericalValues)) - this.bias[i2]));
            } else {
                this.loss[i2] = Double.POSITIVE_INFINITY;
            }
        }
        double pow = Math.pow(numericalValues.pNorm(2.0d) + (this.useBias ? 1 : 0), 2.0d);
        this.it.sortR(this.loss);
        int i3 = 1;
        double d = 0.0d;
        while (true) {
            double d2 = d;
            if (i3 >= this.loss.length || d2 >= getSupportClassGoal(pow, i3, this.loss[this.it.index(i3)])) {
                break;
            }
            int i4 = i3;
            i3++;
            d = d2 + this.loss[this.it.index(i4)];
        }
        double d3 = 0.0d;
        for (int i5 = 1; i5 < i3; i5++) {
            d3 += this.loss[this.it.index(i5)];
        }
        for (int i6 = 1; i6 < i3; i6++) {
            int index = this.it.index(i6);
            double stepSize = getStepSize(this.loss[index], pow, i3, d3);
            this.w[i].mutableAdd(stepSize, numericalValues);
            this.w[index].mutableSubtract(stepSize, numericalValues);
            if (this.useBias) {
                double[] dArr = this.bias;
                dArr[i] = dArr[i] + stepSize;
                double[] dArr2 = this.bias;
                dArr2[index] = dArr2[index] - stepSize;
            }
        }
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        CategoricalResults categoricalResults = new CategoricalResults(this.w.length);
        int i = 0;
        double dot = this.w[0].dot(numericalValues) + this.bias[0];
        for (int i2 = 1; i2 < this.w.length; i2++) {
            double dot2 = this.w[i2].dot(numericalValues) + this.bias[i2];
            if (dot2 > dot) {
                dot = dot2;
                i = i2;
            }
        }
        categoricalResults.setProb(i, 1.0d);
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }

    public static Distribution guessC(DataSet dataSet) {
        return PassiveAggressive.guessC(dataSet);
    }
}
