package org.apache.commons.math4.ml.clustering;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.math4.analysis.integration.BaseAbstractUnivariateIntegrator;
import org.apache.commons.math4.exception.NumberIsTooSmallException;
import org.apache.commons.math4.ml.clustering.Clusterable;
import org.apache.commons.math4.ml.clustering.KMeansPlusPlusClusterer;
import org.apache.commons.math4.ml.distance.DistanceMeasure;
import org.apache.commons.math4.util.MathUtils;
import org.apache.commons.math4.util.Pair;
import org.apache.commons.rng.UniformRandomProvider;
import org.apache.commons.rng.sampling.ListSampler;

/* loaded from: input_file:org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer.class */
public class MiniBatchKMeansClusterer<T extends Clusterable> extends KMeansPlusPlusClusterer<T> {
    private final int batchSize;
    private final int initIterations;
    private final int initBatchSize;
    private final int maxNoImprovementTimes;

    /* loaded from: input_file:org/apache/commons/math4/ml/clustering/MiniBatchKMeansClusterer$ImprovementEvaluator.class */
    private static class ImprovementEvaluator {
        private final int batchSize;
        private final int maxNoImprovementTimes;
        private double ewaInertia;
        private double ewaInertiaMin;
        private int noImprovementTimes;

        private ImprovementEvaluator(int i, int i2) {
            this.ewaInertia = Double.NaN;
            this.ewaInertiaMin = Double.POSITIVE_INFINITY;
            this.batchSize = i;
            this.maxNoImprovementTimes = i2;
        }

        public boolean converge(double d, int i) {
            double d2 = d / this.batchSize;
            if (Double.isNaN(this.ewaInertia)) {
                this.ewaInertia = d2;
            } else {
                double min = Math.min((this.batchSize * 2) / (i + 1), 1);
                this.ewaInertia = (this.ewaInertia * (1.0d - min)) + (d2 * min);
            }
            if (this.ewaInertia < this.ewaInertiaMin) {
                this.noImprovementTimes = 0;
                this.ewaInertiaMin = this.ewaInertia;
            } else {
                this.noImprovementTimes++;
            }
            return this.noImprovementTimes >= this.maxNoImprovementTimes;
        }
    }

    public MiniBatchKMeansClusterer(int i, int i2, int i3, int i4, int i5, int i6, DistanceMeasure distanceMeasure, UniformRandomProvider uniformRandomProvider, KMeansPlusPlusClusterer.EmptyClusterStrategy emptyClusterStrategy) {
        super(i, i2, distanceMeasure, uniformRandomProvider, emptyClusterStrategy);
        if (i3 < 1) {
            throw new NumberIsTooSmallException(Integer.valueOf(i3), 1, true);
        }
        if (i4 < 1) {
            throw new NumberIsTooSmallException(Integer.valueOf(i4), 1, true);
        }
        if (i5 < 1) {
            throw new NumberIsTooSmallException(Integer.valueOf(i5), 1, true);
        }
        if (i6 < 1) {
            throw new NumberIsTooSmallException(Integer.valueOf(i6), 1, true);
        }
        this.batchSize = i3;
        this.initIterations = i4;
        this.initBatchSize = i5;
        this.maxNoImprovementTimes = i6;
    }

    @Override // org.apache.commons.math4.ml.clustering.KMeansPlusPlusClusterer, org.apache.commons.math4.ml.clustering.Clusterer
    public List<CentroidCluster<T>> cluster(Collection<T> collection) {
        MathUtils.checkNotNull(collection);
        if (collection.size() < getNumberOfClusters()) {
            throw new NumberIsTooSmallException(Integer.valueOf(collection.size()), Integer.valueOf(getNumberOfClusters()), false);
        }
        int size = collection.size();
        int maxIterations = getMaxIterations() < 0 ? BaseAbstractUnivariateIntegrator.DEFAULT_MAX_ITERATIONS_COUNT : getMaxIterations() * ((size / this.batchSize) + (size % this.batchSize > 0 ? 1 : 0));
        ArrayList arrayList = new ArrayList(collection);
        List<CentroidCluster<T>> initialCenters = initialCenters(arrayList);
        ImprovementEvaluator improvementEvaluator = new ImprovementEvaluator(this.batchSize, this.maxNoImprovementTimes);
        for (int i = 0; i < maxIterations; i++) {
            clearClustersPoints(initialCenters);
            Pair<Double, List<CentroidCluster<T>>> step = step(ListSampler.sample(getRandomGenerator(), arrayList, this.batchSize), initialCenters);
            double doubleValue = step.getFirst().doubleValue();
            initialCenters = step.getSecond();
            if (improvementEvaluator.converge(doubleValue, size)) {
                break;
            }
        }
        clearClustersPoints(initialCenters);
        Iterator<T> it = collection.iterator();
        while (it.hasNext()) {
            addToNearestCentroidCluster(it.next(), initialCenters);
        }
        return initialCenters;
    }

    private void clearClustersPoints(List<CentroidCluster<T>> list) {
        Iterator<CentroidCluster<T>> it = list.iterator();
        while (it.hasNext()) {
            it.next().getPoints().clear();
        }
    }

    private Pair<Double, List<CentroidCluster<T>>> step(List<T> list, List<CentroidCluster<T>> list2) {
        Iterator<T> it = list.iterator();
        while (it.hasNext()) {
            addToNearestCentroidCluster(it.next(), list2);
        }
        List<CentroidCluster<T>> adjustClustersCenters = adjustClustersCenters(list2);
        double d = 0.0d;
        Iterator<T> it2 = list.iterator();
        while (it2.hasNext()) {
            double addToNearestCentroidCluster = addToNearestCentroidCluster(it2.next(), adjustClustersCenters);
            d += addToNearestCentroidCluster * addToNearestCentroidCluster;
        }
        return new Pair<>(Double.valueOf(d), adjustClustersCenters);
    }

    private List<CentroidCluster<T>> initialCenters(List<T> list) {
        List<T> sample = this.initBatchSize < list.size() ? ListSampler.sample(getRandomGenerator(), list, this.initBatchSize) : new ArrayList<>(list);
        double d = Double.POSITIVE_INFINITY;
        List<CentroidCluster<T>> list2 = null;
        for (int i = 0; i < this.initIterations; i++) {
            Pair<Double, List<CentroidCluster<T>>> step = step(sample, chooseInitialCenters(this.initBatchSize < list.size() ? ListSampler.sample(getRandomGenerator(), list, this.initBatchSize) : new ArrayList<>(list)));
            double doubleValue = step.getFirst().doubleValue();
            List<CentroidCluster<T>> second = step.getSecond();
            if (doubleValue < d) {
                d = doubleValue;
                list2 = second;
            }
        }
        return list2;
    }

    private double addToNearestCentroidCluster(T t, List<CentroidCluster<T>> list) {
        double d = Double.POSITIVE_INFINITY;
        CentroidCluster<T> centroidCluster = null;
        for (CentroidCluster<T> centroidCluster2 : list) {
            double distance = distance(t, centroidCluster2.getCenter());
            if (distance < d) {
                d = distance;
                centroidCluster = centroidCluster2;
            }
        }
        MathUtils.checkNotNull(centroidCluster);
        centroidCluster.addPoint(t);
        return d;
    }
}
