pub struct NablaOptimizer;
Implementations§
Source§impl NablaOptimizer
impl NablaOptimizer
Sourcepub fn sgd_update(weights: &mut NDArray, gradient: &NDArray, learning_rate: f64)
pub fn sgd_update(weights: &mut NDArray, gradient: &NDArray, learning_rate: f64)
Performs Stochastic Gradient Descent (SGD) update
w = w - learning_rate * gradient
§Arguments
weights
- NDArray of current weights to updategradient
- NDArray of gradients for the weightslearning_rate
- Learning rate for the update
§Example
use nabla_ml::nab_array::NDArray;
use nabla_ml::nab_optimizers::NablaOptimizer;
let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
let learning_rate = 0.1;
NablaOptimizer::sgd_update(&mut weights, &gradients, learning_rate);
Sourcepub fn sgd_momentum_update(
weights: &mut NDArray,
gradient: &NDArray,
velocity: &mut NDArray,
learning_rate: f64,
momentum: f64,
)
pub fn sgd_momentum_update( weights: &mut NDArray, gradient: &NDArray, velocity: &mut NDArray, learning_rate: f64, momentum: f64, )
Performs SGD update with momentum
v = momentum * v - learning_rate * gradient w = w + v
§Arguments
weights
- NDArray of current weights to updategradient
- NDArray of gradients for the weightsvelocity
- Mutable reference to momentum velocitylearning_rate
- Learning rate for the updatemomentum
- Momentum coefficient (default: 0.9)
Sourcepub fn rmsprop_update(
weights: &mut NDArray,
gradient: &NDArray,
cache: &mut NDArray,
learning_rate: f64,
decay_rate: f64,
epsilon: f64,
)
pub fn rmsprop_update( weights: &mut NDArray, gradient: &NDArray, cache: &mut NDArray, learning_rate: f64, decay_rate: f64, epsilon: f64, )
Performs RMSprop update
cache = decay_rate * cache + (1 - decay_rate) * gradient^2 w = w - learning_rate * gradient / (sqrt(cache) + epsilon)
§Arguments
weights
- NDArray of current weights to updategradient
- NDArray of gradients for the weightscache
- Running average of squared gradientslearning_rate
- Learning rate for the updatedecay_rate
- Decay rate for running average (default: 0.9)epsilon
- Small value for numerical stability (default: 1e-8)
§Example
use nabla_ml::nab_array::NDArray;
use nabla_ml::nab_optimizers::NablaOptimizer;
let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
let mut cache = NDArray::zeros(vec![3]);
let learning_rate = 0.01;
let decay_rate = 0.9;
let epsilon = 1e-8;
NablaOptimizer::rmsprop_update(
&mut weights,
&gradients,
&mut cache,
learning_rate,
decay_rate,
epsilon
);
Sourcepub fn adam_update(
weights: &mut NDArray,
gradient: &NDArray,
m: &mut NDArray,
v: &mut NDArray,
t: usize,
learning_rate: f64,
beta1: f64,
beta2: f64,
epsilon: f64,
)
pub fn adam_update( weights: &mut NDArray, gradient: &NDArray, m: &mut NDArray, v: &mut NDArray, t: usize, learning_rate: f64, beta1: f64, beta2: f64, epsilon: f64, )
Performs Adam (Adaptive Moment Estimation) update
m = beta1 * m + (1 - beta1) * gradient // Update first moment v = beta2 * v + (1 - beta2) * gradient^2 // Update second moment m_hat = m / (1 - beta1^t) // Bias correction v_hat = v / (1 - beta2^t) // Bias correction w = w - learning_rate * m_hat / (sqrt(v_hat) + epsilon)
§Arguments
weights
- NDArray of current weights to updategradient
- NDArray of gradients for the weightsm
- First moment vector (momentum)v
- Second moment vector (uncentered variance)t
- Current timestep (starting from 1)learning_rate
- Learning rate for the updatebeta1
- Exponential decay rate for first moment (default: 0.9)beta2
- Exponential decay rate for second moment (default: 0.999)epsilon
- Small value for numerical stability (default: 1e-8)
§Example
use nabla_ml::nab_array::NDArray;
use nabla_ml::nab_optimizers::NablaOptimizer;
let mut weights = NDArray::from_vec(vec![1.0, 2.0, 3.0]);
let gradients = NDArray::from_vec(vec![0.1, 0.2, 0.3]);
let mut m = NDArray::zeros(vec![3]);
let mut v = NDArray::zeros(vec![3]);
let t = 1;
let learning_rate = 0.001;
let beta1 = 0.9;
let beta2 = 0.999;
let epsilon = 1e-8;
NablaOptimizer::adam_update(
&mut weights,
&gradients,
&mut m,
&mut v,
t,
learning_rate,
beta1,
beta2,
epsilon
);
Auto Trait Implementations§
impl Freeze for NablaOptimizer
impl RefUnwindSafe for NablaOptimizer
impl Send for NablaOptimizer
impl Sync for NablaOptimizer
impl Unpin for NablaOptimizer
impl UnwindSafe for NablaOptimizer
Blanket Implementations§
Source§impl<T> BorrowMut<T> for Twhere
T: ?Sized,
impl<T> BorrowMut<T> for Twhere
T: ?Sized,
Source§fn borrow_mut(&mut self) -> &mut T
fn borrow_mut(&mut self) -> &mut T
Mutably borrows from an owned value. Read more