tiny_solver/
loss_functions.rsuse core::f64;
pub enum LossFunc {
HuberLoss,
}
pub trait Loss: Send + Sync {
fn evaluate(&self, s: f64) -> [f64; 3];
}
#[derive(Debug, Clone)]
pub struct HuberLoss {
scale: f64,
scale2: f64,
}
impl HuberLoss {
pub fn new(scale: f64) -> Self {
if scale <= 0.0 {
panic!("scale needs to be larger than zero");
}
HuberLoss {
scale,
scale2: scale * scale,
}
}
}
impl Loss for HuberLoss {
fn evaluate(&self, s: f64) -> [f64; 3] {
if s > self.scale2 {
let r = s.sqrt();
let rho1 = (self.scale / r).max(f64::MIN);
[2.0 * self.scale * r - self.scale2, rho1, -rho1 / (2.0 * s)]
} else {
[s, 1.0, 0.0]
}
}
}
pub struct CauchyLoss {
scale2: f64,
c: f64,
}
impl CauchyLoss {
pub fn new(scale: f64) -> Self {
let scale2 = scale * scale;
CauchyLoss {
scale2,
c: 1.0 / scale2,
}
}
}
impl Loss for CauchyLoss {
fn evaluate(&self, s: f64) -> [f64; 3] {
let sum = 1.0 + s * self.c;
let inv = 1.0 / sum;
[
self.scale2 * sum.log2(),
inv.max(f64::MIN),
-self.c * (inv * inv),
]
}
}
pub struct ArctanLoss {
tolerance: f64,
inv_of_squared_tolerance: f64,
}
impl ArctanLoss {
pub fn new(tolerance: f64) -> Self {
if tolerance <= 0.0 {
panic!("scale needs to be larger than zero");
}
ArctanLoss {
tolerance,
inv_of_squared_tolerance: 1.0 / (tolerance * tolerance),
}
}
}
impl Loss for ArctanLoss {
fn evaluate(&self, s: f64) -> [f64; 3] {
let sum = 1.0 + s * s * self.inv_of_squared_tolerance;
let inv = 1.0 / sum;
[
self.tolerance * s.atan2(self.tolerance),
inv.max(f64::MIN),
-2.0 * s * self.inv_of_squared_tolerance * (inv * inv),
]
}
}