package jsat.linear;

import java.io.Serializable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.utils.SystemInfo;

/* loaded from: input_file:jsat/linear/CholeskyDecomposition.class */
public class CholeskyDecomposition implements Serializable {
    private static final long serialVersionUID = 8925094456733750112L;
    private Matrix L;

    public CholeskyDecomposition(Matrix matrix) {
        if (!matrix.isSquare()) {
            throw new ArithmeticException("Input matrix must be symmetric positive definite");
        }
        this.L = matrix;
        int rows = matrix.rows();
        for (int i = 0; i < rows; i++) {
            double computeLJJ = computeLJJ(matrix, i);
            this.L.set(i, i, computeLJJ);
            updateRows(i, i + 1, rows, 1, matrix, computeLJJ);
        }
        copyUpperToLower(rows);
    }

    public CholeskyDecomposition(final Matrix matrix, ExecutorService executorService) {
        if (!matrix.isSquare()) {
            throw new ArithmeticException("Input matrix must be symmetric positive definite");
        }
        this.L = matrix;
        final int rows = matrix.rows();
        double computeLJJ = computeLJJ(matrix, 0);
        for (int i = 0; i < rows; i++) {
            final int i2 = i;
            final double d = computeLJJ;
            this.L.set(i, i, d);
            final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores - 1);
            for (int i3 = 1; i3 < SystemInfo.LogicalCores; i3++) {
                final int i4 = i3;
                executorService.submit(new Runnable() { // from class: jsat.linear.CholeskyDecomposition.1
                    @Override // java.lang.Runnable
                    public void run() {
                        CholeskyDecomposition.this.updateRows(i2, i2 + 1 + i4, rows, SystemInfo.LogicalCores, matrix, d);
                        countDownLatch.countDown();
                    }
                });
            }
            try {
                updateRows(i2, i2 + 1, rows, SystemInfo.LogicalCores, matrix, d);
                computeLJJ = i + 1 < rows ? computeLJJ(matrix, i + 1) : computeLJJ;
                countDownLatch.await();
            } catch (InterruptedException e) {
                Logger.getLogger(CholeskyDecomposition.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            }
        }
        copyUpperToLower(rows);
    }

    public Matrix getLT() {
        DenseMatrix denseMatrix = new DenseMatrix(this.L.rows(), this.L.cols());
        for (int i = 0; i < this.L.rows(); i++) {
            for (int i2 = i; i2 < this.L.rows(); i2++) {
                denseMatrix.set(i, i2, this.L.get(i, i2));
            }
        }
        return denseMatrix;
    }

    public Vec solve(Vec vec) {
        return LUPDecomposition.backSub(this.L, LUPDecomposition.forwardSub(this.L, vec));
    }

    public Matrix solve(Matrix matrix) {
        return LUPDecomposition.backSub(this.L, LUPDecomposition.forwardSub(this.L, matrix));
    }

    public Matrix solve(Matrix matrix, ExecutorService executorService) {
        return LUPDecomposition.backSub(this.L, LUPDecomposition.forwardSub(this.L, matrix, executorService), executorService);
    }

    public double getDet() {
        double d = 1.0d;
        for (int i = 0; i < this.L.rows(); i++) {
            d *= this.L.get(i, i);
        }
        return d;
    }

    private double computeLJJ(Matrix matrix, int i) {
        double d = matrix.get(i, i);
        for (int i2 = 0; i2 < i; i2++) {
            d -= Math.pow(this.L.get(i, i2), 2.0d);
        }
        double sqrt = Math.sqrt(d);
        if (Double.isNaN(sqrt)) {
            throw new ArithmeticException("input matrix is not positive definite");
        }
        return sqrt;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public void updateRows(int i, int i2, int i3, int i4, Matrix matrix, double d) {
        int i5 = i2;
        while (true) {
            int i6 = i5;
            if (i6 >= i3) {
                return;
            }
            double d2 = matrix.get(i6, i);
            for (int i7 = 0; i7 < i; i7++) {
                d2 -= this.L.get(i6, i7) * this.L.get(i, i7);
            }
            this.L.set(i6, i, d2 / d);
            i5 = i6 + i4;
        }
    }

    private void copyUpperToLower(int i) {
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < i2; i3++) {
                this.L.set(i3, i2, this.L.get(i2, i3));
            }
        }
    }
}
