package jsat.parameters;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.PriorityQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.ClassificationModelEvaluation;
import jsat.classifiers.Classifier;
import jsat.distributions.Distribution;
import jsat.exceptions.FailedToFitException;
import jsat.regression.RegressionDataSet;
import jsat.regression.RegressionModelEvaluation;
import jsat.regression.Regressor;
import jsat.utils.FakeExecutor;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:jsat/parameters/RandomSearch.class */
public class RandomSearch extends ModelSearch {
    private int trials;
    private List<Distribution> searchValues;

    public RandomSearch(Regressor regressor, int i) {
        super(regressor, i);
        this.trials = 25;
        this.searchValues = new ArrayList();
    }

    public RandomSearch(Classifier classifier, int i) {
        super(classifier, i);
        this.trials = 25;
        this.searchValues = new ArrayList();
    }

    public RandomSearch(RandomSearch randomSearch) {
        super(randomSearch);
        this.trials = 25;
        this.trials = randomSearch.trials;
        this.searchValues = new ArrayList(randomSearch.searchValues.size());
        Iterator<Distribution> it = randomSearch.searchValues.iterator();
        while (it.hasNext()) {
            this.searchValues.add(it.next().mo136clone());
        }
    }

    public int autoAddParameters(DataSet dataSet) {
        Distribution guess;
        int i = 0;
        for (Parameter parameter : (this.baseClassifier != null ? (Parameterized) this.baseClassifier : (Parameterized) this.baseRegressor).getParameters()) {
            if (parameter instanceof DoubleParameter) {
                Distribution guess2 = ((DoubleParameter) parameter).getGuess(dataSet);
                if (guess2 != null) {
                    addParameter((DoubleParameter) parameter, guess2);
                    i++;
                }
            } else if ((parameter instanceof IntParameter) && (guess = ((IntParameter) parameter).getGuess(dataSet)) != null) {
                addParameter((IntParameter) parameter, guess);
                i++;
            }
        }
        return i;
    }

