lance_arrow/
floats.rs

1// SPDX-License-Identifier: Apache-2.0
2// SPDX-FileCopyrightText: Copyright The Lance Authors
3
4//! Floats Array
5
6use std::fmt::{Debug, Display};
7use std::iter::Sum;
8use std::sync::Arc;
9use std::{
10    fmt::Formatter,
11    ops::{AddAssign, DivAssign},
12};
13
14use arrow_array::{
15    types::{Float16Type, Float32Type, Float64Type},
16    Array, Float16Array, Float32Array, Float64Array,
17};
18use arrow_schema::{DataType, Field};
19use half::{bf16, f16};
20use num_traits::{AsPrimitive, Bounded, Float, FromPrimitive};
21
22use super::bfloat16::{BFloat16Array, BFloat16Type};
23use crate::bfloat16::is_bfloat16_field;
24use crate::Result;
25
26/// Float data type.
27///
28/// This helps differentiate between the different float types,
29/// because bf16 is not officially supported [DataType] in arrow-rs.
30#[derive(Debug)]
31pub enum FloatType {
32    BFloat16,
33    Float16,
34    Float32,
35    Float64,
36}
37
38impl std::fmt::Display for FloatType {
39    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
40        match self {
41            Self::BFloat16 => write!(f, "bfloat16"),
42            Self::Float16 => write!(f, "float16"),
43            Self::Float32 => write!(f, "float32"),
44            Self::Float64 => write!(f, "float64"),
45        }
46    }
47}
48
49/// Try to convert a [DataType] to a [FloatType]. To support bfloat16, always
50/// prefer using the `TryFrom<&Field>` implementation.
51impl TryFrom<&DataType> for FloatType {
52    type Error = crate::ArrowError;
53
54    fn try_from(value: &DataType) -> Result<Self> {
55        match *value {
56            DataType::Float16 => Ok(Self::Float16),
57            DataType::Float32 => Ok(Self::Float32),
58            DataType::Float64 => Ok(Self::Float64),
59            _ => Err(crate::ArrowError::InvalidArgumentError(format!(
60                "{:?} is not a floating type",
61                value
62            ))),
63        }
64    }
65}
66
67impl TryFrom<&Field> for FloatType {
68    type Error = crate::ArrowError;
69
70    fn try_from(field: &Field) -> Result<Self> {
71        match field.data_type() {
72            DataType::FixedSizeBinary(2) if is_bfloat16_field(field) => Ok(Self::BFloat16),
73            _ => Self::try_from(field.data_type()),
74        }
75    }
76}
77
78/// Trait for float types used in Arrow Array.
79///
80pub trait ArrowFloatType: Debug {
81    type Native: FromPrimitive
82        + FloatToArrayType<ArrowType = Self>
83        + AsPrimitive<f32>
84        + Debug
85        + Display;
86
87    const FLOAT_TYPE: FloatType;
88    const MIN: Self::Native;
89    const MAX: Self::Native;
90
91    /// Arrow Float Array Type.
92    type ArrayType: FloatArray<Self>;
93
94    /// Returns empty array of this type.
95    fn empty_array() -> Self::ArrayType {
96        Vec::<Self::Native>::new().into()
97    }
98}
99
100pub trait FloatToArrayType:
101    Float
102    + Bounded
103    + Sum
104    + AddAssign<Self>
105    + AsPrimitive<f64>
106    + AsPrimitive<f32>
107    + DivAssign
108    + Send
109    + Sync
110    + Copy
111{
112    type ArrowType: ArrowFloatType<Native = Self>;
113}
114
115impl FloatToArrayType for bf16 {
116    type ArrowType = BFloat16Type;
117}
118
119impl FloatToArrayType for f16 {
120    type ArrowType = Float16Type;
121}
122
123impl FloatToArrayType for f32 {
124    type ArrowType = Float32Type;
125}
126
127impl FloatToArrayType for f64 {
128    type ArrowType = Float64Type;
129}
130
131impl ArrowFloatType for BFloat16Type {
132    type Native = bf16;
133
134    const FLOAT_TYPE: FloatType = FloatType::BFloat16;
135    const MIN: Self::Native = bf16::MIN;
136    const MAX: Self::Native = bf16::MAX;
137
138    type ArrayType = BFloat16Array;
139}
140
141impl ArrowFloatType for Float16Type {
142    type Native = f16;
143
144    const FLOAT_TYPE: FloatType = FloatType::Float16;
145    const MIN: Self::Native = f16::MIN;
146    const MAX: Self::Native = f16::MAX;
147
148    type ArrayType = Float16Array;
149}
150
151impl ArrowFloatType for Float32Type {
152    type Native = f32;
153
154    const FLOAT_TYPE: FloatType = FloatType::Float32;
155    const MIN: Self::Native = f32::MIN;
156    const MAX: Self::Native = f32::MAX;
157
158    type ArrayType = Float32Array;
159}
160
161impl ArrowFloatType for Float64Type {
162    type Native = f64;
163
164    const FLOAT_TYPE: FloatType = FloatType::Float64;
165    const MIN: Self::Native = f64::MIN;
166    const MAX: Self::Native = f64::MAX;
167
168    type ArrayType = Float64Array;
169}
170
171/// [FloatArray] is a trait that is implemented by all float type arrays.
172pub trait FloatArray<T: ArrowFloatType + ?Sized>:
173    Array + Clone + From<Vec<T::Native>> + 'static
174{
175    type FloatType: ArrowFloatType;
176
177    /// Returns a reference to the underlying data as a slice.
178    fn as_slice(&self) -> &[T::Native];
179}
180
181impl FloatArray<Float16Type> for Float16Array {
182    type FloatType = Float16Type;
183
184    fn as_slice(&self) -> &[<Float16Type as ArrowFloatType>::Native] {
185        self.values()
186    }
187}
188
189impl FloatArray<Float32Type> for Float32Array {
190    type FloatType = Float32Type;
191
192    fn as_slice(&self) -> &[<Float32Type as ArrowFloatType>::Native] {
193        self.values()
194    }
195}
196
197impl FloatArray<Float64Type> for Float64Array {
198    type FloatType = Float64Type;
199
200    fn as_slice(&self) -> &[<Float64Type as ArrowFloatType>::Native] {
201        self.values()
202    }
203}
204
205/// Convert a float32 array to another float array.
206pub fn coerce_float_vector(input: &Float32Array, float_type: FloatType) -> Result<Arc<dyn Array>> {
207    match float_type {
208        FloatType::BFloat16 => Ok(Arc::new(BFloat16Array::from_iter_values(
209            input.values().iter().map(|v| bf16::from_f32(*v)),
210        ))),
211        FloatType::Float16 => Ok(Arc::new(Float16Array::from_iter_values(
212            input.values().iter().map(|v| f16::from_f32(*v)),
213        ))),
214        FloatType::Float32 => Ok(Arc::new(input.clone())),
215        FloatType::Float64 => Ok(Arc::new(Float64Array::from_iter_values(
216            input.values().iter().map(|v| *v as f64),
217        ))),
218    }
219}