package jsat.classifiers.neuralnetwork;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ExecutorService;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.classifiers.DataPointPair;
import jsat.clustering.SeedSelectionMethods;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.TrainableDistanceMetric;
import jsat.linear.vectorcollection.DefaultVectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionFactory;
import jsat.math.decayrates.DecayRate;
import jsat.math.decayrates.ExponetialDecay;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.FakeExecutor;

/* loaded from: input_file:jsat/classifiers/neuralnetwork/LVQ.class */
public class LVQ implements Classifier, Parameterized {
    private static final long serialVersionUID = -3911765006048793222L;
    public static final int DEFAULT_ITERATIONS = 200;
    public static final double DEFAULT_LEARNING_RATE = 0.1d;
    public static final double DEFAULT_EPS = 0.3d;
    public static final double DEFAULT_MSCALE = 0.30000000000000004d;
    public static final int DEFAULT_REPS_PER_CLASS = 3;
    public static final double DEFAULT_STOPPING_DIST = 0.001d;
    private DecayRate learningDecay;
    private int iterations;
    private double learningRate;
    protected DistanceMetric dm;
    private LVQVersion lvqVersion;
    private double eps;
    private double mScale;
    private double stoppingDist;
    private int representativesPerClass;
    protected Vec[] weights;
    protected int[] weightClass;
    protected int[] wins;
    private SeedSelectionMethods.SeedSelection seedSelection;
    protected VectorCollection<VecPaired<Vec, Integer>> vc;
    private VectorCollectionFactory<VecPaired<Vec, Integer>> vcf;
    public static final LVQVersion DEFAULT_LVQ_METHOD = LVQVersion.LVQ3;
    public static final SeedSelectionMethods.SeedSelection DEFAULT_SEED_SELECTION = SeedSelectionMethods.SeedSelection.KPP;

    /* loaded from: input_file:jsat/classifiers/neuralnetwork/LVQ$LVQVersion.class */
    public enum LVQVersion {
        LVQ1,
        LVQ2,
        LVQ21,
        LVQ3
    }

    public LVQ(DistanceMetric distanceMetric, int i) {
        this(distanceMetric, i, 0.1d, 3);
    }

    public LVQ(DistanceMetric distanceMetric, int i, double d, int i2) {
        this(distanceMetric, i, d, i2, DEFAULT_LVQ_METHOD, new ExponetialDecay());
    }

