package jsat.classifiers.linear;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.SimpleWeightVectorModel;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.WarmClassifier;
import jsat.distributions.Distribution;
import jsat.distributions.LogUniform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.ConcatenatedVec;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.SubVector;
import jsat.linear.Vec;
import jsat.lossfunctions.LossC;
import jsat.lossfunctions.LossFunc;
import jsat.lossfunctions.LossMC;
import jsat.lossfunctions.LossR;
import jsat.lossfunctions.SoftmaxLoss;
import jsat.math.FunctionP;
import jsat.math.FunctionVec;
import jsat.math.optimization.LBFGS;
import jsat.math.optimization.Optimizer2;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.regression.WarmRegressor;
import jsat.utils.ListUtils;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.ParallelUtils;

/* loaded from: input_file:jsat/classifiers/linear/LinearBatch.class */
public class LinearBatch implements Classifier, Regressor, Parameterized, SimpleWeightVectorModel, WarmClassifier, WarmRegressor {
    private static final long serialVersionUID = -446156124954287580L;
    private Vec[] ws;
    private double[] bs;
    private LossFunc loss;
    private double lambda0;
    private Optimizer2 optimizer;
    private double tolerance;
    private boolean useBiasTerm;

    /* loaded from: input_file:jsat/classifiers/linear/LinearBatch$GradFunction.class */
    public class GradFunction implements FunctionVec {
        private final DataSet D;
        private final LossFunc loss;
        private ThreadLocal<Vec> tempVecs;

        public GradFunction(DataSet dataSet, LossFunc lossFunc) {
            this.D = dataSet;
            this.loss = lossFunc;
        }

        @Override // jsat.math.FunctionVec
        public Vec f(double... dArr) {
            return f(DenseVector.toDenseVec(dArr));
        }

        @Override // jsat.math.FunctionVec
        public Vec f(Vec vec) {
            Vec mo45clone = vec.mo45clone();
            f(vec, mo45clone);
            return mo45clone;
        }

        @Override // jsat.math.FunctionVec
        public Vec f(Vec vec, Vec vec2) {
            if (vec2 == null) {
                vec2 = vec.mo45clone();
            }
            vec2.zeroOut();
            double d = 0.0d;
            for (int i = 0; i < this.D.getSampleSize(); i++) {
                DataPoint dataPoint = this.D.getDataPoint(i);
                Vec numericalValues = dataPoint.getNumericalValues();
                vec2.mutableAdd(this.loss.getDeriv(vec.dot(numericalValues), LinearBatch.getTargetY(this.D, i)) * dataPoint.getWeight(), numericalValues);
                d += dataPoint.getWeight();
            }
            vec2.mutableDivide(d);
            if (LinearBatch.this.lambda0 > 0.0d) {
                vec2.mutableSubtract(LinearBatch.this.lambda0, vec);
            }
            return vec2;
        }

