package jsat.linear.vectorcollection.lsh;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import jsat.distributions.Normal;
import jsat.linear.DenseVector;
import jsat.linear.Vec;
import jsat.linear.VecPaired;
import jsat.linear.VecPairedComparable;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.linear.distancemetrics.ManhattanDistance;
import jsat.utils.IntList;
import jsat.utils.IntSet;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:jsat/linear/vectorcollection/lsh/E2LSH.class */
public class E2LSH<V extends Vec> {
    private List<V> vecs;
    private DistanceMetric dm;
    private double radius;
    private double eps;
    private double p1;
    private double p2;
    private int w;
    private double c;
    private double delta;
    private int L;
    private int k;
    private List<Double> distCache;
    private Vec[][] h;
    private double[][] b;
    private List<Map<Integer, List<Integer>>> tables;

    public E2LSH(List<V> list, double d, double d2, int i, int i2, double d3, DistanceMetric distanceMetric, List<Double> list2) {
        this.delta = Double.NaN;
        this.vecs = list;
        setRadius(d);
        this.delta = d3;
        setEps(d2);
        if (i <= 0) {
            this.w = 4;
        } else {
            this.w = i;
        }
        setDistanceMetric(distanceMetric);
        this.distCache = list2;
        if (i2 <= 0) {
            this.k = (int) Math.ceil(Math.log(list.size()) / Math.log(1.0d / this.p2));
        } else {
            this.k = i2;
        }
        if (d3 <= 0.0d || d3 >= 1.0d) {
            throw new IllegalArgumentException("dleta must be in range (0,1)");
        }
        this.L = (int) Math.ceil(Math.log(1.0d / d3) / (-Math.log(1.0d - Math.pow(this.p1, this.k))));
        createTablesAndHashes(new XORWOW());
    }

    public E2LSH(List<V> list, double d, double d2, int i, int i2, double d3, DistanceMetric distanceMetric) {
        this(list, d, d2, i, i2, d3, distanceMetric, distanceMetric.getAccelerationCache(list));
    }

    public List<? extends VecPaired<Vec, Double>> searchR(Vec vec) {
        return searchR(vec, false);
    }

    public List<? extends VecPaired<Vec, Double>> searchR(Vec vec, boolean z) {
        ArrayList arrayList = new ArrayList();
        IntSet intSet = new IntSet();
        for (int i = 0; i < this.L; i++) {
            Iterator<Integer> it = this.tables.get(i).get(Integer.valueOf(hash(i, vec))).iterator();
            while (it.hasNext()) {
                intSet.add((IntSet) Integer.valueOf(it.next().intValue()));
            }
        }
        List<Double> queryInfo = this.dm.getQueryInfo(vec);
        double c = z ? this.radius * getC() : this.radius;
        Iterator<Integer> it2 = intSet.iterator();
        while (it2.hasNext()) {
            int intValue = it2.next().intValue();
            double dist = this.dm.dist(intValue, vec, queryInfo, this.vecs, this.distCache);
            if (dist <= c) {
                arrayList.add(new VecPairedComparable(this.vecs.get(intValue), Double.valueOf(dist)));
            }
        }
        Collections.sort(arrayList);
        return arrayList;
    }

    private int hash(int i, Vec vec) {
        int[] iArr = new int[this.k];
        for (int i2 = 0; i2 < this.k; i2++) {
            iArr[i2] = (int) Math.floor(((vec.dot(this.h[i][i2]) / this.radius) + this.b[i][i2]) / this.w);
        }
        return Arrays.hashCode(iArr);
    }

    private void setEps(double d) {
        this.eps = d;
        this.c = d + 1.0d;
    }

    public double getC() {
        return this.c;
    }

    public double getRadius() {
        return this.radius;
    }

    public int getL() {
        return this.L;
    }

    private static double getP2L2(double d, double d2) {
        return (1.0d - (2.0d * Normal.cdf((-d) / d2, 0.0d, 1.0d))) - ((2.0d / ((Math.sqrt(6.283185307179586d) * d) / d2)) * (1.0d - Math.exp(((-d) * d) / ((2.0d * d2) * d2))));
    }

    private static double getP2L1(double d, double d2) {
        return ((2.0d * Math.atan(d / d2)) / 3.141592653589793d) - (Math.log(1.0d + Math.pow(d / d2, 2.0d)) / ((3.141592653589793d * d) / d2));
    }

    private void createTablesAndHashes(Random random) {
        int length = this.vecs.get(0).length();
        this.h = new Vec[this.L][this.k];
        this.b = new double[this.L][this.k];
        for (int i = 0; i < this.L; i++) {
            for (int i2 = 0; i2 < this.k; i2++) {
                DenseVector denseVector = new DenseVector(length);
                for (int i3 = 0; i3 < length; i3++) {
                    denseVector.set(i3, random.nextGaussian());
                }
                this.h[i][i2] = denseVector;
                this.b[i][i2] = random.nextDouble() * this.w;
            }
        }
        this.tables = new ArrayList(this.L);
        for (int i4 = 0; i4 < this.L; i4++) {
            this.tables.add(new HashMap());
            for (int i5 = 0; i5 < this.vecs.size(); i5++) {
                int hash = hash(i4, this.vecs.get(i5));
                List<Integer> list = this.tables.get(i4).get(Integer.valueOf(hash));
                if (list == null) {
                    list = new IntList(3);
                    this.tables.get(i4).put(Integer.valueOf(hash), list);
                }
                list.add(Integer.valueOf(i5));
            }
        }
    }

    private void setDistanceMetric(DistanceMetric distanceMetric) {
        if (!(distanceMetric instanceof EuclideanDistance) && !(distanceMetric instanceof ManhattanDistance)) {
            throw new IllegalArgumentException("only Euclidean and Manhatan (L1 and L2 norm) distances are supported");
        }
        this.dm = distanceMetric;
        if (distanceMetric instanceof EuclideanDistance) {
            this.p1 = getP2L2(this.w, 1.0d);
            this.p2 = getP2L2(this.w, this.c);
        } else {
            this.p1 = getP2L1(this.w, 1.0d);
            this.p2 = getP2L1(this.w, this.c);
        }
    }

    private void setRadius(double d) {
        if (Double.isInfinite(d) || Double.isNaN(d) || d <= 0.0d) {
            throw new IllegalArgumentException("Radius must be a positive constant, not " + d);
        }
        this.radius = d;
    }
}
