package jsat.classifiers.linear;

import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.SingleWeightVectorModel;
import jsat.classifiers.BaseUpdateableClassifier;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.UpdateableClassifier;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.BaseUpdateableRegressor;
import jsat.regression.RegressionDataSet;
import jsat.regression.UpdateableRegressor;

/* loaded from: input_file:jsat/classifiers/linear/PassiveAggressive.class */
public class PassiveAggressive implements UpdateableClassifier, BinaryScoreClassifier, UpdateableRegressor, Parameterized, SingleWeightVectorModel {
    private static final long serialVersionUID = -7130964391528405832L;
    private int epochs;
    private double C;
    private double eps;
    private Vec w;
    private Mode mode;

    /* loaded from: input_file:jsat/classifiers/linear/PassiveAggressive$Mode.class */
    public enum Mode {
        PA,
        PA1,
        PA2
    }

    public PassiveAggressive() {
        this(10, Mode.PA1);
    }

    public PassiveAggressive(int i, Mode mode) {
        this.C = 0.01d;
        this.eps = 0.001d;
        this.epochs = i;
        this.mode = mode;
    }

    public void setC(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d) || d <= 0.0d) {
            throw new ArithmeticException("Aggressiveness must be a positive constant");
        }
        this.C = d;
    }

    public double getC() {
        return this.C;
    }

    public void setMode(Mode mode) {
        this.mode = mode;
    }

    public Mode getMode() {
        return this.mode;
    }

    public void setEps(double d) {
        this.eps = d;
    }

    public double getEps() {
        return this.eps;
    }

    public void setEpochs(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("epochs must be a positive value");
        }
        this.epochs = i;
    }

    public int getEpochs() {
        return this.epochs;
    }

    @Override // jsat.SingleWeightVectorModel
    public Vec getRawWeight() {
        return this.w;
    }

    @Override // jsat.SingleWeightVectorModel
    public double getBias() {
        return 0.0d;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        if (i < 1) {
            return getRawWeight();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        if (i < 1) {
            return getBias();
        }
        throw new IndexOutOfBoundsException("Model has only 1 weight vector");
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return 1;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(2);
        if (getScore(dataPoint) > 0.0d) {
            categoricalResults.setProb(1, 1.0d);
        } else {
            categoricalResults.setProb(0, 1.0d);
        }
        return categoricalResults;
    }

    @Override // jsat.classifiers.calibration.BinaryScoreClassifier
    public double getScore(DataPoint dataPoint) {
        return dataPoint.getNumericalValues().dot(this.w);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        BaseUpdateableClassifier.trainEpochs(classificationDataSet, this, this.epochs);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void setUp(CategoricalData[] categoricalDataArr, int i, CategoricalData categoricalData) {
        if (categoricalData.getNumOfCategories() != 2) {
            throw new FailedToFitException("Only supports binary classification problems");
        }
        if (i < 1) {
            throw new FailedToFitException("only suppors learning from numeric attributes");
        }
        this.w = new DenseVector(i);
    }

    @Override // jsat.regression.UpdateableRegressor
    public void setUp(CategoricalData[] categoricalDataArr, int i) {
        if (i < 1) {
            throw new FailedToFitException("only suppors learning from numeric attributes");
        }
        this.w = new DenseVector(i);
    }

    @Override // jsat.classifiers.UpdateableClassifier
    public void update(DataPoint dataPoint, int i) {
        Vec numericalValues = dataPoint.getNumericalValues();
        int i2 = (i * 2) - 1;
        double max = Math.max(0.0d, 1.0d - (i2 * numericalValues.dot(this.w)));
        if (max == 0.0d) {
            return;
        }
        this.w.mutableAdd(i2 * getCorrection(max, numericalValues), numericalValues);
    }

    @Override // jsat.regression.UpdateableRegressor
    public void update(DataPoint dataPoint, double d) {
        Vec numericalValues = dataPoint.getNumericalValues();
        double dot = numericalValues.dot(this.w);
        double max = Math.max(0.0d, Math.abs(dot - d) - this.eps);
        if (max == 0.0d) {
            return;
        }
        this.w.mutableAdd(Math.signum(d - dot) * getCorrection(max, numericalValues), numericalValues);
    }

    private double getCorrection(double d, Vec vec) {
        double pow = Math.pow(vec.pNorm(2.0d), 2.0d);
        return this.mode == Mode.PA1 ? Math.min(this.C, d / pow) : this.mode == Mode.PA2 ? d / (pow + (1.0d / (2.0d * this.C))) : d / pow;
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        return this.w.dot(dataPoint.getNumericalValues());
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        train(regressionDataSet);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        BaseUpdateableRegressor.trainEpochs(regressionDataSet, this, this.epochs);
    }

    @Override // jsat.regression.Regressor
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public PassiveAggressive mo104clone() {
        PassiveAggressive passiveAggressive = new PassiveAggressive(this.epochs, this.mode);
        passiveAggressive.eps = this.eps;
        passiveAggressive.C = this.C;
        if (this.w != null) {
            passiveAggressive.w = this.w;
        }
        return passiveAggressive;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }

    public static Distribution guessC(DataSet dataSet) {
        return new LogUniform(0.001d, 100.0d);
    }
}
