1use std::fmt::{Debug, Display, Formatter};
2use std::hash::{Hash, Hasher};
3
4use bytes::Bytes;
5use polars_core::chunked_array::cast::CastOptions;
6use polars_core::error::feature_gated;
7use polars_core::prelude::*;
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11pub use super::expr_dyn_fn::*;
12use crate::prelude::*;
13
14#[derive(PartialEq, Clone, Hash)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub enum AggExpr {
17 Min {
18 input: Arc<Expr>,
19 propagate_nans: bool,
20 },
21 Max {
22 input: Arc<Expr>,
23 propagate_nans: bool,
24 },
25 Median(Arc<Expr>),
26 NUnique(Arc<Expr>),
27 First(Arc<Expr>),
28 Last(Arc<Expr>),
29 Mean(Arc<Expr>),
30 Implode(Arc<Expr>),
31 Count(Arc<Expr>, bool),
33 Quantile {
34 expr: Arc<Expr>,
35 quantile: Arc<Expr>,
36 method: QuantileMethod,
37 },
38 Sum(Arc<Expr>),
39 AggGroups(Arc<Expr>),
40 Std(Arc<Expr>, u8),
41 Var(Arc<Expr>, u8),
42}
43
44impl AsRef<Expr> for AggExpr {
45 fn as_ref(&self) -> &Expr {
46 use AggExpr::*;
47 match self {
48 Min { input, .. } => input,
49 Max { input, .. } => input,
50 Median(e) => e,
51 NUnique(e) => e,
52 First(e) => e,
53 Last(e) => e,
54 Mean(e) => e,
55 Implode(e) => e,
56 Count(e, _) => e,
57 Quantile { expr, .. } => expr,
58 Sum(e) => e,
59 AggGroups(e) => e,
60 Std(e, _) => e,
61 Var(e, _) => e,
62 }
63 }
64}
65
66#[derive(Clone, PartialEq)]
72#[must_use]
73#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
74pub enum Expr {
75 Alias(Arc<Expr>, PlSmallStr),
76 Column(PlSmallStr),
77 Columns(Arc<[PlSmallStr]>),
78 DtypeColumn(Vec<DataType>),
79 IndexColumn(Arc<[i64]>),
80 Literal(LiteralValue),
81 BinaryExpr {
82 left: Arc<Expr>,
83 op: Operator,
84 right: Arc<Expr>,
85 },
86 Cast {
87 expr: Arc<Expr>,
88 dtype: DataType,
89 options: CastOptions,
90 },
91 Sort {
92 expr: Arc<Expr>,
93 options: SortOptions,
94 },
95 Gather {
96 expr: Arc<Expr>,
97 idx: Arc<Expr>,
98 returns_scalar: bool,
99 },
100 SortBy {
101 expr: Arc<Expr>,
102 by: Vec<Expr>,
103 sort_options: SortMultipleOptions,
104 },
105 Agg(AggExpr),
106 Ternary {
109 predicate: Arc<Expr>,
110 truthy: Arc<Expr>,
111 falsy: Arc<Expr>,
112 },
113 Function {
114 input: Vec<Expr>,
116 function: FunctionExpr,
118 options: FunctionOptions,
119 },
120 Explode(Arc<Expr>),
121 Filter {
122 input: Arc<Expr>,
123 by: Arc<Expr>,
124 },
125 Window {
127 function: Arc<Expr>,
129 partition_by: Vec<Expr>,
130 order_by: Option<(Arc<Expr>, SortOptions)>,
131 options: WindowType,
132 },
133 Wildcard,
134 Slice {
135 input: Arc<Expr>,
136 offset: Arc<Expr>,
138 length: Arc<Expr>,
139 },
140 Exclude(Arc<Expr>, Vec<Excluded>),
143 KeepName(Arc<Expr>),
145 Len,
146 Nth(i64),
148 RenameAlias {
149 function: SpecialEq<Arc<dyn RenameAliasFn>>,
150 expr: Arc<Expr>,
151 },
152 #[cfg(feature = "dtype-struct")]
153 Field(Arc<[PlSmallStr]>),
154 AnonymousFunction {
155 input: Vec<Expr>,
157 function: OpaqueColumnUdf,
159 output_type: GetOutput,
161 options: FunctionOptions,
162 },
163 SubPlan(SpecialEq<Arc<DslPlan>>, Vec<String>),
164 Selector(super::selector::Selector),
171}
172
173pub type OpaqueColumnUdf = LazySerde<SpecialEq<Arc<dyn ColumnsUdf>>>;
174pub(crate) fn new_column_udf<F: ColumnsUdf + 'static>(func: F) -> OpaqueColumnUdf {
175 LazySerde::Deserialized(SpecialEq::new(Arc::new(func)))
176}
177
178#[derive(Clone)]
179pub enum LazySerde<T: Clone> {
180 Deserialized(T),
181 Bytes(Bytes),
182}
183
184impl<T: PartialEq + Clone> PartialEq for LazySerde<T> {
185 fn eq(&self, other: &Self) -> bool {
186 use LazySerde as L;
187 match (self, other) {
188 (L::Deserialized(a), L::Deserialized(b)) => a == b,
189 (L::Bytes(a), L::Bytes(b)) => a.as_ptr() == b.as_ptr() && a.len() == b.len(),
190 _ => false,
191 }
192 }
193}
194
195impl<T: Clone> Debug for LazySerde<T> {
196 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
197 match self {
198 Self::Bytes(_) => write!(f, "lazy-serde<Bytes>"),
199 Self::Deserialized(_) => write!(f, "lazy-serde<T>"),
200 }
201 }
202}
203
204impl OpaqueColumnUdf {
205 pub fn materialize(self) -> PolarsResult<SpecialEq<Arc<dyn ColumnsUdf>>> {
206 match self {
207 Self::Deserialized(t) => Ok(t),
208 Self::Bytes(b) => {
209 feature_gated!("serde";"python", {
210 python_udf::PythonUdfExpression::try_deserialize(b.as_ref()).map(SpecialEq::new)
211 })
212 },
213 }
214 }
215}
216
217#[allow(clippy::derived_hash_with_manual_eq)]
218impl Hash for Expr {
219 fn hash<H: Hasher>(&self, state: &mut H) {
220 let d = std::mem::discriminant(self);
221 d.hash(state);
222 match self {
223 Expr::Column(name) => name.hash(state),
224 Expr::Columns(names) => names.hash(state),
225 Expr::DtypeColumn(dtypes) => dtypes.hash(state),
226 Expr::IndexColumn(indices) => indices.hash(state),
227 Expr::Literal(lv) => std::mem::discriminant(lv).hash(state),
228 Expr::Selector(s) => s.hash(state),
229 Expr::Nth(v) => v.hash(state),
230 Expr::Filter { input, by } => {
231 input.hash(state);
232 by.hash(state);
233 },
234 Expr::BinaryExpr { left, op, right } => {
235 left.hash(state);
236 right.hash(state);
237 std::mem::discriminant(op).hash(state)
238 },
239 Expr::Cast {
240 expr,
241 dtype,
242 options: strict,
243 } => {
244 expr.hash(state);
245 dtype.hash(state);
246 strict.hash(state)
247 },
248 Expr::Sort { expr, options } => {
249 expr.hash(state);
250 options.hash(state);
251 },
252 Expr::Alias(input, name) => {
253 input.hash(state);
254 name.hash(state)
255 },
256 Expr::KeepName(input) => input.hash(state),
257 Expr::Ternary {
258 predicate,
259 truthy,
260 falsy,
261 } => {
262 predicate.hash(state);
263 truthy.hash(state);
264 falsy.hash(state);
265 },
266 Expr::Function {
267 input,
268 function,
269 options,
270 } => {
271 input.hash(state);
272 std::mem::discriminant(function).hash(state);
273 options.hash(state);
274 },
275 Expr::Gather {
276 expr,
277 idx,
278 returns_scalar,
279 } => {
280 expr.hash(state);
281 idx.hash(state);
282 returns_scalar.hash(state);
283 },
284 Expr::Wildcard | Expr::Len => {},
286 Expr::SortBy {
287 expr,
288 by,
289 sort_options,
290 } => {
291 expr.hash(state);
292 by.hash(state);
293 sort_options.hash(state);
294 },
295 Expr::Agg(input) => input.hash(state),
296 Expr::Explode(input) => input.hash(state),
297 Expr::Window {
298 function,
299 partition_by,
300 order_by,
301 options,
302 } => {
303 function.hash(state);
304 partition_by.hash(state);
305 order_by.hash(state);
306 options.hash(state);
307 },
308 Expr::Slice {
309 input,
310 offset,
311 length,
312 } => {
313 input.hash(state);
314 offset.hash(state);
315 length.hash(state);
316 },
317 Expr::Exclude(input, excl) => {
318 input.hash(state);
319 excl.hash(state);
320 },
321 Expr::RenameAlias { function: _, expr } => expr.hash(state),
322 Expr::AnonymousFunction {
323 input,
324 function: _,
325 output_type: _,
326 options,
327 } => {
328 input.hash(state);
329 options.hash(state);
330 },
331 Expr::SubPlan(_, names) => names.hash(state),
332 #[cfg(feature = "dtype-struct")]
333 Expr::Field(names) => names.hash(state),
334 }
335 }
336}
337
338impl Eq for Expr {}
339
340impl Default for Expr {
341 fn default() -> Self {
342 Expr::Literal(LiteralValue::Null)
343 }
344}
345
346#[derive(Debug, Clone, PartialEq, Eq, Hash)]
347#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
348
349pub enum Excluded {
350 Name(PlSmallStr),
351 Dtype(DataType),
352}
353
354impl Expr {
355 pub fn to_field(&self, schema: &Schema, ctxt: Context) -> PolarsResult<Field> {
357 let mut arena = Arena::with_capacity(5);
359 self.to_field_amortized(schema, ctxt, &mut arena)
360 }
361 pub(crate) fn to_field_amortized(
362 &self,
363 schema: &Schema,
364 ctxt: Context,
365 expr_arena: &mut Arena<AExpr>,
366 ) -> PolarsResult<Field> {
367 let root = to_aexpr(self.clone(), expr_arena)?;
368 expr_arena
369 .get(root)
370 .to_field_and_validate(schema, ctxt, expr_arena)
371 }
372}
373
374#[derive(Copy, Clone, PartialEq, Eq, Hash)]
375#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
376pub enum Operator {
377 Eq,
378 EqValidity,
379 NotEq,
380 NotEqValidity,
381 Lt,
382 LtEq,
383 Gt,
384 GtEq,
385 Plus,
386 Minus,
387 Multiply,
388 Divide,
389 TrueDivide,
390 FloorDivide,
391 Modulus,
392 And,
393 Or,
394 Xor,
395 LogicalAnd,
396 LogicalOr,
397}
398
399impl Display for Operator {
400 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
401 use Operator::*;
402 let tkn = match self {
403 Eq => "==",
404 EqValidity => "==v",
405 NotEq => "!=",
406 NotEqValidity => "!=v",
407 Lt => "<",
408 LtEq => "<=",
409 Gt => ">",
410 GtEq => ">=",
411 Plus => "+",
412 Minus => "-",
413 Multiply => "*",
414 Divide => "//",
415 TrueDivide => "/",
416 FloorDivide => "floor_div",
417 Modulus => "%",
418 And | LogicalAnd => "&",
419 Or | LogicalOr => "|",
420 Xor => "^",
421 };
422 write!(f, "{tkn}")
423 }
424}
425
426impl Operator {
427 pub fn is_comparison(&self) -> bool {
428 matches!(
429 self,
430 Self::Eq
431 | Self::NotEq
432 | Self::Lt
433 | Self::LtEq
434 | Self::Gt
435 | Self::GtEq
436 | Self::And
437 | Self::Or
438 | Self::Xor
439 | Self::EqValidity
440 | Self::NotEqValidity
441 )
442 }
443
444 pub fn swap_operands(self) -> Self {
445 match self {
446 Operator::Eq => Operator::Eq,
447 Operator::Gt => Operator::Lt,
448 Operator::GtEq => Operator::LtEq,
449 Operator::LtEq => Operator::GtEq,
450 Operator::Or => Operator::Or,
451 Operator::LogicalAnd => Operator::LogicalAnd,
452 Operator::LogicalOr => Operator::LogicalOr,
453 Operator::Xor => Operator::Xor,
454 Operator::NotEq => Operator::NotEq,
455 Operator::EqValidity => Operator::EqValidity,
456 Operator::NotEqValidity => Operator::NotEqValidity,
457 Operator::Divide => Operator::Multiply,
458 Operator::Multiply => Operator::Divide,
459 Operator::And => Operator::And,
460 Operator::Plus => Operator::Minus,
461 Operator::Minus => Operator::Plus,
462 Operator::Lt => Operator::Gt,
463 _ => unimplemented!(),
464 }
465 }
466
467 pub fn is_arithmetic(&self) -> bool {
468 !(self.is_comparison())
469 }
470}