package jsat.datatransform;

import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import jsat.DataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.linear.DenseVector;
import jsat.linear.IndexValue;
import jsat.linear.SparseVector;
import jsat.linear.Vec;
import jsat.utils.IntList;

/* loaded from: input_file:jsat/datatransform/RemoveAttributeTransform.class */
public class RemoveAttributeTransform implements DataTransform {
    private static final long serialVersionUID = 2803223213862922734L;
    protected int[] catIndexMap;
    protected int[] numIndexMap;
    private Set<Integer> categoricalToRemove;
    private Set<Integer> numericalToRemove;

    /* JADX INFO: Access modifiers changed from: protected */
    public RemoveAttributeTransform() {
    }

    public RemoveAttributeTransform(Set<Integer> set, Set<Integer> set2) {
        this.categoricalToRemove = set;
        this.numericalToRemove = set2;
    }

    public RemoveAttributeTransform(DataSet dataSet, Set<Integer> set, Set<Integer> set2) {
        this.categoricalToRemove = set;
        this.numericalToRemove = set2;
        setUp(dataSet, set, set2);
    }

    public List<Integer> getKeptNumeric() {
        return IntList.unmodifiableView(this.numIndexMap, this.numIndexMap.length);
    }

    public Map<Integer, Integer> getReverseNumericMap() {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.numIndexMap.length; i++) {
            hashMap.put(Integer.valueOf(i), Integer.valueOf(this.numIndexMap[i]));
        }
        return hashMap;
    }

    public List<Integer> getKeptNominal() {
        return IntList.unmodifiableView(this.catIndexMap, this.catIndexMap.length);
    }

    public Map<Integer, Integer> getReverseNominalMap() {
        HashMap hashMap = new HashMap();
        for (int i = 0; i < this.catIndexMap.length; i++) {
            hashMap.put(Integer.valueOf(i), Integer.valueOf(this.catIndexMap[i]));
        }
        return hashMap;
    }

    @Override // jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        if (this.categoricalToRemove == null || this.numericalToRemove == null) {
            return;
        }
        setUp(dataSet, this.categoricalToRemove, this.numericalToRemove);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public final void setUp(DataSet dataSet, Set<Integer> set, Set<Integer> set2) {
        Iterator<Integer> it = set.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            if (intValue >= dataSet.getNumCategoricalVars()) {
                throw new RuntimeException("The data set does not have a categorical value " + intValue + " to remove");
            }
        }
        Iterator<Integer> it2 = set2.iterator();
        while (it2.hasNext()) {
            int intValue2 = it2.next().intValue();
            if (intValue2 >= dataSet.getNumNumericalVars()) {
                throw new RuntimeException("The data set does not have a numercal value " + intValue2 + " to remove");
            }
        }
        this.catIndexMap = new int[dataSet.getNumCategoricalVars() - set.size()];
        this.numIndexMap = new int[dataSet.getNumNumericalVars() - set2.size()];
        int i = 0;
        for (int i2 = 0; i2 < dataSet.getNumCategoricalVars(); i2++) {
            if (!set.contains(Integer.valueOf(i2))) {
                int i3 = i;
                i++;
                this.catIndexMap[i3] = i2;
            }
        }
        int i4 = 0;
        for (int i5 = 0; i5 < dataSet.getNumNumericalVars(); i5++) {
            if (!set2.contains(Integer.valueOf(i5))) {
                int i6 = i4;
                i4++;
                this.numIndexMap[i6] = i5;
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public RemoveAttributeTransform(RemoveAttributeTransform removeAttributeTransform) {
        if (removeAttributeTransform.catIndexMap != null) {
            this.catIndexMap = Arrays.copyOf(removeAttributeTransform.catIndexMap, removeAttributeTransform.catIndexMap.length);
        }
        if (removeAttributeTransform.numIndexMap != null) {
            this.numIndexMap = Arrays.copyOf(removeAttributeTransform.numIndexMap, removeAttributeTransform.numIndexMap.length);
        }
    }

    public void consolidate(RemoveAttributeTransform removeAttributeTransform) {
        for (int i = 0; i < this.catIndexMap.length; i++) {
            this.catIndexMap[i] = removeAttributeTransform.catIndexMap[this.catIndexMap[i]];
        }
        for (int i2 = 0; i2 < this.numIndexMap.length; i2++) {
            this.numIndexMap[i2] = removeAttributeTransform.numIndexMap[this.numIndexMap[i2]];
        }
    }

    @Override // jsat.datatransform.DataTransform
    public DataPoint transform(DataPoint dataPoint) {
        int[] categoricalValues = dataPoint.getCategoricalValues();
        Vec numericalValues = dataPoint.getNumericalValues();
        CategoricalData[] categoricalDataArr = new CategoricalData[this.catIndexMap.length];
        int[] iArr = new int[categoricalDataArr.length];
        Vec sparseVector = numericalValues.isSparse() ? numericalValues instanceof SparseVector ? new SparseVector(this.numIndexMap.length, ((SparseVector) numericalValues).nnz()) : new SparseVector(this.numIndexMap.length) : new DenseVector(this.numIndexMap.length);
        for (int i = 0; i < this.catIndexMap.length; i++) {
            iArr[i] = categoricalValues[this.catIndexMap[i]];
        }
        Iterator<IndexValue> nonZeroIterator = numericalValues.getNonZeroIterator();
        if (nonZeroIterator.hasNext()) {
            IndexValue next = nonZeroIterator.next();
            for (int i2 = 0; i2 < this.numIndexMap.length; i2++) {
                if (!numericalValues.isSparse()) {
                    sparseVector.set(i2, numericalValues.get(this.numIndexMap[i2]));
                } else if (next != null) {
                    if (this.numIndexMap[i2] > next.getIndex()) {
                        while (this.numIndexMap[i2] > next.getIndex() && nonZeroIterator.hasNext()) {
                            next = nonZeroIterator.next();
                        }
                    }
                    if (this.numIndexMap[i2] >= next.getIndex() && this.numIndexMap[i2] == next.getIndex()) {
                        sparseVector.set(i2, next.getValue());
                        next = nonZeroIterator.hasNext() ? nonZeroIterator.next() : null;
                    }
                }
            }
        }
        return new DataPoint(sparseVector, iArr, categoricalDataArr, dataPoint.getWeight());
    }

    @Override // jsat.datatransform.DataTransform
    public RemoveAttributeTransform clone() {
        return new RemoveAttributeTransform(this);
    }
}
