package jsat.distributions.multivariate;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import jsat.classifiers.DataPoint;
import jsat.linear.CholeskyDecomposition;
import jsat.linear.DenseMatrix;
import jsat.linear.DenseVector;
import jsat.linear.LUPDecomposition;
import jsat.linear.Matrix;
import jsat.linear.MatrixStatistics;
import jsat.linear.SingularValueDecomposition;
import jsat.linear.Vec;

/* loaded from: input_file:jsat/distributions/multivariate/NormalM.class */
public class NormalM extends MultivariateDistributionSkeleton {
    private static final long serialVersionUID = -7043369396743253382L;
    private double logPDFConst;
    private Matrix invCovariance;
    private Vec mean;
    private Matrix L;

    public NormalM(Vec vec, Matrix matrix) {
        setMeanCovariance(vec, matrix);
    }

    public NormalM() {
    }

    public void setMeanCovariance(Vec vec, Matrix matrix) {
        if (!matrix.isSquare()) {
            throw new ArithmeticException("Covariance matrix must be square");
        }
        if (vec.length() != matrix.rows()) {
            throw new ArithmeticException("The mean vector and matrix must have the same dimension," + vec.length() + " does not match [" + matrix.rows() + ", " + matrix.rows() + "]");
        }
        this.mean = vec.mo45clone();
        setCovariance(matrix);
    }

    public void setCovariance(Matrix matrix) {
        if (!matrix.isSquare()) {
            throw new ArithmeticException("Covariance matrix must be square");
        }
        if (matrix.rows() != this.mean.length()) {
            throw new ArithmeticException("Covariance matrix does not agree with the mean");
        }
        this.L = new CholeskyDecomposition(matrix.mo161clone()).getLT();
        this.L.mutableTranspose();
        LUPDecomposition lUPDecomposition = new LUPDecomposition(matrix.mo161clone());
        int length = this.mean.length();
        double det = lUPDecomposition.det();
        if (!Double.isNaN(det) && det >= 1.0E-10d) {
            this.logPDFConst = (((-length) * Math.log(6.283185307179586d)) - Math.log(det)) * 0.5d;
            this.invCovariance = lUPDecomposition.solve(Matrix.eye(length));
        } else {
            SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(matrix.mo161clone());
            this.logPDFConst = 0.5d * Math.log(singularValueDecomposition.getPseudoDet() * Math.pow(6.283185307179586d, singularValueDecomposition.getRank()));
            this.invCovariance = singularValueDecomposition.getPseudoInverse();
        }
    }

    @Override // jsat.distributions.multivariate.MultivariateDistributionSkeleton, jsat.distributions.multivariate.MultivariateDistribution
    public double logPdf(Vec vec) {
        if (this.mean == null) {
            throw new ArithmeticException("No mean or variance set");
        }
        Vec subtract = vec.subtract(this.mean);
        return this.logPDFConst + (subtract.dot(this.invCovariance.multiply(subtract)) * (-0.5d));
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public double pdf(Vec vec) {
        double exp = Math.exp(logPdf(vec));
        if (Double.isInfinite(exp) || Double.isNaN(exp)) {
            return 0.0d;
        }
        return exp;
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public <V extends Vec> boolean setUsingData(List<V> list) {
        Vec vec = this.mean;
        try {
            Vec meanVector = MatrixStatistics.meanVector(list);
            Matrix covarianceMatrix = MatrixStatistics.covarianceMatrix(meanVector, list);
            this.mean = meanVector;
            setCovariance(covarianceMatrix);
            return true;
        } catch (ArithmeticException e) {
            this.mean = vec;
            return false;
        }
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public boolean setUsingDataList(List<DataPoint> list) {
        Vec vec = this.mean;
        try {
            DenseVector denseVector = new DenseVector(list.get(0).getNumericalValues().length());
            double d = 0.0d;
            double d2 = 0.0d;
            for (int i = 0; i < list.size(); i++) {
                DataPoint dataPoint = list.get(i);
                denseVector.mutableAdd(dataPoint.getWeight(), dataPoint.getNumericalValues());
                d += dataPoint.getWeight();
                d2 += Math.pow(dataPoint.getWeight(), 2.0d);
            }
            denseVector.mutableDivide(d);
            DenseMatrix denseMatrix = new DenseMatrix(denseVector.length(), denseVector.length());
            MatrixStatistics.covarianceMatrix(denseVector, list, denseMatrix, d, d2);
            this.mean = denseVector;
            setCovariance(denseMatrix);
            return true;
        } catch (ArithmeticException e) {
            this.mean = vec;
            return false;
        }
    }

    @Override // jsat.distributions.multivariate.MultivariateDistributionSkeleton
    /* renamed from: clone */
    public NormalM mo150clone() {
        NormalM normalM = new NormalM();
        if (this.invCovariance != null) {
            normalM.invCovariance = this.invCovariance.mo161clone();
        }
        if (this.mean != null) {
            normalM.mean = this.mean.mo45clone();
        }
        normalM.logPDFConst = this.logPDFConst;
        return normalM;
    }

    @Override // jsat.distributions.multivariate.MultivariateDistribution
    public List<Vec> sample(int i, Random random) {
        ArrayList arrayList = new ArrayList(i);
        DenseVector denseVector = new DenseVector(this.L.rows());
        for (int i2 = 0; i2 < i; i2++) {
            for (int i3 = 0; i3 < denseVector.length(); i3++) {
                denseVector.set(i3, random.nextGaussian());
            }
            Vec multiply = this.L.multiply(denseVector);
            multiply.mutableAdd(this.mean);
            arrayList.add(multiply);
        }
        return arrayList;
    }
}
