package jsat.classifiers.knn;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.CategoricalResults;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.distributions.Distribution;
import jsat.distributions.discrete.UniformDiscrete;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.EigenValueDecomposition;
import jsat.linear.Matrix;
import jsat.linear.RowColumnOps;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.VecPairedComparable;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.vectorcollection.DefaultVectorCollectionFactory;
import jsat.linear.vectorcollection.VectorCollection;
import jsat.linear.vectorcollection.VectorCollectionFactory;
import jsat.parameters.Parameter;
import jsat.parameters.Parameterized;
import jsat.utils.BoundedSortedList;
import jsat.utils.FakeExecutor;

/* loaded from: input_file:jsat/classifiers/knn/DANN.class */
public class DANN implements Classifier, Parameterized {
    private static final long serialVersionUID = -272865942127664672L;
    public static final int DEFAULT_KN = 40;
    public static final int DEFAULT_K = 1;
    public static final double DEFAULT_EPS = 1.0d;
    public static final int DEFAULT_ITERATIONS = 1;
    private int kn;
    private int k;
    private int maxIterations;
    private double eps;
    private VectorCollectionFactory<VecPaired<Vec, Integer>> vcf;
    private CategoricalData predicting;
    private VectorCollection<VecPaired<Vec, Integer>> vc;
    private List<VecPaired<Vec, Integer>> vecList;

    public DANN() {
        this(40, 1);
    }

    public DANN(int i, int i2) {
        this(i, i2, 1.0d);
    }

    public DANN(int i, int i2, double d) {
        this(i, i2, d, new DefaultVectorCollectionFactory());
    }

    public DANN(int i, int i2, double d, VectorCollectionFactory<VecPaired<Vec, Integer>> vectorCollectionFactory) {
        this(i, i2, d, 1, vectorCollectionFactory);
    }

    public DANN(int i, int i2, double d, int i3, VectorCollectionFactory<VecPaired<Vec, Integer>> vectorCollectionFactory) {
        setK(i2);
        setKn(i);
        setEpsilon(d);
        setMaxIterations(i3);
        this.vcf = vectorCollectionFactory;
    }

    public void setK(int i) {
        if (i < 1) {
            throw new ArithmeticException("Number of neighbors must be positive");
        }
        this.k = i;
    }

    public int getK() {
        return this.k;
    }

    public void setKn(int i) {
        if (i < 2) {
            throw new ArithmeticException("At least 2 neighbors are needed to adapat the metric");
        }
        this.kn = i;
    }

    public int getKn() {
        return this.kn;
    }

    public void setMaxIterations(int i) {
        if (i < 1) {
            throw new RuntimeException("At least one iteration must be performed");
        }
        this.maxIterations = i;
    }

    public int getMaxIterations() {
        return this.maxIterations;
    }

    public void setEpsilon(double d) {
        if (d < 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new ArithmeticException("Regularization must be a positive value");
        }
        this.eps = d;
    }

    public double getEpsilon() {
        return this.eps;
    }

