package jsat.datatransform;

import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.math.IndexFunction;
import jsat.math.OnLineStatistics;
import jsat.utils.DoubleList;

/* loaded from: input_file:jsat/datatransform/AutoDeskewTransform.class */
public class AutoDeskewTransform implements InPlaceTransform {
    private static final long serialVersionUID = -4894242802345656448L;
    private double[] finalLambdas;
    private double[] mins;
    private final IndexFunction transform;
    private static final DoubleList defaultList = new DoubleList(7);
    private List<Double> lambdas;
    private boolean ignorZeros;

    public AutoDeskewTransform() {
        this(true, (List<Double>) defaultList);
    }

    public AutoDeskewTransform(double... dArr) {
        this(true, (List<Double>) DoubleList.view(dArr, dArr.length));
    }

    public AutoDeskewTransform(List<Double> list) {
        this(true, list);
    }

    public AutoDeskewTransform(boolean z, List<Double> list) {
        this.transform = new IndexFunction() { // from class: jsat.datatransform.AutoDeskewTransform.1
            private static final long serialVersionUID = -404316813485246422L;

            @Override // jsat.math.IndexFunction
            public double indexFunc(double d, int i) {
                if (i < 0) {
                    return 0.0d;
                }
                return AutoDeskewTransform.transform(d, AutoDeskewTransform.this.finalLambdas[i], AutoDeskewTransform.this.mins[i]);
            }
        };
        this.ignorZeros = z;
        this.lambdas = list;
    }

    public AutoDeskewTransform(DataSet dataSet) {
        this(dataSet, defaultList);
    }

    public AutoDeskewTransform(DataSet dataSet, List<Double> list) {
        this(dataSet, true, list);
    }

    public AutoDeskewTransform(DataSet dataSet, boolean z, List<Double> list) {
        this(z, list);
        fit(dataSet);
    }