    public void setTrials(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("number of trials must be positive, not " + i);
        }
        this.trials = i;
    }

    public int getTrials() {
        return this.trials;
    }

    public void addParameter(DoubleParameter doubleParameter, Distribution distribution) {
        if (doubleParameter == null) {
            throw new IllegalArgumentException("null not allowed for parameter");
        }
        this.searchParams.add(doubleParameter);
        this.searchValues.add(distribution.mo136clone());
    }

    public void addParameter(IntParameter intParameter, Distribution distribution) {
        if (intParameter == null) {
            throw new IllegalArgumentException("null not allowed for parameter");
        }
        this.searchParams.add(intParameter);
        this.searchValues.add(distribution.mo136clone());
    }

    public void addParameter(String str, Distribution distribution) {
        Parameter parameterByName = getParameterByName(str);
        if (parameterByName instanceof DoubleParameter) {
            addParameter((DoubleParameter) parameterByName, distribution);
        } else {
            if (!(parameterByName instanceof IntParameter)) {
                throw new IllegalArgumentException("Parameter " + str + " is not for double or int values");
            }
            addParameter((IntParameter) parameterByName, distribution);
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(final ClassificationDataSet classificationDataSet, final ExecutorService executorService) {
        List<ClassificationDataSet> list;
        ArrayList arrayList;
        final PriorityQueue priorityQueue = new PriorityQueue(this.folds, new Comparator<ClassificationModelEvaluation>() { // from class: jsat.parameters.RandomSearch.1
            @Override // java.util.Comparator
            public int compare(ClassificationModelEvaluation classificationModelEvaluation, ClassificationModelEvaluation classificationModelEvaluation2) {
                return (RandomSearch.this.classificationTargetScore.lowerIsBetter() ? 1 : -1) * Double.compare(classificationModelEvaluation.getScoreStats(RandomSearch.this.classificationTargetScore).getMean(), classificationModelEvaluation2.getScoreStats(RandomSearch.this.classificationTargetScore).getMean());
            }
        });
        ArrayList<Classifier> arrayList2 = new ArrayList();
        XORWOW xorwow = new XORWOW();
        for (int i = 0; i < this.trials; i++) {
            for (int i2 = 0; i2 < this.searchParams.size(); i2++) {
                double invCdf = this.searchValues.get(i2).invCdf(xorwow.nextDouble());
                Parameter parameter = this.searchParams.get(i2);
                if (parameter instanceof DoubleParameter) {
                    ((DoubleParameter) parameter).setValue(invCdf);
                } else if (parameter instanceof IntParameter) {
                    ((IntParameter) parameter).setValue((int) Math.round(invCdf));
                }
            }
            arrayList2.add(this.baseClassifier.mo234clone());
        }
        ExecutorService fakeExecutor = (!this.trainModelsInParallel || executorService == null) ? new FakeExecutor() : executorService;
        if (this.reuseSameCVFolds) {
            list = classificationDataSet.cvSet(this.folds);
            arrayList = new ArrayList(list.size());
            for (int i3 = 0; i3 < list.size(); i3++) {
                arrayList.add(ClassificationDataSet.comineAllBut(list, i3));
            }
        } else {
            list = null;
            arrayList = null;
        }
        final CountDownLatch countDownLatch = new CountDownLatch(arrayList2.size());
        for (final Classifier classifier : arrayList2) {
            final List<ClassificationDataSet> list2 = list;
            final ArrayList arrayList3 = arrayList;
            fakeExecutor.submit(new Runnable() { // from class: jsat.parameters.RandomSearch.2
                @Override // java.lang.Runnable
                public void run() {
                    ClassificationModelEvaluation classificationModelEvaluation = RandomSearch.this.trainModelsInParallel ? new ClassificationModelEvaluation(classifier, classificationDataSet) : new ClassificationModelEvaluation(classifier, classificationDataSet, executorService);
                    classificationModelEvaluation.addScorer(RandomSearch.this.classificationTargetScore.m34clone());
                    if (RandomSearch.this.reuseSameCVFolds) {
                        classificationModelEvaluation.evaluateCrossValidation(list2, arrayList3);
                    } else {
                        classificationModelEvaluation.evaluateCrossValidation(RandomSearch.this.folds);
                    }
                    synchronized (priorityQueue) {
                        priorityQueue.add(classificationModelEvaluation);
                    }
                    countDownLatch.countDown();
                }
            });
        }
        try {
            countDownLatch.await();
            Classifier classifier2 = ((ClassificationModelEvaluation) priorityQueue.peek()).getClassifier();
            if (this.trainFinalModel) {
                if (executorService instanceof FakeExecutor) {
                    classifier2.trainC(classificationDataSet);
                } else {
                    classifier2.trainC(classificationDataSet, executorService);
                }
            }
            this.trainedClassifier = classifier2;
        } catch (InterruptedException e) {
            throw new FailedToFitException(e);
        }
    }

    @Override // jsat.classifiers.Classifier
    public void trainC(ClassificationDataSet classificationDataSet) {
        trainC(classificationDataSet, null);
    }

    @Override // jsat.regression.Regressor
    public void train(final RegressionDataSet regressionDataSet, final ExecutorService executorService) {
        List<RegressionDataSet> list;
        ArrayList arrayList;
        final PriorityQueue priorityQueue = new PriorityQueue(this.folds, new Comparator<RegressionModelEvaluation>() { // from class: jsat.parameters.RandomSearch.3
            @Override // java.util.Comparator
            public int compare(RegressionModelEvaluation regressionModelEvaluation, RegressionModelEvaluation regressionModelEvaluation2) {
                return (RandomSearch.this.regressionTargetScore.lowerIsBetter() ? 1 : -1) * Double.compare(regressionModelEvaluation.getScoreStats(RandomSearch.this.regressionTargetScore).getMean(), regressionModelEvaluation2.getScoreStats(RandomSearch.this.regressionTargetScore).getMean());
            }
        });
        ArrayList<Regressor> arrayList2 = new ArrayList();
        XORWOW xorwow = new XORWOW();
        for (int i = 0; i < this.trials; i++) {
            for (int i2 = 0; i2 < this.searchParams.size(); i2++) {
                double invCdf = this.searchValues.get(i2).invCdf(xorwow.nextDouble());
                Parameter parameter = this.searchParams.get(i2);
                if (parameter instanceof DoubleParameter) {
                    ((DoubleParameter) parameter).setValue(invCdf);
                } else if (parameter instanceof IntParameter) {
                    ((IntParameter) parameter).setValue((int) Math.round(invCdf));
                }
            }
            arrayList2.add(this.baseRegressor.mo234clone());
        }
        ExecutorService fakeExecutor = (!this.trainModelsInParallel || executorService == null) ? new FakeExecutor() : executorService;
        if (this.reuseSameCVFolds) {
            list = regressionDataSet.cvSet(this.folds);
            arrayList = new ArrayList(list.size());
            for (int i3 = 0; i3 < list.size(); i3++) {
                arrayList.add(RegressionDataSet.comineAllBut(list, i3));
            }
        } else {
            list = null;
            arrayList = null;
        }
        final CountDownLatch countDownLatch = new CountDownLatch(arrayList2.size());
        for (final Regressor regressor : arrayList2) {
            final List<RegressionDataSet> list2 = list;
            final ArrayList arrayList3 = arrayList;
            fakeExecutor.submit(new Runnable() { // from class: jsat.parameters.RandomSearch.4
                @Override // java.lang.Runnable
                public void run() {
                    RegressionModelEvaluation regressionModelEvaluation = RandomSearch.this.trainModelsInParallel ? new RegressionModelEvaluation(regressor, regressionDataSet) : new RegressionModelEvaluation(regressor, regressionDataSet, executorService);
                    regressionModelEvaluation.addScorer(RandomSearch.this.regressionTargetScore.m242clone());
                    if (RandomSearch.this.reuseSameCVFolds) {
                        regressionModelEvaluation.evaluateCrossValidation(list2, arrayList3);
                    } else {
                        regressionModelEvaluation.evaluateCrossValidation(RandomSearch.this.folds);
                    }
                    synchronized (priorityQueue) {
                        priorityQueue.add(regressionModelEvaluation);
                    }
                    countDownLatch.countDown();
                }
            });
        }
        try {
            countDownLatch.await();
            Regressor regressor2 = ((RegressionModelEvaluation) priorityQueue.peek()).getRegressor();
            if (this.trainFinalModel) {
                if (executorService instanceof FakeExecutor) {
                    regressor2.train(regressionDataSet);
                } else {
                    regressor2.train(regressionDataSet, executorService);
                }
            }
            this.trainedRegressor = regressor2;
        } catch (InterruptedException e) {
            throw new FailedToFitException(e);
        }
    }

    @Override // jsat.regression.Regressor
    public void train(RegressionDataSet regressionDataSet) {
        train(regressionDataSet, null);
    }

    @Override // jsat.parameters.ModelSearch
    /* renamed from: clone */
    public RandomSearch mo234clone() {
        return new RandomSearch(this);
    }
}
