datafusion_common/
rounding.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Floating point rounding mode utility library
19//! TODO: Remove this custom implementation and the "libc" dependency when
20//!       floating-point rounding mode manipulation functions become available
21//!       in Rust.
22
23use std::ops::{Add, BitAnd, Sub};
24
25use crate::Result;
26use crate::ScalarValue;
27
28// Define constants for ARM
29#[cfg(all(target_arch = "aarch64", not(target_os = "windows")))]
30const FE_UPWARD: i32 = 0x00400000;
31#[cfg(all(target_arch = "aarch64", not(target_os = "windows")))]
32const FE_DOWNWARD: i32 = 0x00800000;
33
34// Define constants for x86_64
35#[cfg(all(target_arch = "x86_64", not(target_os = "windows")))]
36const FE_UPWARD: i32 = 0x0800;
37#[cfg(all(target_arch = "x86_64", not(target_os = "windows")))]
38const FE_DOWNWARD: i32 = 0x0400;
39
40#[cfg(all(
41    any(target_arch = "x86_64", target_arch = "aarch64"),
42    not(target_os = "windows")
43))]
44extern crate libc;
45
46#[cfg(all(
47    any(target_arch = "x86_64", target_arch = "aarch64"),
48    not(target_os = "windows")
49))]
50extern "C" {
51    fn fesetround(round: i32);
52    fn fegetround() -> i32;
53}
54
55/// A trait to manipulate floating-point types with bitwise operations.
56/// Provides functions to convert a floating-point value to/from its bitwise
57/// representation as well as utility methods to handle special values.
58pub trait FloatBits {
59    /// The integer type used for bitwise operations.
60    type Item: Copy
61        + PartialEq
62        + BitAnd<Output = Self::Item>
63        + Add<Output = Self::Item>
64        + Sub<Output = Self::Item>;
65
66    /// The smallest positive floating-point value representable by this type.
67    const TINY_BITS: Self::Item;
68
69    /// The smallest (in magnitude) negative floating-point value representable by this type.
70    const NEG_TINY_BITS: Self::Item;
71
72    /// A mask to clear the sign bit of the floating-point value's bitwise representation.
73    const CLEAR_SIGN_MASK: Self::Item;
74
75    /// The integer value 1, used in bitwise operations.
76    const ONE: Self::Item;
77
78    /// The integer value 0, used in bitwise operations.
79    const ZERO: Self::Item;
80
81    /// Converts the floating-point value to its bitwise representation.
82    fn to_bits(self) -> Self::Item;
83
84    /// Converts the bitwise representation to the corresponding floating-point value.
85    fn from_bits(bits: Self::Item) -> Self;
86
87    /// Returns true if the floating-point value is NaN (not a number).
88    fn float_is_nan(self) -> bool;
89
90    /// Returns the positive infinity value for this floating-point type.
91    fn infinity() -> Self;
92
93    /// Returns the negative infinity value for this floating-point type.
94    fn neg_infinity() -> Self;
95}
96
97impl FloatBits for f32 {
98    type Item = u32;
99    const TINY_BITS: u32 = 0x1; // Smallest positive f32.
100    const NEG_TINY_BITS: u32 = 0x8000_0001; // Smallest (in magnitude) negative f32.
101    const CLEAR_SIGN_MASK: u32 = 0x7fff_ffff;
102    const ONE: Self::Item = 1;
103    const ZERO: Self::Item = 0;
104
105    fn to_bits(self) -> Self::Item {
106        self.to_bits()
107    }
108
109    fn from_bits(bits: Self::Item) -> Self {
110        f32::from_bits(bits)
111    }
112
113    fn float_is_nan(self) -> bool {
114        self.is_nan()
115    }
116
117    fn infinity() -> Self {
118        f32::INFINITY
119    }
120
121    fn neg_infinity() -> Self {
122        f32::NEG_INFINITY
123    }
124}
125
126impl FloatBits for f64 {
127    type Item = u64;
128    const TINY_BITS: u64 = 0x1; // Smallest positive f64.
129    const NEG_TINY_BITS: u64 = 0x8000_0000_0000_0001; // Smallest (in magnitude) negative f64.
130    const CLEAR_SIGN_MASK: u64 = 0x7fff_ffff_ffff_ffff;
131    const ONE: Self::Item = 1;
132    const ZERO: Self::Item = 0;
133
134    fn to_bits(self) -> Self::Item {
135        self.to_bits()
136    }
137
138    fn from_bits(bits: Self::Item) -> Self {
139        f64::from_bits(bits)
140    }
141
142    fn float_is_nan(self) -> bool {
143        self.is_nan()
144    }
145
146    fn infinity() -> Self {
147        f64::INFINITY
148    }
149
150    fn neg_infinity() -> Self {
151        f64::NEG_INFINITY
152    }
153}
154
155/// Returns the next representable floating-point value greater than the input value.
156///
157/// This function takes a floating-point value that implements the FloatBits trait,
158/// calculates the next representable value greater than the input, and returns it.
159///
160/// If the input value is NaN or positive infinity, the function returns the input value.
161///
162/// # Examples
163///
164/// ```
165/// use datafusion_common::rounding::next_up;
166///
167/// let f: f32 = 1.0;
168/// let next_f = next_up(f);
169/// assert_eq!(next_f, 1.0000001);
170/// ```
171pub fn next_up<F: FloatBits + Copy>(float: F) -> F {
172    let bits = float.to_bits();
173    if float.float_is_nan() || bits == F::infinity().to_bits() {
174        return float;
175    }
176
177    let abs = bits & F::CLEAR_SIGN_MASK;
178    let next_bits = if abs == F::ZERO {
179        F::TINY_BITS
180    } else if bits == abs {
181        bits + F::ONE
182    } else {
183        bits - F::ONE
184    };
185    F::from_bits(next_bits)
186}
187
188/// Returns the next representable floating-point value smaller than the input value.
189///
190/// This function takes a floating-point value that implements the FloatBits trait,
191/// calculates the next representable value smaller than the input, and returns it.
192///
193/// If the input value is NaN or negative infinity, the function returns the input value.
194///
195/// # Examples
196///
197/// ```
198/// use datafusion_common::rounding::next_down;
199///
200/// let f: f32 = 1.0;
201/// let next_f = next_down(f);
202/// assert_eq!(next_f, 0.99999994);
203/// ```
204pub fn next_down<F: FloatBits + Copy>(float: F) -> F {
205    let bits = float.to_bits();
206    if float.float_is_nan() || bits == F::neg_infinity().to_bits() {
207        return float;
208    }
209    let abs = bits & F::CLEAR_SIGN_MASK;
210    let next_bits = if abs == F::ZERO {
211        F::NEG_TINY_BITS
212    } else if bits == abs {
213        bits - F::ONE
214    } else {
215        bits + F::ONE
216    };
217    F::from_bits(next_bits)
218}
219
220#[cfg(any(
221    not(any(target_arch = "x86_64", target_arch = "aarch64")),
222    target_os = "windows"
223))]
224fn alter_fp_rounding_mode_conservative<const UPPER: bool, F>(
225    lhs: &ScalarValue,
226    rhs: &ScalarValue,
227    operation: F,
228) -> Result<ScalarValue>
229where
230    F: FnOnce(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
231{
232    let mut result = operation(lhs, rhs)?;
233    match &mut result {
234        ScalarValue::Float64(Some(value)) => {
235            if UPPER {
236                *value = next_up(*value)
237            } else {
238                *value = next_down(*value)
239            }
240        }
241        ScalarValue::Float32(Some(value)) => {
242            if UPPER {
243                *value = next_up(*value)
244            } else {
245                *value = next_down(*value)
246            }
247        }
248        _ => {}
249    };
250    Ok(result)
251}
252
253pub fn alter_fp_rounding_mode<const UPPER: bool, F>(
254    lhs: &ScalarValue,
255    rhs: &ScalarValue,
256    operation: F,
257) -> Result<ScalarValue>
258where
259    F: FnOnce(&ScalarValue, &ScalarValue) -> Result<ScalarValue>,
260{
261    #[cfg(all(
262        any(target_arch = "x86_64", target_arch = "aarch64"),
263        not(target_os = "windows")
264    ))]
265    unsafe {
266        let current = fegetround();
267        fesetround(if UPPER { FE_UPWARD } else { FE_DOWNWARD });
268        let result = operation(lhs, rhs);
269        fesetround(current);
270        result
271    }
272    #[cfg(any(
273        not(any(target_arch = "x86_64", target_arch = "aarch64")),
274        target_os = "windows"
275    ))]
276    alter_fp_rounding_mode_conservative::<UPPER, _>(lhs, rhs, operation)
277}
278
279#[cfg(test)]
280mod tests {
281    use super::{next_down, next_up};
282
283    #[test]
284    fn test_next_down() {
285        let x = 1.0f64;
286        // Clamp value into range [0, 1).
287        let clamped = x.clamp(0.0, next_down(1.0f64));
288        assert!(clamped < 1.0);
289        assert_eq!(next_up(clamped), 1.0);
290    }
291
292    #[test]
293    fn test_next_up_small_positive() {
294        let value: f64 = 1.0;
295        let result = next_up(value);
296        assert_eq!(result, 1.0000000000000002);
297    }
298
299    #[test]
300    fn test_next_up_small_negative() {
301        let value: f64 = -1.0;
302        let result = next_up(value);
303        assert_eq!(result, -0.9999999999999999);
304    }
305
306    #[test]
307    fn test_next_up_pos_infinity() {
308        let value: f64 = f64::INFINITY;
309        let result = next_up(value);
310        assert_eq!(result, f64::INFINITY);
311    }
312
313    #[test]
314    fn test_next_up_nan() {
315        let value: f64 = f64::NAN;
316        let result = next_up(value);
317        assert!(result.is_nan());
318    }
319
320    #[test]
321    fn test_next_down_small_positive() {
322        let value: f64 = 1.0;
323        let result = next_down(value);
324        assert_eq!(result, 0.9999999999999999);
325    }
326
327    #[test]
328    fn test_next_down_small_negative() {
329        let value: f64 = -1.0;
330        let result = next_down(value);
331        assert_eq!(result, -1.0000000000000002);
332    }
333
334    #[test]
335    fn test_next_down_neg_infinity() {
336        let value: f64 = f64::NEG_INFINITY;
337        let result = next_down(value);
338        assert_eq!(result, f64::NEG_INFINITY);
339    }
340
341    #[test]
342    fn test_next_down_nan() {
343        let value: f64 = f64::NAN;
344        let result = next_down(value);
345        assert!(result.is_nan());
346    }
347
348    #[test]
349    fn test_next_up_small_positive_f32() {
350        let value: f32 = 1.0;
351        let result = next_up(value);
352        assert_eq!(result, 1.0000001);
353    }
354
355    #[test]
356    fn test_next_up_small_negative_f32() {
357        let value: f32 = -1.0;
358        let result = next_up(value);
359        assert_eq!(result, -0.99999994);
360    }
361
362    #[test]
363    fn test_next_up_pos_infinity_f32() {
364        let value: f32 = f32::INFINITY;
365        let result = next_up(value);
366        assert_eq!(result, f32::INFINITY);
367    }
368
369    #[test]
370    fn test_next_up_nan_f32() {
371        let value: f32 = f32::NAN;
372        let result = next_up(value);
373        assert!(result.is_nan());
374    }
375    #[test]
376    fn test_next_down_small_positive_f32() {
377        let value: f32 = 1.0;
378        let result = next_down(value);
379        assert_eq!(result, 0.99999994);
380    }
381    #[test]
382    fn test_next_down_small_negative_f32() {
383        let value: f32 = -1.0;
384        let result = next_down(value);
385        assert_eq!(result, -1.0000001);
386    }
387    #[test]
388    fn test_next_down_neg_infinity_f32() {
389        let value: f32 = f32::NEG_INFINITY;
390        let result = next_down(value);
391        assert_eq!(result, f32::NEG_INFINITY);
392    }
393    #[test]
394    fn test_next_down_nan_f32() {
395        let value: f32 = f32::NAN;
396        let result = next_down(value);
397        assert!(result.is_nan());
398    }
399}