polars_plan/dsl/function_expr/
pow.rs1use num_traits::pow::Pow;
2use num_traits::{Float, One, ToPrimitive, Zero};
3use polars_core::prelude::arity::{broadcast_binary_elementwise, unary_elementwise_values};
4use polars_core::with_match_physical_integer_type;
5
6use super::*;
7
8#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
9#[derive(Clone, Copy, PartialEq, Debug, Eq, Hash)]
10pub enum PowFunction {
11 Generic,
12 Sqrt,
13 Cbrt,
14}
15
16impl Display for PowFunction {
17 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
18 use self::*;
19 match self {
20 PowFunction::Generic => write!(f, "pow"),
21 PowFunction::Sqrt => write!(f, "sqrt"),
22 PowFunction::Cbrt => write!(f, "cbrt"),
23 }
24 }
25}
26
27fn pow_on_chunked_arrays<T, F>(
28 base: &ChunkedArray<T>,
29 exponent: &ChunkedArray<F>,
30) -> ChunkedArray<T>
31where
32 T: PolarsNumericType,
33 F: PolarsNumericType,
34 T::Native: Pow<F::Native, Output = T::Native> + ToPrimitive,
35{
36 if exponent.len() == 1 {
37 if let Some(e) = exponent.get(0) {
38 if e == F::Native::zero() {
39 return unary_elementwise_values(base, |_| T::Native::one());
40 }
41 if e == F::Native::one() {
42 return base.clone();
43 }
44 if e == F::Native::one() + F::Native::one() {
45 return base * base;
46 }
47 }
48 }
49
50 broadcast_binary_elementwise(base, exponent, |b, e| Some(Pow::pow(b?, e?)))
51}
52
53fn pow_on_floats<T>(
54 base: &ChunkedArray<T>,
55 exponent: &ChunkedArray<T>,
56) -> PolarsResult<Option<Column>>
57where
58 T: PolarsFloatType,
59 T::Native: Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
60 ChunkedArray<T>: IntoColumn,
61{
62 let dtype = T::get_dtype();
63
64 if exponent.len() == 1 {
65 let Some(exponent_value) = exponent.get(0) else {
66 return Ok(Some(Column::full_null(
67 base.name().clone(),
68 base.len(),
69 &dtype,
70 )));
71 };
72 let s = match exponent_value.to_f64().unwrap() {
73 1.0 => base.clone().into_column(),
74 0.5 => base.apply_values(|v| v.sqrt()).into_column(),
77 a if a.fract() == 0.0 && a < 10.0 && a > 1.0 => {
78 let mut out = base.clone();
79
80 for _ in 1..exponent_value.to_u8().unwrap() {
81 out = out * base.clone()
82 }
83 out.into_column()
84 },
85 _ => base
86 .apply_values(|v| Pow::pow(v, exponent_value))
87 .into_column(),
88 };
89 Ok(Some(s))
90 } else {
91 Ok(Some(pow_on_chunked_arrays(base, exponent).into_column()))
92 }
93}
94
95fn pow_to_uint_dtype<T, F>(
96 base: &ChunkedArray<T>,
97 exponent: &ChunkedArray<F>,
98) -> PolarsResult<Option<Column>>
99where
100 T: PolarsIntegerType,
101 F: PolarsIntegerType,
102 T::Native: Pow<F::Native, Output = T::Native> + ToPrimitive,
103 ChunkedArray<T>: IntoColumn,
104{
105 let dtype = T::get_dtype();
106
107 if exponent.len() == 1 {
108 let Some(exponent_value) = exponent.get(0) else {
109 return Ok(Some(Column::full_null(
110 base.name().clone(),
111 base.len(),
112 &dtype,
113 )));
114 };
115 let s = match exponent_value.to_u64().unwrap() {
116 1 => base.clone().into_column(),
117 2..=10 => {
118 let mut out = base.clone();
119
120 for _ in 1..exponent_value.to_u8().unwrap() {
121 out = out * base.clone()
122 }
123 out.into_column()
124 },
125 _ => base
126 .apply_values(|v| Pow::pow(v, exponent_value))
127 .into_column(),
128 };
129 Ok(Some(s))
130 } else {
131 Ok(Some(pow_on_chunked_arrays(base, exponent).into_column()))
132 }
133}
134
135fn pow_on_series(base: &Column, exponent: &Column) -> PolarsResult<Option<Column>> {
136 use DataType::*;
137
138 let base_dtype = base.dtype();
139 polars_ensure!(
140 base_dtype.is_primitive_numeric(),
141 InvalidOperation: "`pow` operation not supported for dtype `{}` as base", base_dtype
142 );
143 let exponent_dtype = exponent.dtype();
144 polars_ensure!(
145 exponent_dtype.is_primitive_numeric(),
146 InvalidOperation: "`pow` operation not supported for dtype `{}` as exponent", exponent_dtype
147 );
148
149 if base_dtype.is_integer() {
151 with_match_physical_integer_type!(base_dtype, |$native_type| {
152 if exponent_dtype.is_float() {
153 match exponent_dtype {
154 Float32 => {
155 let ca = base.cast(&DataType::Float32)?;
156 pow_on_floats(ca.f32().unwrap(), exponent.f32().unwrap())
157 },
158 Float64 => {
159 let ca = base.cast(&DataType::Float64)?;
160 pow_on_floats(ca.f64().unwrap(), exponent.f64().unwrap())
161 },
162 _ => unreachable!(),
163 }
164 } else {
165 let ca = base.$native_type().unwrap();
166 let exponent = exponent.strict_cast(&DataType::UInt32).map_err(|err| polars_err!(
167 InvalidOperation:
168 "{}\n\nHint: if you were trying to raise an integer to a negative integer power, please cast your base or exponent to float first.",
169 err
170 ))?;
171 pow_to_uint_dtype(ca, exponent.u32().unwrap())
172 }
173 })
174 } else {
175 match base_dtype {
176 Float32 => {
177 let ca = base.f32().unwrap();
178 let exponent = exponent.strict_cast(&DataType::Float32)?;
179 pow_on_floats(ca, exponent.f32().unwrap())
180 },
181 Float64 => {
182 let ca = base.f64().unwrap();
183 let exponent = exponent.strict_cast(&DataType::Float64)?;
184 pow_on_floats(ca, exponent.f64().unwrap())
185 },
186 _ => unreachable!(),
187 }
188 }
189}
190
191pub(super) fn pow(s: &mut [Column]) -> PolarsResult<Option<Column>> {
192 let base = &s[0];
193 let exponent = &s[1];
194
195 let base_len = base.len();
196 let exp_len = exponent.len();
197 match (base_len, exp_len) {
198 (1, _) | (_, 1) => pow_on_series(base, exponent),
199 (len_a, len_b) if len_a == len_b => pow_on_series(base, exponent),
200 _ => polars_bail!(
201 ComputeError:
202 "exponent shape: {} in `pow` expression does not match that of the base: {}",
203 exp_len, base_len,
204 ),
205 }
206}
207
208pub(super) fn sqrt(base: &Column) -> PolarsResult<Column> {
209 use DataType::*;
210 match base.dtype() {
211 Float32 => {
212 let ca = base.f32().unwrap();
213 sqrt_on_floats(ca)
214 },
215 Float64 => {
216 let ca = base.f64().unwrap();
217 sqrt_on_floats(ca)
218 },
219 _ => {
220 let base = base.cast(&DataType::Float64)?;
221 sqrt(&base)
222 },
223 }
224}
225
226fn sqrt_on_floats<T>(base: &ChunkedArray<T>) -> PolarsResult<Column>
227where
228 T: PolarsFloatType,
229 T::Native: Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
230 ChunkedArray<T>: IntoColumn,
231{
232 Ok(base.apply_values(|v| v.sqrt()).into_column())
233}
234
235pub(super) fn cbrt(base: &Column) -> PolarsResult<Column> {
236 use DataType::*;
237 match base.dtype() {
238 Float32 => {
239 let ca = base.f32().unwrap();
240 cbrt_on_floats(ca)
241 },
242 Float64 => {
243 let ca = base.f64().unwrap();
244 cbrt_on_floats(ca)
245 },
246 _ => {
247 let base = base.cast(&DataType::Float64)?;
248 cbrt(&base)
249 },
250 }
251}
252
253fn cbrt_on_floats<T>(base: &ChunkedArray<T>) -> PolarsResult<Column>
254where
255 T: PolarsFloatType,
256 T::Native: Pow<T::Native, Output = T::Native> + ToPrimitive + Float,
257 ChunkedArray<T>: IntoColumn,
258{
259 Ok(base.apply_values(|v| v.cbrt()).into_column())
260}