package jsat.datatransform.featureselection;

import java.util.Iterator;
import java.util.Set;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.ClassificationDataSet;
import jsat.classifiers.DataPoint;
import jsat.datatransform.RemoveAttributeTransform;
import jsat.exceptions.FailedToFitException;
import jsat.linear.IndexValue;
import jsat.utils.IndexTable;
import jsat.utils.IntSet;

/* loaded from: input_file:jsat/datatransform/featureselection/MutualInfoFS.class */
public class MutualInfoFS extends RemoveAttributeTransform {
    private static final long serialVersionUID = -4394620220403363542L;
    private int featureCount;
    private NumericalHandeling numericHandling;

    /* loaded from: input_file:jsat/datatransform/featureselection/MutualInfoFS$NumericalHandeling.class */
    public enum NumericalHandeling {
        NONE,
        BINARY
    }

    public MutualInfoFS() {
        this(100);
    }

    public MutualInfoFS(int i) {
        this(i, NumericalHandeling.BINARY);
    }

    public MutualInfoFS(ClassificationDataSet classificationDataSet, int i) {
        this(classificationDataSet, i, NumericalHandeling.BINARY);
    }

    protected MutualInfoFS(MutualInfoFS mutualInfoFS) {
        super(mutualInfoFS);
        this.featureCount = mutualInfoFS.featureCount;
        this.numericHandling = mutualInfoFS.numericHandling;
    }

    public MutualInfoFS(int i, NumericalHandeling numericalHandeling) {
        setFeatureCount(i);
        setHandling(numericalHandeling);
    }

    public MutualInfoFS(ClassificationDataSet classificationDataSet, int i, NumericalHandeling numericalHandeling) {
        this(i, numericalHandeling);
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // jsat.datatransform.RemoveAttributeTransform, jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        if (!(dataSet instanceof ClassificationDataSet)) {
            throw new FailedToFitException("MutualInfoFS only works for classification data sets, not " + dataSet.getClass().getSimpleName());
        }
        ClassificationDataSet classificationDataSet = (ClassificationDataSet) dataSet;
        super.fit(classificationDataSet);
        int sampleSize = classificationDataSet.getSampleSize();
        double[] priors = classificationDataSet.getPriors();
        double[] dArr = new double[priors.length];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = Math.log(priors[i]);
        }
        int numCategoricalVars = classificationDataSet.getNumCategoricalVars();
        int i2 = numCategoricalVars;
        if (this.numericHandling != NumericalHandeling.NONE) {
            i2 = classificationDataSet.getNumFeatures();
        }
        double[] dArr2 = new double[i2];
        CategoricalData[] categories = classificationDataSet.getCategories();
        double[][] dArr3 = new double[i2];
        for (int i3 = 0; i3 < dArr3.length; i3++) {
            if (i3 < classificationDataSet.getNumCategoricalVars()) {
                int numOfCategories = categories[i3].getNumOfCategories();
                dArr3[i3] = new double[numOfCategories][dArr.length];
                dArr2[i3] = new double[numOfCategories];
            } else {
                dArr3[i3] = new double[2][dArr.length];
                dArr2[i3] = new double[1];
            }
        }
        double d = 0.0d;
        for (int i4 = 0; i4 < classificationDataSet.getSampleSize(); i4++) {
            DataPoint dataPoint = classificationDataSet.getDataPoint(i4);
            int dataPointCategory = classificationDataSet.getDataPointCategory(i4);
            double weight = dataPoint.getWeight();
            d += weight;
            int[] categoricalValues = dataPoint.getCategoricalValues();
            for (int i5 = 0; i5 < categoricalValues.length; i5++) {
                double[] dArr4 = dArr2[i5];
                int i6 = categoricalValues[i5];
                dArr4[i6] = dArr4[i6] + weight;
                double[] dArr5 = dArr3[i5][categoricalValues[i5]];
                dArr5[dataPointCategory] = dArr5[dataPointCategory] + weight;
            }
            if (this.numericHandling == NumericalHandeling.BINARY) {
                Iterator<IndexValue> it = dataPoint.getNumericalValues().iterator();
                while (it.hasNext()) {
                    IndexValue next = it.next();
                    double[] dArr6 = dArr2[next.getIndex() + numCategoricalVars];
                    dArr6[0] = dArr6[0] + weight;
                    double[] dArr7 = dArr3[next.getIndex() + numCategoricalVars][0];
                    dArr7[dataPointCategory] = dArr7[dataPointCategory] + weight;
                }
            }
        }
        double[] dArr8 = new double[i2];
        for (int i7 = 0; i7 < i2; i7++) {
            double d2 = 0.0d;
            if (i7 < classificationDataSet.getNumCategoricalVars()) {
                for (int i8 = 0; i8 < dArr3[i7].length; i8++) {
                    double d3 = dArr2[i7][i8] / d;
                    if (d3 != 0.0d) {
                        double log = Math.log(d3);
                        for (int i9 = 0; i9 < dArr.length; i9++) {
                            double d4 = dArr3[i7][i8][i9] / d;
                            if (d4 != 0.0d) {
                                d2 += d4 * ((Math.log(d4) - log) - dArr[i9]);
                            }
                        }
                    }
                }
            } else {
                for (int i10 = 0; i10 < priors.length; i10++) {
                    double d5 = dArr3[i7][0][i10] / d;
                    double d6 = ((priors[i10] * sampleSize) - dArr3[i7][0][i10]) / d;
                    double d7 = dArr2[i7][0] / d;
                    double d8 = 1.0d - d7;
                    if (d5 != 0.0d && d8 != 0.0d) {
                        d2 += d5 * ((Math.log(d5) - Math.log(d8)) - dArr[i10]);
                    }
                    if (d6 != 0.0d && d7 != 0.0d) {
                        d2 += d6 * ((Math.log(d6) - Math.log(d7)) - dArr[i10]);
                    }
                }
            }
            dArr8[i7] = d2;
        }
        IndexTable indexTable = new IndexTable(dArr8);
        Set<Integer> intSet = new IntSet();
        Set<Integer> intSet2 = new IntSet();
        for (int i11 = 0; i11 < i2 - this.featureCount; i11++) {
            int index = indexTable.index(i11);
            if (index < numCategoricalVars) {
                intSet.add(Integer.valueOf(index));
            } else {
                intSet2.add(Integer.valueOf(index - numCategoricalVars));
            }
        }
        setUp(classificationDataSet, intSet, intSet2);
    }

    @Override // jsat.datatransform.RemoveAttributeTransform
    public MutualInfoFS clone() {
        return new MutualInfoFS(this);
    }

    public void setFeatureCount(int i) {
        if (i < 1) {
            throw new IllegalArgumentException("Number of features must be positive, not " + i);
        }
        this.featureCount = i;
    }

    public int getFeatureCount() {
        return this.featureCount;
    }

    public void setHandling(NumericalHandeling numericalHandeling) {
        this.numericHandling = numericalHandeling;
    }

    public NumericalHandeling getHandling() {
        return this.numericHandling;
    }
}
