package jsat.datatransform.kernel;

import java.util.Random;
import jsat.DataSet;
import jsat.classifiers.DataPoint;
import jsat.datatransform.DataTransformBase;
import jsat.distributions.Distribution;
import jsat.distributions.kernels.RBFKernel;
import jsat.linear.DenseVector;
import jsat.linear.Matrix;
import jsat.linear.RandomMatrix;
import jsat.linear.RandomVector;
import jsat.linear.Vec;
import jsat.utils.random.XORWOW;

/* loaded from: input_file:jsat/datatransform/kernel/RFF_RBF.class */
public class RFF_RBF extends DataTransformBase {
    private static final long serialVersionUID = -3478216020648280477L;
    private Matrix transform;
    private Vec offsets;
    private double sigma;
    private int dim;
    private boolean inMemory;

    /* loaded from: input_file:jsat/datatransform/kernel/RFF_RBF$RandomMatrixRFF_RBF.class */
    private static class RandomMatrixRFF_RBF extends RandomMatrix {
        private static final long serialVersionUID = 4702514384718636893L;
        private double coef;

        public RandomMatrixRFF_RBF(double d, int i, int i2, long j) {
            super(i, i2, j);
            this.coef = d;
        }

        @Override // jsat.linear.RandomMatrix
        protected double getVal(Random random) {
            return this.coef * random.nextGaussian();
        }
    }

    /* loaded from: input_file:jsat/datatransform/kernel/RFF_RBF$RandomVectorRFF_RBF.class */
    private static class RandomVectorRFF_RBF extends RandomVector {
        private static final long serialVersionUID = -6132378281909907937L;

        public RandomVectorRFF_RBF(int i, long j) {
            super(i, j);
        }

        @Override // jsat.linear.RandomVector
        protected double getVal(Random random) {
            return random.nextDouble() * 2.0d * 3.141592653589793d;
        }

        @Override // jsat.linear.RandomVector, jsat.linear.Vec
        /* renamed from: clone */
        public Vec mo45clone() {
            return this;
        }
    }

    public RFF_RBF() {
        this(1.0d);
    }

    public RFF_RBF(double d) {
        this(d, 512);
    }

    public RFF_RBF(double d, int i) {
        this(d, i, true);
    }

    public RFF_RBF(double d, int i, boolean z) {
        setSigma(d);
        setDimensions(i);
        setInMemory(z);
    }

    public RFF_RBF(int i, double d, int i2, Random random, boolean z) {
        this(d, i2, z);
        if (i <= 0) {
            throw new IllegalArgumentException("The number of numeric features must be positive, not " + i);
        }
        if (d <= 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("The sigma parameter must be positive, not " + d);
        }
        if (i2 <= 1) {
            throw new IllegalArgumentException("The target dimension must be positive, not " + i2);
        }
        this.transform = new RandomMatrixRFF_RBF(Math.sqrt(0.5d / (d * d)), i, i2, random.nextLong());
        this.offsets = new RandomVectorRFF_RBF(i2, random.nextLong());
        if (z) {
            this.transform = this.transform.add(0.0d);
            this.offsets = new DenseVector(this.offsets);
        }
    }

    @Override // jsat.datatransform.DataTransform
    public void fit(DataSet dataSet) {
        int numNumericalVars = dataSet.getNumNumericalVars();
        XORWOW xorwow = new XORWOW();
        this.transform = new RandomMatrixRFF_RBF(Math.sqrt(0.5d / (this.sigma * this.sigma)), numNumericalVars, this.dim, xorwow.nextLong());
        this.offsets = new RandomVectorRFF_RBF(this.dim, xorwow.nextLong());
        if (this.inMemory) {
            this.transform = this.transform.add(0.0d);
            this.offsets = new DenseVector(this.offsets);
        }
    }

    protected RFF_RBF(RFF_RBF rff_rbf) {
        if (rff_rbf.transform != null) {
            this.transform = rff_rbf.transform.mo161clone();
        }
        if (rff_rbf.offsets != null) {
            this.offsets = rff_rbf.offsets.mo45clone();
        }
        this.dim = rff_rbf.dim;
        this.inMemory = rff_rbf.inMemory;
        this.sigma = rff_rbf.sigma;
    }

    @Override // jsat.datatransform.DataTransform
    public DataPoint transform(DataPoint dataPoint) {
        Vec multiply = dataPoint.getNumericalValues().multiply(this.transform);
        double sqrt = Math.sqrt(2.0d / this.transform.cols());
        for (int i = 0; i < multiply.length(); i++) {
            multiply.set(i, Math.cos(multiply.get(i) + this.offsets.get(i)) * sqrt);
        }
        return new DataPoint(multiply, dataPoint.getCategoricalValues(), dataPoint.getCategoricalData(), dataPoint.getWeight());
    }

    @Override // jsat.datatransform.DataTransformBase
    public RFF_RBF clone() {
        return new RFF_RBF(this);
    }

    public void setInMemory(boolean z) {
        this.inMemory = z;
    }

    public boolean isInMemory() {
        return this.inMemory;
    }

    public void setDimensions(int i) {
        if (i < 1) {
            throw new ArithmeticException("Number of dimensions must be a positive value, not " + i);
        }
        this.dim = i;
    }

    public int getDimensions() {
        return this.dim;
    }

    public void setSigma(double d) {
        if (d <= 0.0d || Double.isInfinite(d) || Double.isNaN(d)) {
            throw new IllegalArgumentException("Sigma must be a positive value, not " + d);
        }
        this.sigma = d;
    }

    public double getSigma() {
        return this.sigma;
    }

    public Distribution guessSigma(DataSet dataSet) {
        return RBFKernel.guessSigma(dataSet);
    }
}
