package jsat.distributions.kernels;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.SubMatrix;
import jsat.linear.Vec;
import jsat.math.FastMath;
import jsat.math.FunctionBase;
import jsat.math.optimization.GoldenSearch;
import jsat.utils.DoubleList;
import jsat.utils.ListUtils;

/* loaded from: input_file:jsat/distributions/kernels/KernelPoint.class */
public class KernelPoint {
    protected KernelTrick k;
    private double errorTolerance;
    protected List<Vec> vecs;
    protected List<Double> kernelAccel;
    protected Matrix K;
    protected Matrix InvK;
    protected Matrix KExpanded;
    protected Matrix InvKExpanded;
    protected DoubleList alpha;
    protected BudgetStrategy budgetStrategy;
    protected int maxBudget;
    private double sqrdNorm;
    private boolean normGood;

    /* loaded from: input_file:jsat/distributions/kernels/KernelPoint$BudgetStrategy.class */
    public enum BudgetStrategy {
        PROJECTION,
        MERGE_RBF,
        STOP,
        RANDOM
    }

    public KernelPoint(KernelTrick kernelTrick, double d) {
        this.budgetStrategy = BudgetStrategy.PROJECTION;
        this.maxBudget = Integer.MAX_VALUE;
        this.sqrdNorm = 0.0d;
        this.normGood = true;
        this.k = kernelTrick;
        setErrorTolerance(d);
        setBudgetStrategy(BudgetStrategy.PROJECTION);
        setMaxBudget(Integer.MAX_VALUE);
        if (kernelTrick.supportsAcceleration()) {
            this.kernelAccel = new DoubleList(16);
        }
        this.alpha = new DoubleList(16);
        this.vecs = new ArrayList(16);
    }

