package jsat.datatransform;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.OnLineStatistics;
import jsat.utils.ClosedHashingUtil;
import jsat.utils.DoubleList;
import jsat.utils.IndexTable;

/* loaded from: input_file:jsat/datatransform/Imputer.class */
public class Imputer implements InPlaceTransform {
    private NumericImputionMode mode;
    protected int[] cat_imputs;
    protected double[] numeric_imputs;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: jsat.datatransform.Imputer$1, reason: invalid class name */
    /* loaded from: input_file:jsat/datatransform/Imputer$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$jsat$datatransform$Imputer$NumericImputionMode = new int[NumericImputionMode.values().length];

        static {
            try {
                $SwitchMap$jsat$datatransform$Imputer$NumericImputionMode[NumericImputionMode.MEAN.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$jsat$datatransform$Imputer$NumericImputionMode[NumericImputionMode.MEDIAN.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* loaded from: input_file:jsat/datatransform/Imputer$NumericImputionMode.class */
    public enum NumericImputionMode {
        MEAN,
        MEDIAN
    }

    public Imputer(DataSet<?> dataSet) {
        this(dataSet, NumericImputionMode.MEAN);
    }

    public Imputer(DataSet<?> dataSet, NumericImputionMode numericImputionMode) {
        this.mode = numericImputionMode;
        fit(dataSet);
    }

    public Imputer(Imputer imputer) {
        this.mode = imputer.mode;
        if (imputer.cat_imputs != null) {
            this.cat_imputs = Arrays.copyOf(imputer.cat_imputs, imputer.cat_imputs.length);
        }
        if (imputer.numeric_imputs != null) {
            this.numeric_imputs = Arrays.copyOf(imputer.numeric_imputs, imputer.numeric_imputs.length);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        this.numeric_imputs = new double[dataSet.getNumNumericalVars()];
        this.cat_imputs = new int[dataSet.getNumCategoricalVars()];
        ArrayList arrayList = null;
        ArrayList arrayList2 = null;
        double[] dArr = null;
        switch (AnonymousClass1.$SwitchMap$jsat$datatransform$Imputer$NumericImputionMode[this.mode.ordinal()]) {
            case 1:
                OnLineStatistics[] onlineColumnStats = dataSet.getOnlineColumnStats(true);
                for (int i = 0; i < onlineColumnStats.length; i++) {
                    this.numeric_imputs[i] = onlineColumnStats[i].getMean();
                }
                break;
            case ClosedHashingUtil.DELETED /* 2 */:
                arrayList = new ArrayList(dataSet.getNumNumericalVars());
                arrayList2 = new ArrayList(dataSet.getNumNumericalVars());
                dArr = new double[dataSet.getNumNumericalVars()];
                for (int i2 = 0; i2 < dataSet.getNumNumericalVars(); i2++) {
                    arrayList.add(new DoubleList(dataSet.getSampleSize()));
                    arrayList2.add(new DoubleList(dataSet.getSampleSize()));
                }
                break;
        }
        double[] dArr2 = new double[dataSet.getNumCategoricalVars()];
        for (int i3 = 0; i3 < dArr2.length; i3++) {
            dArr2[i3] = new double[dataSet.getCategories()[i3].getNumOfCategories()];
        }
        for (int i4 = 0; i4 < dataSet.getSampleSize(); i4++) {
            DataPoint dataPoint = dataSet.getDataPoint(i4);
            double weight = dataPoint.getWeight();
            int[] categoricalValues = dataPoint.getCategoricalValues();
            for (int i5 = 0; i5 < categoricalValues.length; i5++) {
                if (categoricalValues[i5] >= 0) {
                    double[] dArr3 = dArr2[i5];
                    int i6 = categoricalValues[i5];
                    dArr3[i6] = dArr3[i6] + weight;
                }
            }
            Vec numericalValues = dataPoint.getNumericalValues();
            if (this.mode == NumericImputionMode.MEDIAN) {
                Iterator<IndexValue> it = numericalValues.iterator();
                while (it.hasNext()) {
                    IndexValue next = it.next();
                    if (!Double.isNaN(next.getValue())) {
                        ((List) arrayList.get(next.getIndex())).add(Double.valueOf(next.getValue()));
                        ((List) arrayList2.get(next.getIndex())).add(Double.valueOf(weight));
                        double[] dArr4 = dArr;
                        int index = next.getIndex();
                        dArr4[index] = dArr4[index] + weight;
                    }
                }
            }
        }
        if (this.mode == NumericImputionMode.MEDIAN) {
            IndexTable indexTable = new IndexTable(dataSet.getNumNumericalVars());
            for (int i7 = 0; i7 < dataSet.getNumNumericalVars(); i7++) {
                List list = (List) arrayList.get(i7);
                List list2 = (List) arrayList2.get(i7);
                indexTable.reset();
                indexTable.sort(list);
                double d = dArr[i7] / 2.0d;
                double d2 = 0.0d;
                double d3 = 0.0d;
                for (int i8 = 0; i8 < indexTable.length() && d3 < d; i8++) {
                    int index2 = indexTable.index(i8);
                    d2 = ((Double) list.get(index2)).doubleValue();
                    d3 += ((Double) list2.get(index2)).doubleValue();
                }
                this.numeric_imputs[i7] = d2;
            }
        }
        for (int i9 = 0; i9 < dArr2.length; i9++) {
            int i10 = 0;
            for (int i11 = 1; i11 < dArr2[i9].length; i11++) {
                if (dArr2[i9][i11] > dArr2[i9][i10]) {
                    i10 = i11;
                }
            }
            this.cat_imputs[i9] = i10;
        }
    }

    @Override // jsat.datatransform.InPlaceTransform
    public void mutableTransform(DataPoint dataPoint) {
        Vec numericalValues = dataPoint.getNumericalValues();
        Iterator<IndexValue> it = numericalValues.iterator();
        while (it.hasNext()) {
            IndexValue next = it.next();
            if (Double.isNaN(next.getValue())) {
                numericalValues.set(next.getIndex(), this.numeric_imputs[next.getIndex()]);
            }
        }
        int[] categoricalValues = dataPoint.getCategoricalValues();
        for (int i = 0; i < categoricalValues.length; i++) {
            if (categoricalValues[i] < 0) {
                categoricalValues[i] = this.cat_imputs[i];
            }
        }
    }

    @Override // jsat.datatransform.InPlaceTransform
    public boolean mutatesNominal() {
        return true;
    }

    @Override // jsat.datatransform.DataTransform
    public DataPoint transform(DataPoint dataPoint) {
        DataPoint m6clone = dataPoint.m6clone();
        mutableTransform(m6clone);
        return m6clone;
    }

    @Override // jsat.datatransform.DataTransform
    public Imputer clone() {
        return new Imputer(this);
    }
}
