polars_plan/dsl/function_expr/
pow.rs

1use num_traits::pow::Pow;
2use num_traits::{Float, One, ToPrimitive, Zero};
3use polars_core::prelude::arity::{broadcast_binary_elementwise, unary_elementwise_values};
4use polars_core::with_match_physical_integer_type;
5
6use super::*;
7
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
10pub enum PowFunction {
11    Generic,
12    Sqrt,
13    Cbrt,
14}
15
16impl Display for PowFunction {
17    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
18        use self::*;
19        match self {
20            PowFunction::Generic => write!(f, "pow"),
21            PowFunction::Sqrt => write!(f, "sqrt"),
22            PowFunction::Cbrt => write!(f, "cbrt"),
23        }
24    }
25}
26
27fn pow_on_chunked_arrays<T, F>(
28    base: &ChunkedArray<T>,
29    exponent: &ChunkedArray<F>,
30) -> ChunkedArray<T>
31where
32    T: PolarsNumericType,
33    F: PolarsNumericType,
34    T::Native: Pow<F::Native, Output = T::Native> + ToPrimitive,
35{
36    if exponent.len() == 1 {
37        if let Some(e) = exponent.get(0) {
38            if e == F::Native::zero() {
39                return unary_elementwise_values(base, |_| T::Native::one());
40            }
41            if e == F::Native::one() {
42                return base.clone();
43            }
44            if e == F::Native::one() + F::Native::one() {
45                return base * base;
46            }
47        }
48    }
49
50    broadcast_binary_elementwise(base, exponent, |b, e| Some(Pow::pow(b?, e?)))
51}
52
53fn pow_on_floats<T>(
54    base: &ChunkedArray<T>,
55    exponent: &ChunkedArray<T>,
56) -> PolarsResult<Option<Column>>
57where
58    T: PolarsFloatType,
59    T::Native: Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
60    ChunkedArray<T>: IntoColumn,
61{
62    let dtype = T::get_dtype();
63
64    if exponent.len() == 1 {
65        let Some(exponent_value) = exponent.get(0) else {
66            return Ok(Some(Column::full_null(
67                base.name().clone(),
68                base.len(),
69                &dtype,
70            )));
71        };
72        let s = match exponent_value.to_f64().unwrap() {
73            1.0 => base.clone().into_column(),
74            // specialized sqrt will ensure (-inf)^0.5 = NaN
75            // and will likely be faster as well.
76            0.5 => base.apply_values(|v| v.sqrt()).into_column(),
77            a if a.fract() == 0.0 && a < 10.0 && a > 1.0 => {
78                let mut out = base.clone();
79
80                for _ in 1..exponent_value.to_u8().unwrap() {
81                    out = out * base.clone()
82                }
83                out.into_column()
84            },
85            _ => base
86                .apply_values(|v| Pow::pow(v, exponent_value))
87                .into_column(),
88        };
89        Ok(Some(s))
90    } else {
91        Ok(Some(pow_on_chunked_arrays(base, exponent).into_column()))
92    }
93}
94
95fn pow_to_uint_dtype<T, F>(
96    base: &ChunkedArray<T>,
97    exponent: &ChunkedArray<F>,
98) -> PolarsResult<Option<Column>>
99where
100    T: PolarsIntegerType,
101    F: PolarsIntegerType,
102    T::Native: Pow<F::Native, Output = T::Native> + ToPrimitive,
103    ChunkedArray<T>: IntoColumn,
104{
105    let dtype = T::get_dtype();
106
107    if exponent.len() == 1 {
108        let Some(exponent_value) = exponent.get(0) else {
109            return Ok(Some(Column::full_null(
110                base.name().clone(),
111                base.len(),
112                &dtype,
113            )));
114        };
115        let s = match exponent_value.to_u64().unwrap() {
116            1 => base.clone().into_column(),
117            2..=10 => {
118                let mut out = base.clone();
119
120                for _ in 1..exponent_value.to_u8().unwrap() {
121                    out = out * base.clone()
122                }
123                out.into_column()
124            },
125            _ => base
126                .apply_values(|v| Pow::pow(v, exponent_value))
127                .into_column(),
128        };
129        Ok(Some(s))
130    } else {
131        Ok(Some(pow_on_chunked_arrays(base, exponent).into_column()))
132    }
133}
134
135fn pow_on_series(base: &Column, exponent: &Column) -> PolarsResult<Option<Column>> {
136    use DataType::*;
137
138    let base_dtype = base.dtype();
139    polars_ensure!(
140        base_dtype.is_primitive_numeric(),
141        InvalidOperation: "`pow` operation not supported for dtype `{}` as base", base_dtype
142    );
143    let exponent_dtype = exponent.dtype();
144    polars_ensure!(
145        exponent_dtype.is_primitive_numeric(),
146        InvalidOperation: "`pow` operation not supported for dtype `{}` as exponent", exponent_dtype
147    );
148
149    // if false, dtype is float
150    if base_dtype.is_integer() {
151        with_match_physical_integer_type!(base_dtype, |$native_type| {
152            if exponent_dtype.is_float() {
153                match exponent_dtype {
154                    Float32 => {
155                        let ca = base.cast(&DataType::Float32)?;
156                        pow_on_floats(ca.f32().unwrap(), exponent.f32().unwrap())
157                    },
158                    Float64 => {
159                        let ca = base.cast(&DataType::Float64)?;
160                        pow_on_floats(ca.f64().unwrap(), exponent.f64().unwrap())
161                    },
162                    _ => unreachable!(),
163                }
164            } else {
165                let ca = base.$native_type().unwrap();
166                let exponent = exponent.strict_cast(&DataType::UInt32).map_err(|err| polars_err!(
167                    InvalidOperation:
168                    "{}\n\nHint: if you were trying to raise an integer to a negative integer power, please cast your base or exponent to float first.",
169                    err
170                ))?;
171                pow_to_uint_dtype(ca, exponent.u32().unwrap())
172            }
173        })
174    } else {
175        match base_dtype {
176            Float32 => {
177                let ca = base.f32().unwrap();
178                let exponent = exponent.strict_cast(&DataType::Float32)?;
179                pow_on_floats(ca, exponent.f32().unwrap())
180            },
181            Float64 => {
182                let ca = base.f64().unwrap();
183                let exponent = exponent.strict_cast(&DataType::Float64)?;
184                pow_on_floats(ca, exponent.f64().unwrap())
185            },
186            _ => unreachable!(),
187        }
188    }
189}
190
191pub(super) fn pow(s: &mut [Column]) -> PolarsResult<Option<Column>> {
192    let base = &s[0];
193    let exponent = &s[1];
194
195    let base_len = base.len();
196    let exp_len = exponent.len();
197    match (base_len, exp_len) {
198        (1, _) | (_, 1) => pow_on_series(base, exponent),
199        (len_a, len_b) if len_a == len_b => pow_on_series(base, exponent),
200        _ => polars_bail!(
201            ComputeError:
202            "exponent shape: {} in `pow` expression does not match that of the base: {}",
203            exp_len, base_len,
204        ),
205    }
206}
207
208pub(super) fn sqrt(base: &Column) -> PolarsResult<Column> {
209    use DataType::*;
210    match base.dtype() {
211        Float32 => {
212            let ca = base.f32().unwrap();
213            sqrt_on_floats(ca)
214        },
215        Float64 => {
216            let ca = base.f64().unwrap();
217            sqrt_on_floats(ca)
218        },
219        _ => {
220            let base = base.cast(&DataType::Float64)?;
221            sqrt(&base)
222        },
223    }
224}
225
226fn sqrt_on_floats<T>(base: &ChunkedArray<T>) -> PolarsResult<Column>
227where
228    T: PolarsFloatType,
229    T::Native: Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
230    ChunkedArray<T>: IntoColumn,
231{
232    Ok(base.apply_values(|v| v.sqrt()).into_column())
233}
234
235pub(super) fn cbrt(base: &Column) -> PolarsResult<Column> {
236    use DataType::*;
237    match base.dtype() {
238        Float32 => {
239            let ca = base.f32().unwrap();
240            cbrt_on_floats(ca)
241        },
242        Float64 => {
243            let ca = base.f64().unwrap();
244            cbrt_on_floats(ca)
245        },
246        _ => {
247            let base = base.cast(&DataType::Float64)?;
248            cbrt(&base)
249        },
250    }
251}
252
253fn cbrt_on_floats<T>(base: &ChunkedArray<T>) -> PolarsResult<Column>
254where
255    T: PolarsFloatType,
256    T::Native: Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
257    ChunkedArray<T>: IntoColumn,
258{
259    Ok(base.apply_values(|v| v.cbrt()).into_column())
260}