package jsat.datatransform.visualization;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.logging.Level;
import java.util.logging.Logger;
import jsat.DataSet;
import jsat.SimpleDataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.linear.DenseMatrix;
import jsat.linear.Matrix;
import jsat.linear.Vec;
import jsat.linear.distancemetrics.DistanceMetric;
import jsat.linear.distancemetrics.EuclideanDistance;
import jsat.math.OnLineStatistics;
import jsat.utils.FakeExecutor;
import jsat.utils.SystemInfo;
import jsat.utils.concurrent.AtomicDouble;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:jsat/datatransform/visualization/MDS.class */
public class MDS implements VisualizationTransform {
    private static DistanceMetric embedMetric = new EuclideanDistance();
    private DistanceMetric dm = new EuclideanDistance();
    private double tolerance = 0.001d;
    private int maxIterations = 300;
    private int targetSize = 2;

    public void setTolerance(double d) {
        if (d < 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("tolerance must be a non-negative value, not " + d);
        }
        this.tolerance = d;
    }

    public double getTolerance() {
        return this.tolerance;
    }

    public void setEmbeddingMetric(DistanceMetric distanceMetric) {
        embedMetric = distanceMetric;
    }

    public DistanceMetric getEmbeddingMetric() {
        return embedMetric;
    }

    @Override // jsat.datatransform.visualization.VisualizationTransform
    public <Type extends DataSet> Type transform(DataSet<Type> dataSet) {
        return (Type) transform(dataSet, new FakeExecutor());
    }

    @Override // jsat.datatransform.visualization.VisualizationTransform
    public <Type extends DataSet> Type transform(final DataSet<Type> dataSet, ExecutorService executorService) {
        final List<Vec> dataVectors = dataSet.getDataVectors();
        final List<Double> accelerationCache = this.dm.getAccelerationCache(dataVectors, executorService);
        int size = dataVectors.size();
        final DenseMatrix denseMatrix = new DenseMatrix(size, size);
        OnLineStatistics onLineStatistics = new OnLineStatistics();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < SystemInfo.LogicalCores; i++) {
            final int i2 = i;
            arrayList.add(executorService.submit(new Callable<OnLineStatistics>() { // from class: jsat.datatransform.visualization.MDS.1
                /* JADX WARN: Can't rename method to resolve collision */
                @Override // java.util.concurrent.Callable
                public OnLineStatistics call() throws Exception {
                    OnLineStatistics onLineStatistics2 = new OnLineStatistics();
                    int i3 = i2;
                    while (true) {
                        int i4 = i3;
                        if (i4 >= dataSet.getSampleSize()) {
                            return onLineStatistics2;
                        }
                        for (int i5 = i4 + 1; i5 < dataSet.getSampleSize(); i5++) {
                            double dist = MDS.this.dm.dist(i4, i5, dataVectors, accelerationCache);
                            onLineStatistics2.add(dist);
                            denseMatrix.set(i4, i5, dist);
                            denseMatrix.set(i5, i4, dist);
                        }
                        i3 = i4 + SystemInfo.LogicalCores;
                    }
                }
            }));
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            try {
                onLineStatistics.add((OnLineStatistics) ((Future) it.next()).get());
            } catch (InterruptedException e) {
                Logger.getLogger(MDS.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            } catch (ExecutionException e2) {
                Logger.getLogger(MDS.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e2);
            }
        }
        SimpleDataSet transform = transform(denseMatrix, executorService);
        DataSet<Type> shallowClone2 = dataSet.shallowClone2();
        shallowClone2.replaceNumericFeatures(transform.getDataVectors());
        return shallowClone2;
    }

    public SimpleDataSet transform(Matrix matrix) {
        return transform(matrix, new FakeExecutor());
    }