    @Override // jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        if (!this.lambdas.contains(Double.valueOf(1.0d))) {
            this.lambdas.add(Double.valueOf(1.0d));
        }
        OnLineStatistics[][] onLineStatisticsArr = new OnLineStatistics[this.lambdas.size()][dataSet.getNumNumericalVars()];
        for (int i = 0; i < onLineStatisticsArr.length; i++) {
            for (int i2 = 0; i2 < onLineStatisticsArr[i].length; i2++) {
                onLineStatisticsArr[i][i2] = new OnLineStatistics();
            }
        }
        this.mins = new double[dataSet.getNumNumericalVars()];
        Arrays.fill(this.mins, Double.POSITIVE_INFINITY);
        boolean z = false;
        for (int i3 = 0; i3 < dataSet.getSampleSize(); i3++) {
            Vec numericalValues = dataSet.getDataPoint(i3).getNumericalValues();
            if (numericalValues.isSparse()) {
                z = true;
            }
            Iterator<IndexValue> it = numericalValues.iterator();
            while (it.hasNext()) {
                IndexValue next = it.next();
                int index = next.getIndex();
                this.mins[index] = Math.min(next.getValue(), this.mins[index]);
            }
        }
        if (z) {
            for (int i4 = 0; i4 < this.mins.length; i4++) {
                this.mins[i4] = Math.min(0.0d, this.mins[i4]);
            }
        }
        for (int i5 = 0; i5 < dataSet.getSampleSize(); i5++) {
            Vec numericalValues2 = dataSet.getDataPoint(i5).getNumericalValues();
            double weight = dataSet.getDataPoint(i5).getWeight();
            int i6 = -1;
            Iterator<IndexValue> it2 = numericalValues2.iterator();
            while (it2.hasNext()) {
                IndexValue next2 = it2.next();
                int index2 = next2.getIndex();
                updateStats(this.lambdas, onLineStatisticsArr, index2, next2.getValue(), this.mins, weight);
                if (!this.ignorZeros) {
                    for (int i7 = i6 + 1; i7 < index2; i7++) {
                        updateStats(this.lambdas, onLineStatisticsArr, i7, 0.0d, this.mins, weight);
                    }
                }
                i6 = index2;
            }
            if (!this.ignorZeros) {
                for (int i8 = i6 + 1; i8 < this.mins.length; i8++) {
                    updateStats(this.lambdas, onLineStatisticsArr, i8, 0.0d, this.mins, weight);
                }
            }
        }
        this.finalLambdas = new double[this.mins.length];
        int indexOf = this.lambdas.indexOf(Double.valueOf(1.0d));
        for (int i9 = 0; i9 < this.finalLambdas.length; i9++) {
            double d = Double.POSITIVE_INFINITY;
            double d2 = 1.0d;
            for (int i10 = 0; i10 < this.lambdas.size(); i10++) {
                double abs = Math.abs(onLineStatisticsArr[i10][i9].getSkewness());
                if (abs < d) {
                    d = abs;
                    d2 = this.lambdas.get(i10).doubleValue();
                }
            }
            if (Math.abs(onLineStatisticsArr[indexOf][i9].getSkewness()) > d * 1.05d) {
                this.finalLambdas[i9] = d2;
            } else {
                this.finalLambdas[i9] = 1.0d;
            }
        }
    }

    protected AutoDeskewTransform(AutoDeskewTransform autoDeskewTransform) {
        this.transform = new IndexFunction() { // from class: jsat.datatransform.AutoDeskewTransform.1
            private static final long serialVersionUID = -404316813485246422L;

            @Override // jsat.math.IndexFunction
            public double indexFunc(double d, int i) {
                if (i < 0) {
                    return 0.0d;
                }
                return AutoDeskewTransform.transform(d, AutoDeskewTransform.this.finalLambdas[i], AutoDeskewTransform.this.mins[i]);
            }
        };
        this.finalLambdas = Arrays.copyOf(autoDeskewTransform.finalLambdas, autoDeskewTransform.finalLambdas.length);
        this.mins = Arrays.copyOf(autoDeskewTransform.mins, autoDeskewTransform.mins.length);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static double transform(double d, double d2, double d3) {
        if (d == 0.0d) {
            return 0.0d;
        }
        return d2 == 2.0d ? d * d : d2 == 1.0d ? d : d2 == 0.5d ? Math.sqrt(d - d3) : d2 == 0.0d ? Math.log((d + 1.0d) - d3) : d2 == -0.5d ? 1.0d / Math.sqrt(d - d3) : d2 == -1.0d ? 1.0d / d : d2 == -2.0d ? 1.0d / (d * d) : Math.pow(d, d2) / d2;
    }

    @Override // jsat.datatransform.DataTransform
    public DataPoint transform(DataPoint dataPoint) {
        DataPoint m6clone = dataPoint.m6clone();
        mutableTransform(m6clone);
        return m6clone;
    }

    @Override // jsat.datatransform.InPlaceTransform
    public void mutableTransform(DataPoint dataPoint) {
        dataPoint.getNumericalValues().applyIndexFunction(this.transform);
    }

    @Override // jsat.datatransform.DataTransform
    public AutoDeskewTransform clone() {
        return new AutoDeskewTransform(this);
    }

    private void updateStats(List<Double> list, OnLineStatistics[][] onLineStatisticsArr, int i, double d, double[] dArr, double d2) {
        for (int i2 = 0; i2 < list.size(); i2++) {
            onLineStatisticsArr[i2][i].add(transform(d, list.get(i2).doubleValue(), dArr[i]), d2);
        }
    }

    @Override // jsat.datatransform.InPlaceTransform
    public boolean mutatesNominal() {
        return false;
    }

    static {
        defaultList.add(-1.0d);
        defaultList.add(-0.5d);
        defaultList.add(0.0d);
        defaultList.add(0.5d);
        defaultList.add(1.0d);
    }
}
