use std::ops::{Add, BitAnd, Sub};
use datafusion_common::Result;
use datafusion_common::ScalarValue;
#[cfg(all(target_arch = "aarch64", not(target_os = "windows")))]
const FE_UPWARD: i32 = 0x00400000;
#[cfg(all(target_arch = "aarch64", not(target_os = "windows")))]
const FE_DOWNWARD: i32 = 0x00800000;
#[cfg(all(target_arch = "x86_64", not(target_os = "windows")))]
const FE_UPWARD: i32 = 0x0800;
#[cfg(all(target_arch = "x86_64", not(target_os = "windows")))]
const FE_DOWNWARD: i32 = 0x0400;
#[cfg(all(
any(target_arch = "x86_64", target_arch = "aarch64"),
not(target_os = "windows")
))]
extern crate libc;
#[cfg(all(
any(target_arch = "x86_64", target_arch = "aarch64"),
not(target_os = "windows")
))]
extern "C" {
fn fesetround(round: i32);
fn fegetround() -> i32;
}
pub trait FloatBits {
type Item: Copy
+ PartialEq
+ BitAnd<Output = Self::Item>
+ Add<Output = Self::Item>
+ Sub<Output = Self::Item>;
const TINY_BITS: Self::Item;
const NEG_TINY_BITS: Self::Item;
const CLEAR_SIGN_MASK: Self::Item;
const ONE: Self::Item;
const ZERO: Self::Item;
fn to_bits(self) -> Self::Item;
fn from_bits(bits: Self::Item) -> Self;
fn float_is_nan(self) -> bool;
fn infinity() -> Self;
fn neg_infinity() -> Self;
}
impl FloatBits for f32 {
type Item = u32;
const TINY_BITS: u32 = 0x1; const NEG_TINY_BITS: u32 = 0x8000_0001; const CLEAR_SIGN_MASK: u32 = 0x7fff_ffff;
const ONE: Self::Item = 1;
const ZERO: Self::Item = 0;
fn to_bits(self) -> Self::Item {
self.to_bits()
}
fn from_bits(bits: Self::Item) -> Self {
f32::from_bits(bits)
}
fn float_is_nan(self) -> bool {
self.is_nan()
}
fn infinity() -> Self {
f32::INFINITY
}
fn neg_infinity() -> Self {
f32::NEG_INFINITY
}
}
impl FloatBits for f64 {
type Item = u64;
const TINY_BITS: u64 = 0x1; const NEG_TINY_BITS: u64 = 0x8000_0000_0000_0001; const CLEAR_SIGN_MASK: u64 = 0x7fff_ffff_ffff_ffff;
const ONE: Self::Item = 1;
const ZERO: Self::Item = 0;
fn to_bits(self) -> Self::Item {
self.to_bits()
}
fn from_bits(bits: Self::Item) -> Self {
f64::from_bits(bits)
}
fn float_is_nan(self) -> bool {
self.is_nan()
}
fn infinity() -> Self {
f64::INFINITY
}
fn neg_infinity() -> Self {
f64::NEG_INFINITY
}
}
pub fn next_up<F: FloatBits + Copy>(float: F) -> F {
let bits = float.to_bits();
if float.float_is_nan() || bits == F::infinity().to_bits() {
return float;
}
let abs = bits & F::CLEAR_SIGN_MASK;
let next_bits = if abs == F::ZERO {
F::TINY_BITS
} else if bits == abs {
bits + F::ONE
} else {
bits - F::ONE
};
F::from_bits(next_bits)
}
pub fn next_down<F: FloatBits + Copy>(float: F) -> F {
let bits = float.to_bits();
if float.float_is_nan() || bits == F::neg_infinity().to_bits() {
return float;
}
let abs = bits & F::CLEAR_SIGN_MASK;
let next_bits = if abs == F::ZERO {
F::NEG_TINY_BITS
} else if bits == abs {
bits - F::ONE
} else {
bits + F::ONE
};
F::from_bits(next_bits)
}
#[cfg(any(
not(any(target_arch = "x86_64", target_arch = "aarch64")),
target_os = "windows"
))]
fn alter_fp_rounding_mode_conservative<const UPPER: bool, F>(
lhs: &ScalarValue,
rhs: &ScalarValue,
operation: F,
) -> Result<ScalarValue>
where
F: FnOnce(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
{
let mut result = operation(lhs, rhs)?;
match &mut result {
ScalarValue::Float64(Some(value)) => {
if UPPER {
*value = next_up(*value)
} else {
*value = next_down(*value)
}
}
ScalarValue::Float32(Some(value)) => {
if UPPER {
*value = next_up(*value)
} else {
*value = next_down(*value)
}
}
_ => {}
};
Ok(result)
}
pub fn alter_fp_rounding_mode<const UPPER: bool, F>(
lhs: &ScalarValue,
rhs: &ScalarValue,
operation: F,
) -> Result<ScalarValue>
where
F: FnOnce(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
{
#[cfg(all(
any(target_arch = "x86_64", target_arch = "aarch64"),
not(target_os = "windows")
))]
unsafe {
let current = fegetround();
fesetround(if UPPER { FE_UPWARD } else { FE_DOWNWARD });
let result = operation(lhs, rhs);
fesetround(current);
result
}
#[cfg(any(
not(any(target_arch = "x86_64", target_arch = "aarch64")),
target_os = "windows"
))]
alter_fp_rounding_mode_conservative::<UPPER, _>(lhs, rhs, operation)
}
#[cfg(test)]
mod tests {
use super::{next_down, next_up};
#[test]
fn test_next_down() {
let x = 1.0f64;
let clamped = x.clamp(0.0, next_down(1.0f64));
assert!(clamped < 1.0);
assert_eq!(next_up(clamped), 1.0);
}
#[test]
fn test_next_up_small_positive() {
let value: f64 = 1.0;
let result = next_up(value);
assert_eq!(result, 1.0000000000000002);
}
#[test]
fn test_next_up_small_negative() {
let value: f64 = -1.0;
let result = next_up(value);
assert_eq!(result, -0.9999999999999999);
}
#[test]
fn test_next_up_pos_infinity() {
let value: f64 = f64::INFINITY;
let result = next_up(value);
assert_eq!(result, f64::INFINITY);
}
#[test]
fn test_next_up_nan() {
let value: f64 = f64::NAN;
let result = next_up(value);
assert!(result.is_nan());
}
#[test]
fn test_next_down_small_positive() {
let value: f64 = 1.0;
let result = next_down(value);
assert_eq!(result, 0.9999999999999999);
}
#[test]
fn test_next_down_small_negative() {
let value: f64 = -1.0;
let result = next_down(value);
assert_eq!(result, -1.0000000000000002);
}
#[test]
fn test_next_down_neg_infinity() {
let value: f64 = f64::NEG_INFINITY;
let result = next_down(value);
assert_eq!(result, f64::NEG_INFINITY);
}
#[test]
fn test_next_down_nan() {
let value: f64 = f64::NAN;
let result = next_down(value);
assert!(result.is_nan());
}
#[test]
fn test_next_up_small_positive_f32() {
let value: f32 = 1.0;
let result = next_up(value);
assert_eq!(result, 1.0000001);
}
#[test]
fn test_next_up_small_negative_f32() {
let value: f32 = -1.0;
let result = next_up(value);
assert_eq!(result, -0.99999994);
}
#[test]
fn test_next_up_pos_infinity_f32() {
let value: f32 = f32::INFINITY;
let result = next_up(value);
assert_eq!(result, f32::INFINITY);
}
#[test]
fn test_next_up_nan_f32() {
let value: f32 = f32::NAN;
let result = next_up(value);
assert!(result.is_nan());
}
#[test]
fn test_next_down_small_positive_f32() {
let value: f32 = 1.0;
let result = next_down(value);
assert_eq!(result, 0.99999994);
}
#[test]
fn test_next_down_small_negative_f32() {
let value: f32 = -1.0;
let result = next_down(value);
assert_eq!(result, -1.0000001);
}
#[test]
fn test_next_down_neg_infinity_f32() {
let value: f32 = f32::NEG_INFINITY;
let result = next_down(value);
assert_eq!(result, f32::NEG_INFINITY);
}
#[test]
fn test_next_down_nan_f32() {
let value: f32 = f32::NAN;
let result = next_down(value);
assert!(result.is_nan());
}
}