    public SimpleDataSet transform(final Matrix matrix, ExecutorService executorService) {
        int rows = matrix.rows();
        XORWOW xorwow = new XORWOW();
        DenseMatrix denseMatrix = new DenseMatrix(rows, this.targetSize);
        final ArrayList arrayList = new ArrayList();
        for (int i = 0; i < rows; i++) {
            for (int i2 = 0; i2 < this.targetSize; i2++) {
                denseMatrix.set(i, i2, xorwow.nextDouble());
            }
            arrayList.add(denseMatrix.getRowView(i));
        }
        final List<Double> accelerationCache = embedMetric.getAccelerationCache(arrayList, executorService);
        DenseMatrix denseMatrix2 = new DenseMatrix(rows, rows);
        for (int i3 = 0; i3 < rows; i3++) {
            for (int i4 = 0; i4 < rows; i4++) {
                if (i3 == i4) {
                    denseMatrix2.set(i3, i4, (1.0d - (1.0d / rows)) / rows);
                } else {
                    denseMatrix2.set(i3, i4, (0.0d - (1.0d / rows)) / rows);
                }
            }
        }
        double d = Double.POSITIVE_INFINITY;
        double stress = stress(arrayList, accelerationCache, matrix, executorService);
        final DenseMatrix denseMatrix3 = new DenseMatrix(rows, rows);
        DenseMatrix denseMatrix4 = new DenseMatrix(denseMatrix.rows(), denseMatrix.cols());
        for (int i5 = 0; i5 < this.maxIterations && d > this.tolerance; i5++) {
            final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
            for (int i6 = 0; i6 < SystemInfo.LogicalCores; i6++) {
                final int i7 = i6;
                executorService.submit(new Runnable() { // from class: jsat.datatransform.visualization.MDS.2
                    @Override // java.lang.Runnable
                    public void run() {
                        int i8 = i7;
                        while (true) {
                            int i9 = i8;
                            if (i9 >= denseMatrix3.rows()) {
                                countDownLatch.countDown();
                                return;
                            }
                            for (int i10 = i9 + 1; i10 < denseMatrix3.rows(); i10++) {
                                double dist = MDS.embedMetric.dist(i9, i10, arrayList, accelerationCache);
                                if (dist > 1.0E-5d) {
                                    double d2 = (-matrix.get(i9, i10)) / dist;
                                    denseMatrix3.set(i9, i10, d2);
                                    denseMatrix3.set(i10, i9, d2);
                                } else {
                                    denseMatrix3.set(i9, i10, 0.0d);
                                    denseMatrix3.set(i10, i9, 0.0d);
                                }
                            }
                            i8 = i9 + SystemInfo.LogicalCores;
                        }
                    }
                });
            }
            denseMatrix4.zeroOut();
            try {
                countDownLatch.await();
            } catch (InterruptedException e) {
                Logger.getLogger(MDS.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
            }
            for (int i8 = 0; i8 < denseMatrix3.rows(); i8++) {
                denseMatrix3.set(i8, i8, 0.0d);
                for (int i9 = 0; i9 < denseMatrix3.cols(); i9++) {
                    if (i9 != i8) {
                        denseMatrix3.increment(i8, i8, -denseMatrix3.get(i8, i9));
                    }
                }
            }
            denseMatrix3.multiply(denseMatrix, denseMatrix4, executorService);
            denseMatrix4.mutableMultiply(1.0d / rows);
            denseMatrix4.copyTo(denseMatrix);
            accelerationCache.clear();
            accelerationCache.addAll(embedMetric.getAccelerationCache(arrayList, executorService));
            double stress2 = stress(arrayList, accelerationCache, matrix, executorService);
            d = Math.abs(stress - stress2);
            stress = stress2;
        }
        SimpleDataSet simpleDataSet = new SimpleDataSet(new CategoricalData[0], this.targetSize);
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            simpleDataSet.add(new DataPoint((Vec) it.next()));
        }
        return simpleDataSet;
    }

    private static double stress(final List<Vec> list, final List<Double> list2, final Matrix matrix, ExecutorService executorService) {
        final AtomicDouble atomicDouble = new AtomicDouble(0.0d);
        final CountDownLatch countDownLatch = new CountDownLatch(SystemInfo.LogicalCores);
        for (int i = 0; i < SystemInfo.LogicalCores; i++) {
            final int i2 = i;
            executorService.submit(new Runnable() { // from class: jsat.datatransform.visualization.MDS.3
                @Override // java.lang.Runnable
                public void run() {
                    double d = 0.0d;
                    int i3 = i2;
                    while (true) {
                        int i4 = i3;
                        if (i4 >= matrix.rows()) {
                            atomicDouble.addAndGet(d);
                            countDownLatch.countDown();
                            return;
                        }
                        for (int i5 = i4 + 1; i5 < matrix.rows(); i5++) {
                            double dist = MDS.embedMetric.dist(i4, i5, list, list2) - matrix.get(i4, i5);
                            d += dist * dist;
                        }
                        i3 = i4 + SystemInfo.LogicalCores;
                    }
                }
            });
        }
        try {
            countDownLatch.await();
        } catch (InterruptedException e) {
            Logger.getLogger(MDS.class.getName()).log(Level.SEVERE, (String) null, (Throwable) e);
        }
        return atomicDouble.get();
    }

    @Override // jsat.datatransform.visualization.VisualizationTransform
    public int getTargetDimension() {
        return this.targetSize;
    }

    @Override // jsat.datatransform.visualization.VisualizationTransform
    public boolean setTargetDimension(int i) {
        if (i < 1) {
            return false;
        }
        this.targetSize = i;
        return true;
    }
}
