package jsat.datatransform.featureselection;

import java.util.Random;
import java.util.Set;
import jsat.DataSet;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.Classifier;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransform;
import jsat.datatransform.RemoveAttributeTransform;
import jsat.regression.RegressionDataSet;
import jsat.regression.Regressor;
import jsat.utils.IntSet;
import jsat.utils.ListUtils;

/* loaded from: input_file:jsat/datatransform/featureselection/BDS.class */
public class BDS implements DataTransform {
    private static final long serialVersionUID = 8633823674617843754L;
    private RemoveAttributeTransform finalTransform;
    private Set<Integer> catSelected;
    private Set<Integer> numSelected;
    private int featureCount;
    private int folds;
    private Object evaluator;

    public BDS(BDS bds) {
        this.featureCount = bds.featureCount;
        this.folds = bds.folds;
        this.evaluator = bds.evaluator;
        if (bds.finalTransform != null) {
            this.finalTransform = bds.finalTransform.clone();
            this.catSelected = new IntSet(bds.catSelected);
            this.numSelected = new IntSet(bds.numSelected);
        }
    }

    public BDS(int i, Classifier classifier, int i2) {
        setFeatureCount(i);
        setFolds(i2);
        setEvaluator(classifier);
    }

    public BDS(int i, ClassificationDataSet classificationDataSet, Classifier classifier, int i2) {
        search(classificationDataSet, i, i2, classifier);
    }

    public BDS(int i, Regressor regressor, int i2) {
        setFeatureCount(i);
        setFolds(i2);
        setEvaluator(regressor);
    }

    public BDS(int i, RegressionDataSet regressionDataSet, Regressor regressor, int i2) {
        this(i, regressor, i2);
        search(regressionDataSet, i, i2, regressor);
    }

    @Override // jsat.datatransform.DataTransform
    public DataPoint transform(DataPoint dataPoint) {
        return this.finalTransform.transform(dataPoint);
    }

    @Override // jsat.datatransform.DataTransform
    public BDS clone() {
        return new BDS(this);
    }

    public Set<Integer> getSelectedCategorical() {
        return new IntSet(this.catSelected);
    }

    public Set<Integer> getSelectedNumerical() {
        return new IntSet(this.numSelected);
    }

    @Override // jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        search(dataSet, this.featureCount, this.folds, this.evaluator);
    }

    private void search(DataSet dataSet, int i, int i2, Object obj) {
        Random random = new Random();
        int numFeatures = dataSet.getNumFeatures();
        int numCategoricalVars = dataSet.getNumCategoricalVars();
        this.catSelected = new IntSet(dataSet.getNumCategoricalVars());
        this.numSelected = new IntSet(dataSet.getNumNumericalVars());
        IntSet intSet = new IntSet();
        ListUtils.addRange(intSet, 0, numFeatures, 1);
        IntSet intSet2 = new IntSet(dataSet.getNumCategoricalVars());
        IntSet intSet3 = new IntSet(dataSet.getNumNumericalVars());
        ListUtils.addRange(intSet2, 0, numCategoricalVars, 1);
        ListUtils.addRange(intSet3, 0, numFeatures - numCategoricalVars, 1);
        IntSet intSet4 = new IntSet();
        ListUtils.addRange(intSet4, 0, numFeatures, 1);
        IntSet intSet5 = new IntSet(dataSet.getNumCategoricalVars());
        IntSet intSet6 = new IntSet(dataSet.getNumNumericalVars());
        IntSet intSet7 = new IntSet(dataSet.getNumCategoricalVars());
        IntSet intSet8 = new IntSet(dataSet.getNumNumericalVars());
        ListUtils.addRange(intSet5, 0, numCategoricalVars, 1);
        ListUtils.addRange(intSet6, 0, numFeatures - numCategoricalVars, 1);
        double[] dArr = {Double.POSITIVE_INFINITY};
        double[] dArr2 = {Double.POSITIVE_INFINITY};
        int min = Math.min(i, numFeatures / 2);
        for (int i3 = 0; i3 < min; i3++) {
            int SFSSelectFeature = SFS.SFSSelectFeature(intSet, dataSet, intSet2, intSet3, this.catSelected, this.numSelected, obj, i2, random, dArr, min);
            intSet4.remove(Integer.valueOf(SFSSelectFeature));
            SFS.removeFeature(SFSSelectFeature, numCategoricalVars, intSet7, intSet8);
            int SBSRemoveFeature = SBS.SBSRemoveFeature(intSet4, dataSet, intSet7, intSet8, intSet5, intSet6, obj, i2, random, min, dArr2, 0.0d);
            intSet.remove(Integer.valueOf(SBSRemoveFeature));
            SFS.addFeature(SBSRemoveFeature, numCategoricalVars, intSet2, intSet3);
        }
        intSet5.clear();
        intSet8.clear();
        ListUtils.addRange(intSet5, 0, numCategoricalVars, 1);
        ListUtils.addRange(intSet6, 0, numFeatures - numCategoricalVars, 1);
        intSet5.removeAll(this.catSelected);
        intSet6.removeAll(this.numSelected);
        this.finalTransform = new RemoveAttributeTransform(dataSet, intSet5, intSet6);
    }

    public void setFeatureCount(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of features to select must be positive, not " + i);
        }
        this.featureCount = i;
    }

    public int getFeatureCount() {
        return this.featureCount;
    }

    public void setFolds(int i) {
        if (i <= 0) {
            throw new IllegalArgumentException("Number of CV folds must be positive, not " + i);
        }
        this.folds = i;
    }

    public int getFolds() {
        return this.folds;
    }

    private void setEvaluator(Object obj) {
        this.evaluator = obj;
    }
}