        @Override // jsat.math.FunctionVec
        public Vec f(final Vec vec, Vec vec2, ExecutorService executorService) {
            if (vec2 == null) {
                vec2 = vec.mo45clone();
            }
            vec2.zeroOut();
            if (this.tempVecs == null) {
                this.tempVecs = new ThreadLocal<Vec>() { // from class: jsat.classifiers.linear.LinearBatch.GradFunction.1
                    /* JADX INFO: Access modifiers changed from: protected */
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.lang.ThreadLocal
                    public Vec initialValue() {
                        return vec.mo45clone();
                    }
                };
            }
            final Vec vec3 = vec2;
            final int sampleSize = this.D.getSampleSize();
            final int i = SystemInfo.LogicalCores;
            final CountDownLatch countDownLatch = new CountDownLatch(i);
            final double[] dArr = new double[i];
            for (int i2 = 0; i2 < SystemInfo.LogicalCores; i2++) {
                final int i3 = i2;
                executorService.submit(new Runnable() { // from class: jsat.classifiers.linear.LinearBatch.GradFunction.2
                    @Override // java.lang.Runnable
                    public void run() {
                        Vec vec4 = (Vec) GradFunction.this.tempVecs.get();
                        vec4.zeroOut();
                        double d = 0.0d;
                        for (int startBlock = ParallelUtils.getStartBlock(sampleSize, i3, i); startBlock < ParallelUtils.getEndBlock(sampleSize, i3, i); startBlock++) {
                            DataPoint dataPoint = GradFunction.this.D.getDataPoint(startBlock);
                            Vec numericalValues = dataPoint.getNumericalValues();
                            vec4.mutableAdd(GradFunction.this.loss.getDeriv(vec.dot(numericalValues), LinearBatch.getTargetY(GradFunction.this.D, startBlock)) * dataPoint.getWeight(), numericalValues);
                            d += dataPoint.getWeight();
                        }
                        synchronized (vec3) {
                            vec3.mutableAdd(vec4);
                        }
                        dArr[i3] = d;
                        countDownLatch.countDown();
                    }
                });
            }
            try {
                countDownLatch.await();
            } catch (InterruptedException e) {
                Logger.getLogger(LinearBatch.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            }
            double d = 0.0d;
            for (double d2 : dArr) {
                d += d2;
            }
            vec2.mutableDivide(d);
            if (LinearBatch.this.lambda0 > 0.0d) {
                vec2.mutableSubtract(LinearBatch.this.lambda0, vec);
            }
            return vec2;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:jsat/classifiers/linear/LinearBatch$GradMCFunction.class */
    public class GradMCFunction implements FunctionVec {
        private final ClassificationDataSet D;
        private final LossMC loss;
        private ThreadLocal<Vec> tempVecs;

        public GradMCFunction(ClassificationDataSet classificationDataSet, LossMC lossMC) {
            this.D = classificationDataSet;
            this.loss = lossMC;
        }

        @Override // jsat.math.FunctionVec
        public Vec f(double... dArr) {
            return f(DenseVector.toDenseVec(dArr));
        }

        @Override // jsat.math.FunctionVec
        public Vec f(Vec vec) {
            Vec mo45clone = vec.mo45clone();
            f(vec, mo45clone);
            return mo45clone;
        }

        @Override // jsat.math.FunctionVec
        public Vec f(Vec vec, Vec vec2) {
            if (vec2 == null) {
                vec2 = vec.mo45clone();
            }
            vec2.zeroOut();
            DenseVector denseVector = new DenseVector(this.D.getClassSize());
            int length = (vec.length() - (LinearBatch.this.useBiasTerm ? LinearBatch.this.bs.length : 0)) / this.D.getClassSize();
            double d = 0.0d;
            for (int i = 0; i < this.D.getSampleSize(); i++) {
                DataPoint dataPoint = this.D.getDataPoint(i);
                Vec numericalValues = dataPoint.getNumericalValues();
                for (int i2 = 0; i2 < denseVector.length(); i2++) {
                    denseVector.set(i2, new SubVector(i2 * length, length, vec).dot(numericalValues));
                }
                if (LinearBatch.this.useBiasTerm) {
                    denseVector.mutableAdd(new SubVector(vec.length() - LinearBatch.this.bs.length, LinearBatch.this.bs.length, vec));
                }
                this.loss.process(denseVector, denseVector);
                this.loss.deriv(denseVector, denseVector, this.D.getDataPointCategory(i));
                for (int i3 = 0; i3 < denseVector.length(); i3++) {
                    new SubVector(i3 * length, length, vec2).mutableAdd(denseVector.get(i3) * dataPoint.getWeight(), numericalValues);
                }
                d += dataPoint.getWeight();
            }
            vec2.mutableDivide(d);
            if (LinearBatch.this.lambda0 > 0.0d) {
                vec2.mutableSubtract(LinearBatch.this.lambda0, vec);
            }
            return vec2;
        }

        @Override // jsat.math.FunctionVec
        public Vec f(final Vec vec, Vec vec2, ExecutorService executorService) {
            if (vec2 == null) {
                vec2 = vec.mo45clone();
            }
            vec2.zeroOut();
            if (this.tempVecs == null) {
                this.tempVecs = new ThreadLocal<Vec>() { // from class: jsat.classifiers.linear.LinearBatch.GradMCFunction.1
                    /* JADX INFO: Access modifiers changed from: protected */
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.lang.ThreadLocal
                    public Vec initialValue() {
                        return vec.mo45clone();
                    }
                };
            }
            final Vec vec3 = vec2;
            final int sampleSize = this.D.getSampleSize();
            final int i = SystemInfo.LogicalCores;
            final int length = (vec.length() - (LinearBatch.this.useBiasTerm ? LinearBatch.this.bs.length : 0)) / this.D.getClassSize();
            final CountDownLatch countDownLatch = new CountDownLatch(i);
            final double[] dArr = new double[i];
            for (int i2 = 0; i2 < SystemInfo.LogicalCores; i2++) {
                final int i3 = i2;
                executorService.submit(new Runnable() { // from class: jsat.classifiers.linear.LinearBatch.GradMCFunction.2
                    @Override // java.lang.Runnable
                    public void run() {
                        Vec vec4 = (Vec) GradMCFunction.this.tempVecs.get();
                        vec4.zeroOut();
                        DenseVector denseVector = new DenseVector(GradMCFunction.this.D.getClassSize());
                        double d = 0.0d;
                        for (int startBlock = ParallelUtils.getStartBlock(sampleSize, i3, i); startBlock < ParallelUtils.getEndBlock(sampleSize, i3, i); startBlock++) {
                            DataPoint dataPoint = GradMCFunction.this.D.getDataPoint(startBlock);
                            Vec numericalValues = dataPoint.getNumericalValues();
                            for (int i4 = 0; i4 < denseVector.length(); i4++) {
                                denseVector.set(i4, new SubVector(i4 * length, length, vec).dot(numericalValues));
                            }
                            if (LinearBatch.this.useBiasTerm) {
                                denseVector.mutableAdd(new SubVector(vec.length() - LinearBatch.this.bs.length, LinearBatch.this.bs.length, vec));
                            }
                            GradMCFunction.this.loss.process(denseVector, denseVector);
                            GradMCFunction.this.loss.deriv(denseVector, denseVector, GradMCFunction.this.D.getDataPointCategory(startBlock));
                            Iterator<IndexValue> it = denseVector.iterator();
                            while (it.hasNext()) {
                                IndexValue next = it.next();
                                new SubVector(next.getIndex() * length, length, vec4).mutableAdd(next.getValue() * dataPoint.getWeight(), numericalValues);
                            }
                            d += dataPoint.getWeight();
                        }
                        synchronized (vec3) {
                            vec3.mutableAdd(vec4);
                        }
                        dArr[i3] = d;
                        countDownLatch.countDown();
                    }
                });
            }
            try {
                countDownLatch.await();
            } catch (InterruptedException e) {
                Logger.getLogger(LinearBatch.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            }
            double d = 0.0d;
            for (double d2 : dArr) {
                d += d2;
            }
            vec2.mutableDivide(d);
            if (LinearBatch.this.lambda0 > 0.0d) {
                vec2.mutableSubtract(LinearBatch.this.lambda0, vec);
            }
            return vec2;
        }
    }

    /* loaded from: input_file:jsat/classifiers/linear/LinearBatch$LossFunction.class */
    public class LossFunction implements FunctionP {
        private static final long serialVersionUID = -576682206943283356L;
        private final DataSet D;
        private final LossFunc loss;

        public LossFunction(DataSet dataSet, LossFunc lossFunc) {
            this.D = dataSet;
            this.loss = lossFunc;
        }

        @Override // jsat.math.Function
        public double f(Vec vec) {
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i = 0; i < this.D.getSampleSize(); i++) {
                DataPoint dataPoint = this.D.getDataPoint(i);
                d += this.loss.getLoss(vec.dot(dataPoint.getNumericalValues()), LinearBatch.getTargetY(this.D, i)) * dataPoint.getWeight();
                d2 += dataPoint.getWeight();
            }
            return LinearBatch.this.lambda0 > 0.0d ? (d / d2) + (LinearBatch.this.lambda0 * vec.dot(vec)) : d / d2;
        }

        @Override // jsat.math.FunctionP
        public double f(final Vec vec, ExecutorService executorService) {
            final int sampleSize = this.D.getSampleSize();
            final int i = SystemInfo.LogicalCores;
            final double[] dArr = new double[i];
            ArrayList arrayList = new ArrayList(i);
            for (int i2 = 0; i2 < SystemInfo.LogicalCores; i2++) {
                final int i3 = i2;
                arrayList.add(executorService.submit(new Callable<Double>() { // from class: jsat.classifiers.linear.LinearBatch.LossFunction.1
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public Double call() throws Exception {
                        double d = 0.0d;
                        double d2 = 0.0d;
                        for (int startBlock = ParallelUtils.getStartBlock(sampleSize, i3, i); startBlock < ParallelUtils.getEndBlock(sampleSize, i3, i); startBlock++) {
                            DataPoint dataPoint = LossFunction.this.D.getDataPoint(startBlock);
                            d += LossFunction.this.loss.getLoss(vec.dot(dataPoint.getNumericalValues()), LinearBatch.getTargetY(LossFunction.this.D, startBlock)) * dataPoint.getWeight();
                            d2 += dataPoint.getWeight();
                        }
                        dArr[i3] = d2;
                        return Double.valueOf(d);
                    }
                }));
            }
            double d = 0.0d;
            try {
                Iterator it = ListUtils.collectFutures(arrayList).iterator();
                while (it.hasNext()) {
                    d += ((Double) it.next()).doubleValue();
                }
            } catch (InterruptedException e) {
                Logger.getLogger(LinearBatch.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            } catch (ExecutionException e2) {
                Logger.getLogger(LinearBatch.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
            }
            double d2 = 0.0d;
            for (double d3 : dArr) {
                d2 += d3;
            }
            return LinearBatch.this.lambda0 > 0.0d ? (d / d2) + (LinearBatch.this.lambda0 * vec.dot(vec)) : d / d2;
        }

        @Override // jsat.math.Function
        public double f(double... dArr) {
            return f(DenseVector.toDenseVec(dArr));
        }
    }

    /* loaded from: input_file:jsat/classifiers/linear/LinearBatch$LossMCFunction.class */
    public class LossMCFunction implements FunctionP {
        private static final long serialVersionUID = -861700500356609563L;
        private final ClassificationDataSet D;
        private final LossMC loss;

        public LossMCFunction(ClassificationDataSet classificationDataSet, LossMC lossMC) {
            this.D = classificationDataSet;
            this.loss = lossMC;
        }

        @Override // jsat.math.Function
        public double f(Vec vec) {
            double d = 0.0d;
            DenseVector denseVector = new DenseVector(this.D.getClassSize());
            int length = (vec.length() - (LinearBatch.this.useBiasTerm ? LinearBatch.this.bs.length : 0)) / this.D.getClassSize();
            double d2 = 0.0d;
            for (int i = 0; i < this.D.getSampleSize(); i++) {
                DataPoint dataPoint = this.D.getDataPoint(i);
                Vec numericalValues = dataPoint.getNumericalValues();
                for (int i2 = 0; i2 < denseVector.length(); i2++) {
                    denseVector.set(i2, new SubVector(i2 * length, length, vec).dot(numericalValues));
                }
                if (LinearBatch.this.useBiasTerm) {
                    denseVector.mutableAdd(new SubVector(vec.length() - LinearBatch.this.bs.length, LinearBatch.this.bs.length, vec));
                }
                this.loss.process(denseVector, denseVector);
                d += this.loss.getLoss(denseVector, this.D.getDataPointCategory(i)) * dataPoint.getWeight();
                d2 += dataPoint.getWeight();
            }
            return LinearBatch.this.lambda0 > 0.0d ? (d / d2) + (LinearBatch.this.lambda0 * vec.dot(vec)) : d;
        }

        @Override // jsat.math.FunctionP
        public double f(final Vec vec, ExecutorService executorService) {
            final int sampleSize = this.D.getSampleSize();
            final int i = SystemInfo.LogicalCores;
            final int length = (vec.length() - (LinearBatch.this.useBiasTerm ? LinearBatch.this.bs.length : 0)) / this.D.getClassSize();
            ArrayList arrayList = new ArrayList(i);
            final double[] dArr = new double[i];
            for (int i2 = 0; i2 < SystemInfo.LogicalCores; i2++) {
                final int i3 = i2;
                arrayList.add(executorService.submit(new Callable<Double>() { // from class: jsat.classifiers.linear.LinearBatch.LossMCFunction.1
                    /* JADX WARN: Can't rename method to resolve collision */
                    @Override // java.util.concurrent.Callable
                    public Double call() throws Exception {
                        double d = 0.0d;
                        DenseVector denseVector = new DenseVector(LossMCFunction.this.D.getClassSize());
                        double d2 = 0.0d;
                        for (int startBlock = ParallelUtils.getStartBlock(sampleSize, i3, i); startBlock < ParallelUtils.getEndBlock(sampleSize, i3, i); startBlock++) {
                            DataPoint dataPoint = LossMCFunction.this.D.getDataPoint(startBlock);
                            Vec numericalValues = dataPoint.getNumericalValues();
                            for (int i4 = 0; i4 < denseVector.length(); i4++) {
                                denseVector.set(i4, new SubVector(i4 * length, length, vec).dot(numericalValues));
                            }
                            if (LinearBatch.this.useBiasTerm) {
                                denseVector.mutableAdd(new SubVector(vec.length() - LinearBatch.this.bs.length, LinearBatch.this.bs.length, vec));
                            }
                            LossMCFunction.this.loss.process(denseVector, denseVector);
                            d += LossMCFunction.this.loss.getLoss(denseVector, LossMCFunction.this.D.getDataPointCategory(startBlock)) * dataPoint.getWeight();
                            d2 += dataPoint.getWeight();
                        }
                        dArr[i3] = d2;
                        return Double.valueOf(d);
                    }
                }));
            }
            double d = 0.0d;
            try {
                Iterator it = ListUtils.collectFutures(arrayList).iterator();
                while (it.hasNext()) {
                    d += ((Double) it.next()).doubleValue();
                }
            } catch (InterruptedException e) {
                Logger.getLogger(LinearBatch.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            } catch (ExecutionException e2) {
                Logger.getLogger(LinearBatch.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
            }
            double d2 = 0.0d;
            for (double d3 : dArr) {
                d2 += d3;
            }
            return (d / d2) + (LinearBatch.this.lambda0 * vec.dot(vec));
        }

        @Override // jsat.math.Function
        public double f(double... dArr) {
            return f(DenseVector.toDenseVec(dArr));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:jsat/classifiers/linear/LinearBatch$VecWithBias.class */
    public class VecWithBias extends Vec {
        public Vec w;
        public double[] b;

        public VecWithBias(Vec vec, double[] dArr) {
            this.w = vec;
            this.b = dArr;
        }

        @Override // jsat.linear.Vec
        public double dot(Vec vec) {
            return vec.length() == this.w.length() ? this.w.dot(vec) + this.b[0] : super.dot(vec);
        }

        @Override // jsat.linear.Vec
        public void mutableAdd(double d, Vec vec) {
            if (vec.length() != this.w.length()) {
                super.mutableAdd(d, vec);
                return;
            }
            this.w.mutableAdd(d, vec);
            double[] dArr = this.b;
            dArr[0] = dArr[0] + d;
        }

        @Override // jsat.linear.Vec
        public int length() {
            return this.w.length() + 1;
        }

        @Override // jsat.linear.Vec
        public double get(int i) {
            if (i < this.w.length()) {
                return this.w.get(i);
            }
            if (i == this.w.length()) {
                return this.b[0];
            }
            throw new IndexOutOfBoundsException();
        }

        @Override // jsat.linear.Vec
        public void set(int i, double d) {
            if (i < this.w.length()) {
                this.w.set(i, d);
            } else {
                if (i != this.w.length()) {
                    throw new IndexOutOfBoundsException();
                }
                this.b[0] = d;
            }
        }

        @Override // jsat.linear.Vec
        public boolean isSparse() {
            return this.w.isSparse();
        }

        @Override // jsat.linear.Vec
        /* renamed from: clone */
        public Vec mo45clone() {
            return new VecWithBias(this.w.mo45clone(), Arrays.copyOf(this.b, this.b.length));
        }
    }

    public LinearBatch() {
        this(new SoftmaxLoss(), 1.0E-6d);
    }

    public LinearBatch(LossFunc lossFunc, double d) {
        this(lossFunc, d, 0.001d);
    }

    public LinearBatch(LossFunc lossFunc, double d, double d2) {
        this(lossFunc, d, d2, null);
    }

    public LinearBatch(LossFunc lossFunc, double d, double d2, Optimizer2 optimizer2) {
        this.useBiasTerm = true;
        setLoss(lossFunc);
        setLambda0(d);
        setOptimizer(optimizer2);
        setTolerance(d2);
    }

    public LinearBatch(LinearBatch linearBatch) {
        this(linearBatch.loss.m203clone(), linearBatch.lambda0, linearBatch.tolerance, linearBatch.optimizer == null ? null : linearBatch.optimizer.m217clone());
        if (linearBatch.ws != null) {
            this.ws = new Vec[linearBatch.ws.length];
            for (int i = 0; i < linearBatch.ws.length; i++) {
                this.ws[i] = linearBatch.ws[i].mo45clone();
            }
        }
        if (linearBatch.bs != null) {
            this.bs = Arrays.copyOf(linearBatch.bs, linearBatch.bs.length);
        }
    }

    public void setUseBiasTerm(boolean z) {
        this.useBiasTerm = z;
    }

    public boolean isUseBiasTerm() {
        return this.useBiasTerm;
    }

    public void setLambda0(double d) {
        if (d < 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Lambda0 must be non-negative, not " + d);
        }
        this.lambda0 = d;
    }

    public double getLambda0() {
        return this.lambda0;
    }

    public void setLoss(LossFunc lossFunc) {
        this.loss = lossFunc;
    }

    public LossFunc getLoss() {
        return this.loss;
    }

    public void setOptimizer(Optimizer2 optimizer2) {
        this.optimizer = optimizer2;
    }

    public Optimizer2 getOptimizer() {
        return this.optimizer;
    }

    public void setTolerance(double d) {
        if (d < 0.0d || Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("Tolerance must be a non-negative constant, not " + d);
        }
        this.tolerance = d;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        if (this.ws.length == 1) {
            return ((LossC) this.loss).getClassification(this.ws[0].dot(numericalValues) + this.bs[0]);
        }
        DenseVector denseVector = new DenseVector(this.ws.length);
        for (int i = 0; i < this.ws.length; i++) {
            denseVector.set(i, this.ws[i].dot(numericalValues) + this.bs[i]);
        }
        ((LossMC) this.loss).process(denseVector, denseVector);
        return ((LossMC) this.loss).getClassification(denseVector);
    }

    @Override // jsat.regression.Regressor
    public double regress(DataPoint dataPoint) {
        return ((LossR) this.loss).getRegression(this.ws[0].dot(dataPoint.getNumericalValues()) + this.bs[0]);
    }

    @Override // jsat.classifiers.WarmClassifier
    public void trainC(ClassificationDataSet classificationDataSet, Classifier classifier) {
        trainC(classificationDataSet, classifier, null);
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        trainC(classificationDataSet, null, executorService);
    }

    @Override // jsat.classifiers.WarmClassifier
    public void trainC(ClassificationDataSet classificationDataSet, Classifier classifier, ExecutorService executorService) {
        ConcatenatedVec concatenatedVec;
        if (classificationDataSet.getNumNumericalVars() <= 0) {
            throw new FailedToFitException("LinearBath requires numeric features to work");
        }
        if (!(this.loss instanceof LossC)) {
            throw new FailedToFitException("Loss function " + this.loss.getClass().getSimpleName() + " does not support classification");
        }
        if (classificationDataSet.getClassSize() <= 2) {
            this.ws = new Vec[1];
            this.bs = new double[1];
        } else {
            if (!(this.loss instanceof LossMC)) {
                throw new FailedToFitException("Loss function " + this.loss.getClass().getSimpleName() + " does not support multi-class classification");
            }
            this.ws = new Vec[classificationDataSet.getClassSize()];
            this.bs = new double[this.ws.length];
        }
        for (int i = 0; i < this.ws.length; i++) {
            this.ws[i] = new DenseVector(classificationDataSet.getNumNumericalVars());
        }
        Optimizer2 lbfgs = this.optimizer == null ? new LBFGS(10) : this.optimizer.m217clone();
        doWarmStartIfNotNull(classifier);
        if (this.ws.length == 1) {
            if (!this.useBiasTerm) {
                lbfgs.optimize(this.tolerance, this.ws[0], this.ws[0], new LossFunction(classificationDataSet, this.loss), new GradFunction(classificationDataSet, this.loss), null, executorService);
                return;
            } else {
                VecWithBias vecWithBias = new VecWithBias(this.ws[0], this.bs);
                lbfgs.optimize(this.tolerance, vecWithBias, vecWithBias, new LossFunction(classificationDataSet, this.loss), new GradFunction(classificationDataSet, this.loss), null, executorService);
                return;
            }
        }
        LossMC lossMC = (LossMC) this.loss;
        if (this.useBiasTerm) {
            ArrayList arrayList = new ArrayList(Arrays.asList(this.ws));
            arrayList.add(DenseVector.toDenseVec(this.bs));
            concatenatedVec = new ConcatenatedVec(arrayList);
        } else {
            concatenatedVec = new ConcatenatedVec((List<Vec>) Arrays.asList(this.ws));
        }
        lbfgs.optimize(this.tolerance, concatenatedVec, new DenseVector(concatenatedVec), new LossMCFunction(classificationDataSet, lossMC), new GradMCFunction(classificationDataSet, lossMC), null, executorService);
    }

    private void doWarmStartIfNotNull(Object obj) throws FailedToFitException {
        if (obj != null) {
            if (!(obj instanceof SimpleWeightVectorModel)) {
                throw new FailedToFitException("Can not warm warm from " + obj.getClass().getCanonicalName());
            }
            SimpleWeightVectorModel simpleWeightVectorModel = (SimpleWeightVectorModel) obj;
            if (simpleWeightVectorModel.numWeightsVecs() != this.ws.length) {
                throw new FailedToFitException("Warm solution has " + simpleWeightVectorModel.numWeightsVecs() + " weight vectors instead of " + this.ws.length);
            }
            for (int i = 0; i < this.ws.length; i++) {
                simpleWeightVectorModel.getRawWeight(i).copyTo(this.ws[i]);
                if (this.useBiasTerm) {
                    this.bs[i] = simpleWeightVectorModel.getBias(i);
                }
            }
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, (ExecutorService) null);
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet, ExecutorService executorService) {
        train(regressionDataSet, null, executorService);
    }

    @Override // jsat.regression.WarmRegressor
    public void train(RegressionDataSet regressionDataSet, Regressor regressor) {
        train(regressionDataSet, regressor, null);
    }

    @Override // jsat.regression.WarmRegressor
    public void train(RegressionDataSet regressionDataSet, Regressor regressor, ExecutorService executorService) {
        if (regressionDataSet.getNumNumericalVars() <= 0) {
            throw new FailedToFitException("LinearBath requires numeric features to work");
        }
        if (!(this.loss instanceof LossR)) {
            throw new FailedToFitException("Loss function " + this.loss.getClass().getSimpleName() + " does not regression");
        }
        this.ws = new Vec[]{new DenseVector(regressionDataSet.getNumNumericalVars())};
        this.bs = new double[1];
        Optimizer2 lbfgs = this.optimizer == null ? new LBFGS(10) : this.optimizer.m217clone();
        doWarmStartIfNotNull(regressor);
        if (!this.useBiasTerm) {
            lbfgs.optimize(this.tolerance, this.ws[0], this.ws[0], new LossFunction(regressionDataSet, this.loss), new GradFunction(regressionDataSet, this.loss), null, executorService);
        } else {
            VecWithBias vecWithBias = new VecWithBias(this.ws[0], this.bs);
            lbfgs.optimize(this.tolerance, vecWithBias, vecWithBias, new LossFunction(regressionDataSet, this.loss), new GradFunction(regressionDataSet, this.loss), null, executorService);
        }
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        train(regressionDataSet, (ExecutorService) null);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double getTargetY(DataSet dataSet, int i) {
        return dataSet instanceof ClassificationDataSet ? (((ClassificationDataSet) dataSet).getDataPointCategory(i) * 2) - 1 : ((RegressionDataSet) dataSet).getTargetValue(i);
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }

    @Override // jsat.classifiers.WarmClassifier, jsat.regression.WarmRegressor
    public boolean warmFromSameDataOnly() {
        return false;
    }

    @Override // jsat.SimpleWeightVectorModel
    public Vec getRawWeight(int i) {
        return this.ws[i];
    }

    @Override // jsat.SimpleWeightVectorModel
    public double getBias(int i) {
        return this.bs[i];
    }

    @Override // jsat.SimpleWeightVectorModel
    public int numWeightsVecs() {
        return this.ws.length;
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return true;
    }

    @Override // jsat.regression.Regressor
    public LinearBatch clone() {
        return new LinearBatch(this);
    }

    public static Distribution guessLambda0(DataSet dataSet) {
        return new LogUniform(1.0E-7d, 0.01d);
    }
}
