package jsat.math.optimization.stochastic;

import jsat.linear.DenseVector;
import jsat.linear.ScaledVector;
import jsat.linear.Vec;

/* loaded from: input_file:jsat/math/optimization/stochastic/SGDMomentum.class */
public class SGDMomentum implements GradientUpdater {
    private static final long serialVersionUID = -3837883539010356899L;
    private double momentum;
    private boolean nestrov;
    private Vec velocity;
    private double biasVelocity;

    public SGDMomentum(double d, boolean z) {
        setMomentum(d);
        this.nestrov = z;
    }

    public SGDMomentum(double d) {
        this(d, true);
    }

    public SGDMomentum(SGDMomentum sGDMomentum) {
        this.momentum = sGDMomentum.momentum;
        if (sGDMomentum.velocity != null) {
            this.velocity = sGDMomentum.velocity.mo45clone();
        }
        this.biasVelocity = sGDMomentum.biasVelocity;
    }

    public void setMomentum(double d) {
        if (d <= 0.0d || d >= 1.0d || Double.isNaN(d)) {
            throw new IllegalArgumentException("Momentum must be in (0,1) not " + d);
        }
        this.momentum = d;
    }

    public double getMomentum() {
        return this.momentum;
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    public void update(Vec vec, Vec vec2, double d) {
        update(vec, vec2, d, 0.0d, 0.0d);
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    public double update(Vec vec, Vec vec2, double d, double d2, double d3) {
        double d4;
        if (this.nestrov) {
            vec.mutableAdd(this.momentum * this.momentum, this.velocity);
            vec.mutableSubtract((1.0d + this.momentum) * d, vec2);
            d4 = ((-this.momentum) * this.momentum * this.biasVelocity) + ((1.0d + this.momentum) * d * d3);
        } else {
            vec.mutableAdd(this.momentum, this.velocity);
            vec.mutableSubtract(d, vec2);
            d4 = ((-this.momentum) * this.biasVelocity) + (d * d3);
        }
        this.velocity.mutableMultiply(this.momentum);
        this.velocity.mutableSubtract(d, vec2);
        this.biasVelocity = (this.biasVelocity * this.momentum) - (d * d3);
        return d4;
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public SGDMomentum m232clone() {
        return new SGDMomentum(this);
    }

    @Override // jsat.math.optimization.stochastic.GradientUpdater
    public void setup(int i) {
        this.velocity = new ScaledVector(new DenseVector(i));
        this.biasVelocity = 0.0d;
    }
}