    public KernelPoint(KernelPoint kernelPoint) {
        this.budgetStrategy = BudgetStrategy.PROJECTION;
        this.maxBudget = Integer.MAX_VALUE;
        this.sqrdNorm = 0.0d;
        this.normGood = true;
        this.k = kernelPoint.k.mo144clone();
        this.errorTolerance = kernelPoint.errorTolerance;
        if (kernelPoint.vecs != null) {
            this.vecs = new ArrayList(kernelPoint.vecs.size());
            Iterator<Vec> it = kernelPoint.vecs.iterator();
            while (it.hasNext()) {
                this.vecs.add(it.next().mo45clone());
            }
            if (kernelPoint.kernelAccel != null) {
                this.kernelAccel = new DoubleList(kernelPoint.kernelAccel);
            }
            this.alpha = new DoubleList(kernelPoint.alpha);
        }
        if (kernelPoint.KExpanded != null) {
            this.KExpanded = kernelPoint.KExpanded.mo161clone();
            this.InvKExpanded = kernelPoint.InvKExpanded.mo161clone();
            this.K = new SubMatrix(this.KExpanded, 0, 0, kernelPoint.K.rows(), kernelPoint.K.cols());
            this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, kernelPoint.InvK.rows(), kernelPoint.InvK.rows());
        }
        this.maxBudget = kernelPoint.maxBudget;
        this.sqrdNorm = kernelPoint.sqrdNorm;
        this.normGood = kernelPoint.normGood;
    }

    public void setMaxBudget(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Budget must be positive, not " + i);
        }
        this.maxBudget = i;
    }

    public int getMaxBudget() {
        return this.maxBudget;
    }

    public void setBudgetStrategy(BudgetStrategy budgetStrategy) {
        if (getBasisSize() > 0) {
            throw new RuntimeException("KerenlPoint already started, budget may not be changed");
        }
        this.budgetStrategy = budgetStrategy;
    }

    public BudgetStrategy getBudgetStrategy() {
        return this.budgetStrategy;
    }

    public void setErrorTolerance(double d) {
        if (Double.isNaN(d) || d < 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("Error tolerance must be in [0, 1], not " + d);
        }
        this.errorTolerance = d;
    }

    public double getErrorTolerance() {
        return this.errorTolerance;
    }

    public double getSqrdNorm() {
        if (!this.normGood) {
            this.sqrdNorm = 0.0d;
            for (int i = 0; i < this.alpha.size(); i++) {
                if (this.K != null) {
                    this.sqrdNorm += this.alpha.get(i).doubleValue() * this.alpha.get(i).doubleValue() * this.K.get(i, i);
                    for (int i2 = i + 1; i2 < this.alpha.size(); i2++) {
                        this.sqrdNorm += 2.0d * this.alpha.get(i).doubleValue() * this.alpha.get(i2).doubleValue() * this.K.get(i, i2);
                    }
                } else {
                    this.sqrdNorm += this.alpha.get(i).doubleValue() * this.alpha.get(i).doubleValue() * this.k.eval(i, i, this.vecs, this.kernelAccel);
                    for (int i3 = i + 1; i3 < this.alpha.size(); i3++) {
                        this.sqrdNorm += 2.0d * this.alpha.get(i).doubleValue() * this.alpha.get(i3).doubleValue() * this.k.eval(i, i3, this.vecs, this.kernelAccel);
                    }
                }
            }
            this.normGood = true;
        }
        return this.sqrdNorm;
    }

    public double dot(Vec vec) {
        return dot(vec, this.k.getQueryInfo(vec));
    }

    public double dot(Vec vec, List<Double> list) {
        if (getBasisSize() == 0) {
            return 0.0d;
        }
        return this.k.evalSum(this.vecs, this.kernelAccel, this.alpha.getBackingArray(), vec, list, 0, this.alpha.size());
    }

    public double dot(KernelPoint kernelPoint) {
        if (getBasisSize() == 0 || kernelPoint.getBasisSize() == 0) {
            return 0.0d;
        }
        int size = this.alpha.size();
        List<? extends Vec> mergedView = ListUtils.mergedView(this.vecs, kernelPoint.vecs);
        List<Double> mergedView2 = (this.kernelAccel == null || kernelPoint.kernelAccel == null) ? null : ListUtils.mergedView(this.kernelAccel, kernelPoint.kernelAccel);
        double d = 0.0d;
        for (int i = 0; i < this.alpha.size(); i++) {
            for (int i2 = 0; i2 < kernelPoint.alpha.size(); i2++) {
                d += this.alpha.get(i).doubleValue() * kernelPoint.alpha.get(i2).doubleValue() * this.k.eval(i, i2 + size, mergedView, mergedView2);
            }
        }
        return d;
    }

    public double dist(Vec vec) {
        return dist(vec, this.k.getQueryInfo(vec));
    }

    public double dist(Vec vec, List<Double> list) {
        return Math.sqrt((this.k.eval(0, 0, Arrays.asList(vec), list) + getSqrdNorm()) - (2.0d * dot(vec, list)));
    }

    public double dist(KernelPoint kernelPoint) {
        if (this == kernelPoint) {
            return 0.0d;
        }
        return Math.sqrt(Math.max(0.0d, (getSqrdNorm() + kernelPoint.getSqrdNorm()) - (2.0d * dot(kernelPoint))));
    }

    public void mutableMultiply(double d) {
        if (Double.isNaN(d) || Double.isInfinite(d)) {
            throw new IllegalArgumentException("multiplier must be a real value, not " + d);
        }
        if (getBasisSize() == 0) {
            return;
        }
        this.sqrdNorm *= d * d;
        this.alpha.getVecView().mutableMultiply(d);
    }

    public void mutableAdd(Vec vec) {
        mutableAdd(1.0d, vec);
    }

    public void mutableAdd(double d, Vec vec) {
        mutableAdd(d, vec, this.k.getQueryInfo(vec));
    }

    public void mutableAdd(double d, Vec vec, List<Double> list) {
        if (d == 0.0d) {
            return;
        }
        this.normGood = false;
        double eval = this.k.eval(0, 0, Arrays.asList(vec), list);
        if (this.budgetStrategy == BudgetStrategy.PROJECTION) {
            if (this.K == null) {
                this.KExpanded = new DenseMatrix(16, 16);
                this.K = new SubMatrix(this.KExpanded, 0, 0, 1, 1);
                this.K.set(0, 0, eval);
                this.InvKExpanded = new DenseMatrix(16, 16);
                this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, 1, 1);
                this.InvK.set(0, 0, 1.0d / eval);
                this.alpha.add(d);
                this.vecs.add(vec);
                if (this.kernelAccel != null) {
                    this.kernelAccel.addAll(list);
                    return;
                }
                return;
            }
            DenseVector denseVector = new DenseVector(this.K.rows());
            for (int i = 0; i < denseVector.length(); i++) {
                denseVector.set(i, this.k.eval(i, vec, list, this.vecs, this.kernelAccel));
            }
            Vec multiply = this.InvK.multiply(denseVector);
            double dot = eval - multiply.dot(denseVector);
            int rows = this.K.rows();
            if (dot <= this.errorTolerance || rows >= this.maxBudget) {
                this.alpha.getVecView().mutableAdd(d, multiply);
                this.normGood = false;
                return;
            }
            this.vecs.add(vec);
            if (this.kernelAccel != null) {
                this.kernelAccel.addAll(list);
            }
            if (rows == this.KExpanded.rows()) {
                this.KExpanded.changeSize(rows * 2, rows * 2);
                this.InvKExpanded.changeSize(rows * 2, rows * 2);
            }
            Matrix.OuterProductUpdate(this.InvK, multiply, multiply, 1.0d / dot);
            this.K = new SubMatrix(this.KExpanded, 0, 0, rows + 1, rows + 1);
            this.InvK = new SubMatrix(this.InvKExpanded, 0, 0, rows + 1, rows + 1);
            for (int i2 = 0; i2 < rows; i2++) {
                this.K.set(rows, i2, denseVector.get(i2));
                this.K.set(i2, rows, denseVector.get(i2));
                this.InvK.set(rows, i2, (-multiply.get(i2)) / dot);
                this.InvK.set(i2, rows, (-multiply.get(i2)) / dot);
            }
            this.K.set(rows, rows, eval);
            this.InvK.set(rows, rows, 1.0d / dot);
            this.alpha.add(d);
            return;
        }
        if (this.budgetStrategy != BudgetStrategy.MERGE_RBF) {
            if (this.budgetStrategy == BudgetStrategy.STOP) {
                this.normGood = false;
                if (getBasisSize() < this.maxBudget) {
                    addPoint(vec, list, d);
                    return;
                }
                return;
            }
            if (this.budgetStrategy != BudgetStrategy.RANDOM) {
                throw new RuntimeException("BUG: report me!");
            }
            this.normGood = false;
            if (getBasisSize() >= this.maxBudget) {
                removeIndex(new Random().nextInt(this.vecs.size()));
            }
            addPoint(vec, list, d);
            return;
        }
        this.normGood = false;
        addPoint(vec, list, d);
        if (this.vecs.size() <= this.maxBudget) {
            return;
        }
        int i3 = 0;
        double abs = Math.abs(this.alpha.get(0).doubleValue());
        for (int i4 = 1; i4 < this.alpha.size(); i4++) {
            if (Math.abs(this.alpha.getD(i4)) < Math.abs(abs)) {
                abs = this.alpha.getD(i4);
                i3 = i4;
            }
        }
        double d2 = Double.POSITIVE_INFINITY;
        int i5 = -1;
        double d3 = 0.0d;
        double d4 = 0.0d;
        double d5 = 0.001d;
        while (true) {
            double d6 = d5;
            if (i5 != -1) {
                Vec multiply2 = this.vecs.get(i3).multiply(d3);
                multiply2.mutableAdd(1.0d - d3, this.vecs.get(i5));
                finalMergeStep(i3, i5, multiply2, this.k.getQueryInfo(multiply2), d4, true);
                return;
            }
            for (int i6 = 0; i6 < this.alpha.size(); i6++) {
                if (i6 != i3) {
                    double d7 = abs;
                    double d8 = this.alpha.getD(i6);
                    double d9 = d7 + d8;
                    if (Math.abs(d9) >= d6) {
                        double eval2 = this.k.eval(i6, i3, this.vecs, this.kernelAccel);
                        double h = getH(eval2, d7 / d9, d8 / d9);
                        double pow = (d7 * Math.pow(eval2, (1.0d - h) * (1.0d - h))) + (d8 * Math.pow(eval2, h * h));
                        double d10 = (((d7 * d7) + (d8 * d8)) + (((2.0d * eval2) * d7) * d8)) - (pow * pow);
                        if (d10 < d2) {
                            d2 = d10;
                            i5 = i6;
                            d3 = h;
                            d4 = pow;
                        }
                    }
                }
            }
            d5 = d6 / 10.0d;
        }
    }

    private void addPoint(Vec vec, List<Double> list, double d) {
        this.vecs.add(vec);
        if (this.kernelAccel != null) {
            this.kernelAccel.addAll(list);
        }
        this.alpha.add(d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void finalMergeStep(int i, int i2, Vec vec, List<Double> list, double d, boolean z) {
        int min = Math.min(i, i2);
        int max = Math.max(i, i2);
        this.alpha.remove(max);
        this.alpha.remove(min);
        if (z) {
            this.vecs.remove(max);
            this.vecs.remove(min);
            this.kernelAccel.remove(max);
            this.kernelAccel.remove(min);
            this.vecs.add(vec);
            this.kernelAccel.addAll(list);
        }
        this.alpha.add(d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static double getH(final double d, final double d2, final double d3) {
        if (d2 == d3) {
            return 0.5d;
        }
        FunctionBase functionBase = new FunctionBase() { // from class: jsat.distributions.kernels.KernelPoint.1
            private static final long serialVersionUID = -6891301465754898634L;

            @Override // jsat.math.Function
            public double f(Vec vec) {
                double d4 = vec.get(0);
                return -((d2 * FastMath.pow(d, (1.0d - d4) * (1.0d - d4))) + (d3 * FastMath.pow(d, d4 * d4)));
            }
        };
        if (Math.signum(d2) != Math.signum(d3)) {
            if (d2 < 0.0d) {
                return GoldenSearch.minimize(0.001d, 100, 0.0d, 0.2d, 0, functionBase, 0.0d);
            }
            if (d3 < 0.0d) {
                return GoldenSearch.minimize(0.001d, 100, 0.8d, 1.0d, 0, functionBase, 0.0d);
            }
        }
        return d2 > d3 ? GoldenSearch.minimize(0.001d, 100, 0.5d, 1.0d, 0, functionBase, 0.0d) : GoldenSearch.minimize(0.001d, 100, 0.0d, 0.5d, 0, functionBase, 0.0d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void removeIndex(int i) {
        if (this.kernelAccel != null) {
            int size = this.kernelAccel.size() / this.vecs.size();
            for (int i2 = 0; i2 < size; i2++) {
                this.kernelAccel.remove(i);
            }
        }
        this.alpha.remove(i);
        this.vecs.remove(i);
    }

    public int getBasisSize() {
        if (this.vecs == null) {
            return 0;
        }
        return this.vecs.size();
    }

    public List<Vec> getRawBasisVecs() {
        return Collections.unmodifiableList(this.vecs);
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public KernelPoint m146clone() {
        return new KernelPoint(this);
    }
}