    @Override // jsat.classifiers.Classifier
    public CategoricalResults classify(DataPoint dataPoint) {
        CategoricalResults categoricalResults = new CategoricalResults(this.predicting.getNumOfCategories());
        int numNumericalValues = dataPoint.numNumericalValues();
        DenseMatrix eye = Matrix.eye(numNumericalValues);
        DenseMatrix denseMatrix = new DenseMatrix(numNumericalValues, numNumericalValues);
        DenseMatrix denseMatrix2 = new DenseMatrix(numNumericalValues, numNumericalValues);
        Vec numericalValues = dataPoint.getNumericalValues();
        Vec denseVector = new DenseVector(numNumericalValues);
        double[] dArr = new double[this.kn];
        double[] dArr2 = new double[this.predicting.getNumOfCategories()];
        int[] iArr = new int[dArr2.length];
        Vec denseVector2 = new DenseVector(eye.rows());
        Vec[] vecArr = new Vec[this.predicting.getNumOfCategories()];
        for (int i = 0; i < vecArr.length; i++) {
            vecArr[i] = new DenseVector(denseVector2.length());
        }
        int i2 = 0;
        while (i2 < this.maxIterations) {
            denseVector2.zeroOut();
            Arrays.fill(dArr2, 0.0d);
            Arrays.fill(iArr, 0);
            for (Vec vec : vecArr) {
                vec.zeroOut();
            }
            double d = 0.0d;
            denseMatrix.zeroOut();
            denseMatrix2.zeroOut();
            List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> search = i2 == 0 ? this.vc.search(numericalValues, this.kn) : brute(numericalValues, eye, this.kn);
            double doubleValue = search.get(search.size() - 1).getPair().doubleValue();
            for (int i3 = 0; i3 < search.size(); i3++) {
                VecPaired<VecPaired<Vec, Integer>, Double> vecPaired = search.get(i3);
                dArr[i3] = Math.pow(Math.pow(1.0d - (Math.pow(vecPaired.getPair().doubleValue(), 2.0d) / doubleValue), 3.0d), 3.0d);
                d += dArr[i3];
                denseVector2.mutableAdd(vecPaired);
                int intValue = vecPaired.getVector().getPair().intValue();
                dArr2[intValue] = dArr2[intValue] + dArr[i3];
                vecArr[intValue].mutableAdd(vecPaired);
                iArr[intValue] = iArr[intValue] + 1;
            }
            denseVector2.mutableDivide(this.kn);
            for (int i4 = 0; i4 < vecArr.length; i4++) {
                if (iArr[i4] != 0.0d) {
                    vecArr[i4].mutableDivide(iArr[i4]);
                }
                int i5 = i4;
                dArr2[i5] = dArr2[i5] / d;
            }
            for (int i6 = 0; i6 < vecArr.length; i6++) {
                if (dArr2[i6] > 0.0d) {
                    vecArr[i6].copyTo(denseVector);
                    denseVector.mutableSubtract(denseVector2);
                    Matrix.OuterProductUpdate(denseMatrix, denseVector, denseVector, dArr2[i6]);
                    for (int i7 = 0; i7 < search.size(); i7++) {
                        VecPaired<VecPaired<Vec, Integer>, Double> vecPaired2 = search.get(i7);
                        if (vecPaired2.getVector().getPair().intValue() == i6) {
                            vecPaired2.copyTo(denseVector);
                            denseVector.mutableSubtract(vecArr[i6]);
                            Matrix.OuterProductUpdate(denseMatrix2, denseVector, denseVector, dArr[i7]);
                        }
                    }
                }
            }
            denseMatrix2.mutableMultiply(1.0d / d);
            RowColumnOps.addDiag(denseMatrix, 0, denseMatrix.rows(), this.eps);
            for (int i8 = 0; i8 < dArr2.length; i8++) {
                if (dArr2[i8] == 1.0d) {
                    categoricalResults.setProb(i8, 1.0d);
                    return categoricalResults;
                }
            }
            EigenValueDecomposition eigenValueDecomposition = new EigenValueDecomposition(denseMatrix2);
            Matrix d2 = eigenValueDecomposition.getD();
            for (int i9 = 0; i9 < d2.rows(); i9++) {
                d2.set(i9, i9, Math.pow(d2.get(i9, i9), -0.5d));
            }
            Matrix vt = eigenValueDecomposition.getVT();
            Matrix multiply = vt.transposeMultiply(d2).multiply(vt);
            eye.zeroOut();
            multiply.multiply(denseMatrix).multiply(multiply, eye);
            i2++;
        }
        Iterator<? extends VecPaired<VecPaired<Vec, Integer>, Double>> it = brute(numericalValues, eye, this.k).iterator();
        while (it.hasNext()) {
            categoricalResults.incProb(it.next().getVector().getPair().intValue(), 1.0d);
        }
        categoricalResults.normalize();
        return categoricalResults;
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet, ExecutorService executorService) {
        this.predicting = classificationDataSet.getPredicting();
        this.vecList = new ArrayList(classificationDataSet.getSampleSize());
        for (int i = 0; i < classificationDataSet.getSampleSize(); i++) {
            this.vecList.add(new VecPaired<>(classificationDataSet.getDataPoint(i).getNumericalValues(), Integer.valueOf(classificationDataSet.getDataPointCategory(i))));
        }
        if (executorService == null || (executorService instanceof FakeExecutor)) {
            this.vc = this.vcf.getVectorCollection(this.vecList, new EuclideanDistance());
        } else {
            this.vc = this.vcf.getVectorCollection(this.vecList, new EuclideanDistance(), executorService);
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, null);
    }

    @Override // jsat.classifiers.Classifier
    public boolean supportsWeightedData() {
        return false;
    }

    @Override // jsat.classifiers.Classifier
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Classifier m38clone() {
        DANN dann = new DANN(this.kn, this.k, this.maxIterations, this.vcf.m178clone());
        if (this.predicting != null) {
            dann.predicting = this.predicting.m1clone();
        }
        if (this.vc != null) {
            dann.vc = this.vc.clone();
        }
        if (this.vecList != null) {
            dann.vecList = new ArrayList(this.vecList);
        }
        return dann;
    }

    private double dist(Matrix matrix, Vec vec, Vec vec2, Vec vec3, Vec vec4) {
        vec.copyTo(vec3);
        vec3.mutableSubtract(vec2);
        vec4.zeroOut();
        matrix.multiply(vec3, 1.0d, vec4);
        return vec3.dot(vec4);
    }

    private List<? extends VecPaired<VecPaired<Vec, Integer>, Double>> brute(Vec vec, Matrix matrix, int i) {
        DenseVector denseVector = new DenseVector(vec.length());
        DenseVector denseVector2 = new DenseVector(vec.length());
        BoundedSortedList boundedSortedList = new BoundedSortedList(i, i);
        for (VecPaired<Vec, Integer> vecPaired : this.vecList) {
            boundedSortedList.add((BoundedSortedList) new VecPairedComparable(vecPaired, Double.valueOf(dist(matrix, vec, vecPaired, denseVector, denseVector2))));
        }
        return boundedSortedList;
    }

    public static Distribution guessK(DataSet dataSet) {
        return new UniformDiscrete(1, 25);
    }

    public static Distribution guessKn(DataSet dataSet) {
        return new UniformDiscrete(40, Math.max(dataSet.getSampleSize() / 5, 50));
    }

    @Override // jsat.parameters.Parameterized
    public List<Parameter> getParameters() {
        return Parameter.getParamsFromMethods(this);
    }

    @Override // jsat.parameters.Parameterized
    public Parameter getParameter(String str) {
        return Parameter.toParameterMap(getParameters()).get(str);
    }
}