    public LVQ(DistanceMetric distanceMetric, int i, double d, int i2, LVQVersion lVQVersion, DecayRate decayRate) {
        setLearningDecay(decayRate);
        setIterations(i);
        setLearningRate(d);
        setDistanceMetric(distanceMetric);
        setLVQMethod(lVQVersion);
        setEpsilonDistance(0.3d);
        setMScale(0.30000000000000004d);
        setSeedSelection(DEFAULT_SEED_SELECTION);
        setVecCollectionFactory(new DefaultVectorCollectionFactory());
        setRepresentativesPerClass(i2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public LVQ(LVQ lvq) {
        this(lvq.dm.mo172clone(), lvq.iterations, lvq.learningRate, lvq.representativesPerClass, lvq.lvqVersion, lvq.learningDecay);
        if (lvq.weights != null) {
            this.wins = Arrays.copyOf(lvq.wins, lvq.wins.length);
            this.weights = new Vec[lvq.weights.length];
            this.weightClass = Arrays.copyOf(lvq.weightClass, lvq.weightClass.length);
            for (int i = 0; i < lvq.weights.length; i++) {
                this.weights[i] = lvq.weights[i].mo45clone();
            }
        }
        setEpsilonDistance(lvq.eps);
        setMScale(lvq.getMScale());
        setSeedSelection(lvq.getSeedSelection());
        if (lvq.vc != null) {
            this.vc = lvq.vc.m199clone();
        }
        setVecCollectionFactory(lvq.vcf.m178clone());
    }

    public void setMScale(double d) {
        if (d <= 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new ArithmeticException("Scale factor must be a positive constant, not " + d);
        }
        this.mScale = d;
    }

    public double getMScale() {
        return this.mScale;
    }

    public void setEpsilonDistance(double d) {
        if (d <= 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new ArithmeticException("eps factor must be a positive constant, not " + d);
        }
        this.eps = d;
    }

    public double getEpsilonDistance() {
        return this.eps;
    }

    public void setLearningRate(double d) {
        if (d <= 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new ArithmeticException("learning rate must be a positive constant, not " + d);
        }
        this.learningRate = d;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningDecay(DecayRate decayRate) {
        this.learningDecay = decayRate;
    }

    public DecayRate getLearningDecay() {
        return this.learningDecay;
    }

    public void setIterations(int i) {
        if (i < 0) {
            throw new ArithmeticException("Can not perform a negative number of iterations");
        }
        this.iterations = i;
    }

    public int getIterations() {
        return this.iterations;
    }

    public void setRepresentativesPerClass(int i) {
        this.representativesPerClass = i;
    }

    public int getRepresentativesPerClass() {
        return this.representativesPerClass;
    }

    public void setLVQMethod(LVQVersion lVQVersion) {
        this.lvqVersion = lVQVersion;
    }

    public LVQVersion getLVQMethod() {
        return this.lvqVersion;
    }

    public void setDistanceMetric(DistanceMetric distanceMetric) {
        this.dm = distanceMetric;
    }

    public DistanceMetric getDistanceMetric() {
        return this.dm;
    }

    public void setStoppingDist(double d) {
        if (d < 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new ArithmeticException("stopping dist must be a zero or positive constant, not " + d);
        }
        this.stoppingDist = d;
    }

    public double getStoppingDist() {
        return this.stoppingDist;
    }

    public void setSeedSelection(SeedSelectionMethods.SeedSelection seedSelection) {
        this.seedSelection = seedSelection;
    }

    public SeedSelectionMethods.SeedSelection getSeedSelection() {
        return this.seedSelection;
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }

    public void setVecCollectionFactory(VectorCollectionFactory<VecPaired<Vec, Integer>> vectorCollectionFactory) {
        this.vcf = vectorCollectionFactory;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(this.weightClass.length / this.representativesPerClass);
        categoricalResults.setProb(this.weightClass[this.vc.search(dataPoint.getNumericalValues(), 1).get(0).getVector().getPair().intValue()], 1.0d);
        return categoricalResults;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean epsClose(double d, double d2) {
        return Math.min(d / d2, d2 / d) > 1.0d - this.eps && Math.max(d / d2, d2 / d) < 1.0d + this.eps;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        if (executorService == null || (executorService instanceof FakeExecutor)) {
            TrainableDistanceMetric.trainIfNeeded(this.dm, classificationDataSet);
        } else {
            TrainableDistanceMetric.trainIfNeeded(this.dm, classificationDataSet, executorService);
        }
        Random random = new Random();
        int numOfCategories = classificationDataSet.getPredicting().getNumOfCategories();
        this.weights = new Vec[numOfCategories * this.representativesPerClass];
        Vec[] vecArr = new Vec[this.weights.length];
        this.weightClass = new int[this.weights.length];
        this.wins = new int[this.weights.length];
        int i = 0;
        for (int i2 = 0; i2 < numOfCategories; i2++) {
            List<DataPoint> samples = classificationDataSet.getSamples(i2);
            ArrayList arrayList = new ArrayList(samples.size());
            Iterator<DataPoint> it = samples.iterator();
            while (it.hasNext()) {
                arrayList.add(new DataPointPair(it.next(), Integer.valueOf(i2)));
            }
            Iterator<Vec> it2 = SeedSelectionMethods.selectIntialPoints(new ClassificationDataSet(arrayList, classificationDataSet.getPredicting()), this.representativesPerClass, this.dm, random, this.seedSelection).iterator();
            while (it2.hasNext()) {
                this.weights[i] = it2.next().mo45clone();
                vecArr[i] = this.weights[i].mo45clone();
                int i3 = i;
                i++;
                this.weightClass[i3] = i2;
            }
        }
        Vec mo45clone = this.weights[0].mo45clone();
        for (int i4 = 0; i4 < this.iterations; i4++) {
            for (int i5 = 0; i5 < this.weights.length; i5++) {
                this.weights[i5].copyTo(vecArr[i5]);
            }
            Arrays.fill(this.wins, 0);
            double rate = this.learningDecay.rate(i4, this.iterations, this.learningRate);
            for (int i6 = 0; i6 < classificationDataSet.getSampleSize(); i6++) {
                Vec numericalValues = classificationDataSet.getDataPoint(i6).getNumericalValues();
                int i7 = -1;
                int i8 = 0;
                int i9 = 0;
                double d = Double.POSITIVE_INFINITY;
                double d2 = Double.POSITIVE_INFINITY;
                for (int i10 = 0; i10 < this.weights.length; i10++) {
                    double dist = this.dm.dist(numericalValues, this.weights[i10]);
                    if (dist < d) {
                        if (this.lvqVersion == LVQVersion.LVQ2) {
                            d2 = d;
                            i9 = i8;
                        }
                        d = dist;
                        i8 = i10;
                        i7 = classificationDataSet.getDataPointCategory(i6);
                    }
                }
                if (this.lvqVersion.ordinal() >= LVQVersion.LVQ2.ordinal() && this.weightClass[i8] != this.weightClass[i9] && i7 == this.weightClass[i9] && epsClose(d, d2)) {
                    numericalValues.copyTo(mo45clone);
                    mo45clone.mutableSubtract(this.weights[i8]);
                    this.weights[i8].mutableSubtract(rate, mo45clone);
                    numericalValues.copyTo(mo45clone);
                    mo45clone.mutableSubtract(this.weights[i9]);
                    this.weights[i9].mutableAdd(rate, mo45clone);
                    int[] iArr = this.wins;
                    int i11 = i9;
                    iArr[i11] = iArr[i11] + 1;
                } else if (this.lvqVersion.ordinal() >= LVQVersion.LVQ21.ordinal() && this.weightClass[i8] != this.weightClass[i9] && i7 == this.weightClass[i8] && epsClose(d, d2)) {
                    numericalValues.copyTo(mo45clone);
                    mo45clone.mutableSubtract(this.weights[i8]);
                    this.weights[i8].mutableAdd(rate, mo45clone);
                    int[] iArr2 = this.wins;
                    int i12 = i8;
                    iArr2[i12] = iArr2[i12] + 1;
                    numericalValues.copyTo(mo45clone);
                    mo45clone.mutableSubtract(this.weights[i9]);
                    this.weights[i9].mutableSubtract(rate, mo45clone);
                } else if (this.lvqVersion.ordinal() < LVQVersion.LVQ3.ordinal() || this.weightClass[i8] != this.weightClass[i9] || Math.min(d / d2, d2 / d) <= (1.0d - this.eps) * (1.0d + this.eps)) {
                    numericalValues.copyTo(mo45clone);
                    mo45clone.mutableSubtract(this.weights[i8]);
                    if (i7 == this.weightClass[i8]) {
                        int[] iArr3 = this.wins;
                        int i13 = i8;
                        iArr3[i13] = iArr3[i13] + 1;
                        this.weights[i8].mutableAdd(rate, mo45clone);
                    } else {
                        this.weights[i8].mutableSubtract(rate, mo45clone);
                    }
                } else {
                    numericalValues.copyTo(mo45clone);
                    mo45clone.mutableSubtract(this.weights[i8]);
                    this.weights[i8].mutableAdd(this.mScale * rate, mo45clone);
                    numericalValues.copyTo(mo45clone);
                    mo45clone.mutableSubtract(this.weights[i9]);
                    this.weights[i9].mutableAdd(this.mScale * rate, mo45clone);
                    int[] iArr4 = this.wins;
                    int i14 = i8;
                    iArr4[i14] = iArr4[i14] + 1;
                    int[] iArr5 = this.wins;
                    int i15 = i9;
                    iArr5[i15] = iArr5[i15] + 1;
                }
            }
            boolean z = true;
            for (int i16 = 0; i16 < this.weights.length; i16++) {
                if (z && this.dm.dist(this.weights[i16], vecArr[i16]) > this.stoppingDist) {
                    z = false;
                }
            }
            if (z) {
                break;
            }
        }
        ArrayList arrayList2 = new ArrayList(this.weights.length);
        for (int i17 = 0; i17 < this.weights.length; i17++) {
            if (this.wins[i17] != 0) {
                arrayList2.add(new VecPaired(this.weights[i17], Integer.valueOf(i17)));
            }
        }
        if (executorService == null || (executorService instanceof FakeExecutor)) {
            this.vc = this.vcf.getVectorCollection(arrayList2, this.dm);
        } else {
            this.vc = this.vcf.getVectorCollection(arrayList2, this.dm, executorService);
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, null);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public LVQ mo69clone() {
        return new LVQ(this);
    }
}
