package jsat.classifiers.linear;

import java.util.Arrays;
import java.util.Iterator;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.classifiers.linear.StochasticSTLinearL1;
import jsat.exceptions.FailedToFitException;
import jsat.exceptions.UntrainedModelException;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.regression.RegressionDataSet;

/* loaded from: input_file:jsat/classifiers/linear/SMIDAS.class */
public class SMIDAS extends StochasticSTLinearL1 {
    private static final long serialVersionUID = -4888083541600164597L;
    private double eta;

    public SMIDAS(double d) {
        this(d, 1000, 1.0E-14d, DEFAULT_LOSS);
    }

    public SMIDAS(double d, int i, double d2, StochasticSTLinearL1.Loss loss) {
        this(d, i, d2, loss, true);
    }

    public SMIDAS(double d, int i, double d2, StochasticSTLinearL1.Loss loss, boolean z) {
        setEta(d);
        setEpochs(i);
        setLambda(d2);
        setLoss(loss);
        setReScale(z);
    }

    public void setEta(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d) || d <= 0.0d) {
            throw new ArithmeticException("convergence parameter must be a positive value");
        }
        this.eta = d;
    }

    public double getEta() {
        return this.eta;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        if (this.w == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        return this.loss.classify(wDot(dataPoint.getNumericalValues()));
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        if (this.w == null) {
            throw new UntrainedModelException("Model has not been trained");
        }
        return this.loss.regress(wDot(dataPoint.getNumericalValues()));
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        if (classificationDataSet.getNumNumericalVars() < 3) {
            throw new FailedToFitException("SMIDAS requires at least 3 features");
        }
        if (classificationDataSet.getClassSize() != 2) {
            throw new FailedToFitException("SMIDAS only supports binary classification problems");
        }
        Vec[] upVecs = setUpVecs(classificationDataSet);
        DenseVector denseVec = DenseVector.toDenseVec(this.obvMin);
        DenseVector denseVec2 = DenseVector.toDenseVec(this.obvMax);
        DenseVector denseVector = new DenseVector(denseVec2.length());
        denseVector.mutableAdd(this.maxScaled - this.minScaled);
        denseVector.mutablePairwiseDivide(denseVec2.subtract(denseVec));
        boolean z = true;
        for (double d : this.obvMin) {
            if (d != 0.0d) {
                z = false;
            }
        }
        double[] dArr = new double[upVecs.length];
        for (int i = 0; i < classificationDataSet.getSampleSize(); i++) {
            if (z && this.minScaled == 0.0d) {
                upVecs[i].mutablePairwiseMultiply(denseVector);
            } else {
                upVecs[i] = upVecs[i].subtract(denseVec);
                upVecs[i].mutablePairwiseMultiply(denseVector);
                upVecs[i].mutableAdd(this.minScaled);
            }
            dArr[i] = (classificationDataSet.getDataPointCategory(i) * 2) - 1;
        }
        train(upVecs, dArr);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        train(regressionDataSet);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        if (regressionDataSet.getNumNumericalVars() < 3) {
            throw new FailedToFitException("SMIDAS requires at least 3 features");
        }
        Vec[] upVecs = setUpVecs(regressionDataSet);
        DenseVector denseVec = DenseVector.toDenseVec(this.obvMin);
        DenseVector denseVec2 = DenseVector.toDenseVec(this.obvMax);
        DenseVector denseVector = new DenseVector(denseVec2.length());
        denseVector.mutableAdd(this.maxScaled - this.minScaled);
        denseVector.mutablePairwiseDivide(denseVec2.subtract(denseVec));
        boolean z = true;
        for (double d : this.obvMin) {
            if (d != 0.0d) {
                z = false;
            }
        }
        double[] dArr = new double[upVecs.length];
        for (int i = 0; i < regressionDataSet.getSampleSize(); i++) {
            if (z && this.minScaled == 0.0d) {
                upVecs[i].mutablePairwiseMultiply(denseVector);
            } else {
                upVecs[i] = upVecs[i].subtract(denseVec);
                upVecs[i].mutablePairwiseMultiply(denseVector);
                upVecs[i].mutableAdd(this.minScaled);
            }
            dArr[i] = regressionDataSet.getTargetValue(i);
        }
        train(upVecs, dArr);
    }

    private void train(Vec[] vecArr, double[] dArr) {
        int length = vecArr.length;
        int length2 = vecArr[0].length();
        double log = 2.0d * Math.log(length2);
        DenseVector denseVector = new DenseVector(length2);
        double d = 0.0d;
        this.w = new DenseVector(length2);
        Random random = new Random();
        for (int i = 0; i < this.epochs; i++) {
            int nextInt = random.nextInt(length);
            double deriv = this.loss.deriv(this.w.dot(vecArr[nextInt]) + this.bias, dArr[nextInt]);
            denseVector.mutableSubtract(this.eta * deriv, vecArr[nextInt]);
            double d2 = d - (this.eta * deriv);
            Iterator<IndexValue> it = denseVector.iterator();
            while (it.hasNext()) {
                IndexValue next = it.next();
                int index = next.getIndex();
                double value = next.getValue();
                denseVector.set(index, Math.signum(value) * Math.max(0.0d, Math.abs(value) - (this.eta * this.lambda)));
            }
            d = Math.signum(d2) * Math.max(0.0d, Math.abs(d2) - (this.eta * this.lambda));
            double pNorm = denseVector.pNorm(log);
            if (pNorm > 0.0d) {
                double log2 = Math.log(pNorm);
                for (int i2 = 0; i2 < this.w.length(); i2++) {
                    double d3 = denseVector.get(i2);
                    this.w.set(i2, Math.signum(d3) * Math.exp(((log - 1.0d) * Math.log(Math.abs(d3))) - ((log - 2.0d) * log2)));
                }
                this.bias = Math.signum(d) * Math.exp(((log - 1.0d) * Math.log(Math.abs(d))) - ((log - 2.0d) * log2));
            } else {
                denseVector.zeroOut();
                d = 0.0d;
                this.w.zeroOut();
                this.bias = 0.0d;
            }
        }
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    @Override // jsat.classifiers.linear.StochasticSTLinearL1
    /* renamed from: clone, reason: merged with bridge method [inline-methods] and merged with bridge method [inline-methods] */
    public SMIDAS mo46clone() {
        SMIDAS smidas = new SMIDAS(this.eta, this.epochs, this.lambda, this.loss, this.reScale);
        if (this.w != null) {
            smidas.w = this.w.mo45clone();
        }
        smidas.bias = this.bias;
        smidas.minScaled = this.minScaled;
        smidas.maxScaled = this.maxScaled;
        if (this.obvMin != null) {
            smidas.obvMin = Arrays.copyOf(this.obvMin, this.obvMin.length);
        }
        if (this.obvMax != null) {
            smidas.obvMax = Arrays.copyOf(this.obvMax, this.obvMax.length);
        }
        return smidas;
    }

    private Vec[] setUpVecs(DataSet dataSet) {
        this.obvMin = new double[dataSet.getNumNumericalVars()];
        Arrays.fill(this.obvMin, Double.POSITIVE_INFINITY);
        this.obvMax = new double[dataSet.getNumNumericalVars()];
        Arrays.fill(this.obvMax, Double.NEGATIVE_INFINITY);
        Vec[] vecArr = new Vec[dataSet.getSampleSize()];
        for (int i = 0; i < dataSet.getSampleSize(); i++) {
            vecArr[i] = dataSet.getDataPoint(i).getNumericalValues();
            Iterator<IndexValue> it = vecArr[i].iterator();
            while (it.hasNext()) {
                IndexValue next = it.next();
                int index = next.getIndex();
                double value = next.getValue();
                this.obvMin[index] = Math.min(this.obvMin[index], value);
                this.obvMax[index] = Math.max(this.obvMax[index], value);
            }
        }
        if (vecArr[0].isSparse()) {
            for (int i2 = 0; i2 < this.obvMin.length; i2++) {
                this.obvMin[i2] = Math.min(this.obvMin[i2], 0.0d);
            }
        }
        if (!this.reScale) {
            for (double d : this.obvMin) {
                if (d < -1.0d) {
                    throw new FailedToFitException("Values must be in the range [-1,1], " + d + " violation encountered");
                }
            }
            for (double d2 : this.obvMax) {
                if (d2 > 1.0d) {
                    throw new FailedToFitException("Values must be in the range [-1,1], " + d2 + " violation encountered");
                }
            }
        }
        return vecArr;
    }
}
