polars_plan/dsl/
expr.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::hash::{Hash, Hasher};
3
4use bytes::Bytes;
5use polars_core::chunked_array::cast::CastOptions;
6use polars_core::error::feature_gated;
7use polars_core::prelude::*;
8#[cfg(feature = "serde")]
9use serde::{Deserialize, Serialize};
10
11pub use super::expr_dyn_fn::*;
12use crate::prelude::*;
13
14#[derive(PartialEq, Clone, Hash)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub enum AggExpr {
17    Min {
18        input: Arc<Expr>,
19        propagate_nans: bool,
20    },
21    Max {
22        input: Arc<Expr>,
23        propagate_nans: bool,
24    },
25    Median(Arc<Expr>),
26    NUnique(Arc<Expr>),
27    First(Arc<Expr>),
28    Last(Arc<Expr>),
29    Mean(Arc<Expr>),
30    Implode(Arc<Expr>),
31    // include_nulls
32    Count(Arc<Expr>, bool),
33    Quantile {
34        expr: Arc<Expr>,
35        quantile: Arc<Expr>,
36        method: QuantileMethod,
37    },
38    Sum(Arc<Expr>),
39    AggGroups(Arc<Expr>),
40    Std(Arc<Expr>, u8),
41    Var(Arc<Expr>, u8),
42}
43
44impl AsRef<Expr> for AggExpr {
45    fn as_ref(&self) -> &Expr {
46        use AggExpr::*;
47        match self {
48            Min { input, .. } => input,
49            Max { input, .. } => input,
50            Median(e) => e,
51            NUnique(e) => e,
52            First(e) => e,
53            Last(e) => e,
54            Mean(e) => e,
55            Implode(e) => e,
56            Count(e, _) => e,
57            Quantile { expr, .. } => expr,
58            Sum(e) => e,
59            AggGroups(e) => e,
60            Std(e, _) => e,
61            Var(e, _) => e,
62        }
63    }
64}
65
66/// Expressions that can be used in various contexts.
67///
68/// Queries consist of multiple expressions.
69/// When using the polars lazy API, don't construct an `Expr` directly; instead, create one using
70/// the functions in the `polars_lazy::dsl` module. See that module's docs for more info.
71#[derive(Clone, PartialEq)]
72#[must_use]
73#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
74pub enum Expr {
75    Alias(Arc<Expr>, PlSmallStr),
76    Column(PlSmallStr),
77    Columns(Arc<[PlSmallStr]>),
78    DtypeColumn(Vec<DataType>),
79    IndexColumn(Arc<[i64]>),
80    Literal(LiteralValue),
81    BinaryExpr {
82        left: Arc<Expr>,
83        op: Operator,
84        right: Arc<Expr>,
85    },
86    Cast {
87        expr: Arc<Expr>,
88        dtype: DataType,
89        options: CastOptions,
90    },
91    Sort {
92        expr: Arc<Expr>,
93        options: SortOptions,
94    },
95    Gather {
96        expr: Arc<Expr>,
97        idx: Arc<Expr>,
98        returns_scalar: bool,
99    },
100    SortBy {
101        expr: Arc<Expr>,
102        by: Vec<Expr>,
103        sort_options: SortMultipleOptions,
104    },
105    Agg(AggExpr),
106    /// A ternary operation
107    /// if true then "foo" else "bar"
108    Ternary {
109        predicate: Arc<Expr>,
110        truthy: Arc<Expr>,
111        falsy: Arc<Expr>,
112    },
113    Function {
114        /// function arguments
115        input: Vec<Expr>,
116        /// function to apply
117        function: FunctionExpr,
118        options: FunctionOptions,
119    },
120    Explode(Arc<Expr>),
121    Filter {
122        input: Arc<Expr>,
123        by: Arc<Expr>,
124    },
125    /// Polars flavored window functions.
126    Window {
127        /// Also has the input. i.e. avg("foo")
128        function: Arc<Expr>,
129        partition_by: Vec<Expr>,
130        order_by: Option<(Arc<Expr>, SortOptions)>,
131        options: WindowType,
132    },
133    Wildcard,
134    Slice {
135        input: Arc<Expr>,
136        /// length is not yet known so we accept negative offsets
137        offset: Arc<Expr>,
138        length: Arc<Expr>,
139    },
140    /// Can be used in a select statement to exclude a column from selection
141    /// TODO: See if we can replace `Vec<Excluded>` with `Arc<Excluded>`
142    Exclude(Arc<Expr>, Vec<Excluded>),
143    /// Set root name as Alias
144    KeepName(Arc<Expr>),
145    Len,
146    /// Take the nth column in the `DataFrame`
147    Nth(i64),
148    RenameAlias {
149        function: SpecialEq<Arc<dyn RenameAliasFn>>,
150        expr: Arc<Expr>,
151    },
152    #[cfg(feature = "dtype-struct")]
153    Field(Arc<[PlSmallStr]>),
154    AnonymousFunction {
155        /// function arguments
156        input: Vec<Expr>,
157        /// function to apply
158        function: OpaqueColumnUdf,
159        /// output dtype of the function
160        output_type: GetOutput,
161        options: FunctionOptions,
162    },
163    SubPlan(SpecialEq<Arc<DslPlan>>, Vec<String>),
164    /// Expressions in this node should only be expanding
165    /// e.g.
166    /// `Expr::Columns`
167    /// `Expr::Dtypes`
168    /// `Expr::Wildcard`
169    /// `Expr::Exclude`
170    Selector(super::selector::Selector),
171}
172
173pub type OpaqueColumnUdf = LazySerde<SpecialEq<Arc<dyn ColumnsUdf>>>;
174pub(crate) fn new_column_udf<F: ColumnsUdf + 'static>(func: F) -> OpaqueColumnUdf {
175    LazySerde::Deserialized(SpecialEq::new(Arc::new(func)))
176}
177
178#[derive(Clone)]
179pub enum LazySerde<T: Clone> {
180    Deserialized(T),
181    Bytes(Bytes),
182}
183
184impl<T: PartialEq + Clone> PartialEq for LazySerde<T> {
185    fn eq(&self, other: &Self) -> bool {
186        use LazySerde as L;
187        match (self, other) {
188            (L::Deserialized(a), L::Deserialized(b)) => a == b,
189            (L::Bytes(a), L::Bytes(b)) => a.as_ptr() == b.as_ptr() && a.len() == b.len(),
190            _ => false,
191        }
192    }
193}
194
195impl<T: Clone> Debug for LazySerde<T> {
196    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
197        match self {
198            Self::Bytes(_) => write!(f, "lazy-serde<Bytes>"),
199            Self::Deserialized(_) => write!(f, "lazy-serde<T>"),
200        }
201    }
202}
203
204impl OpaqueColumnUdf {
205    pub fn materialize(self) -> PolarsResult<SpecialEq<Arc<dyn ColumnsUdf>>> {
206        match self {
207            Self::Deserialized(t) => Ok(t),
208            Self::Bytes(b) => {
209                feature_gated!("serde";"python", {
210                    python_udf::PythonUdfExpression::try_deserialize(b.as_ref()).map(SpecialEq::new)
211                })
212            },
213        }
214    }
215}
216
217#[allow(clippy::derived_hash_with_manual_eq)]
218impl Hash for Expr {
219    fn hash<H: Hasher>(&self, state: &mut H) {
220        let d = std::mem::discriminant(self);
221        d.hash(state);
222        match self {
223            Expr::Column(name) => name.hash(state),
224            Expr::Columns(names) => names.hash(state),
225            Expr::DtypeColumn(dtypes) => dtypes.hash(state),
226            Expr::IndexColumn(indices) => indices.hash(state),
227            Expr::Literal(lv) => std::mem::discriminant(lv).hash(state),
228            Expr::Selector(s) => s.hash(state),
229            Expr::Nth(v) => v.hash(state),
230            Expr::Filter { input, by } => {
231                input.hash(state);
232                by.hash(state);
233            },
234            Expr::BinaryExpr { left, op, right } => {
235                left.hash(state);
236                right.hash(state);
237                std::mem::discriminant(op).hash(state)
238            },
239            Expr::Cast {
240                expr,
241                dtype,
242                options: strict,
243            } => {
244                expr.hash(state);
245                dtype.hash(state);
246                strict.hash(state)
247            },
248            Expr::Sort { expr, options } => {
249                expr.hash(state);
250                options.hash(state);
251            },
252            Expr::Alias(input, name) => {
253                input.hash(state);
254                name.hash(state)
255            },
256            Expr::KeepName(input) => input.hash(state),
257            Expr::Ternary {
258                predicate,
259                truthy,
260                falsy,
261            } => {
262                predicate.hash(state);
263                truthy.hash(state);
264                falsy.hash(state);
265            },
266            Expr::Function {
267                input,
268                function,
269                options,
270            } => {
271                input.hash(state);
272                std::mem::discriminant(function).hash(state);
273                options.hash(state);
274            },
275            Expr::Gather {
276                expr,
277                idx,
278                returns_scalar,
279            } => {
280                expr.hash(state);
281                idx.hash(state);
282                returns_scalar.hash(state);
283            },
284            // already hashed by discriminant
285            Expr::Wildcard | Expr::Len => {},
286            Expr::SortBy {
287                expr,
288                by,
289                sort_options,
290            } => {
291                expr.hash(state);
292                by.hash(state);
293                sort_options.hash(state);
294            },
295            Expr::Agg(input) => input.hash(state),
296            Expr::Explode(input) => input.hash(state),
297            Expr::Window {
298                function,
299                partition_by,
300                order_by,
301                options,
302            } => {
303                function.hash(state);
304                partition_by.hash(state);
305                order_by.hash(state);
306                options.hash(state);
307            },
308            Expr::Slice {
309                input,
310                offset,
311                length,
312            } => {
313                input.hash(state);
314                offset.hash(state);
315                length.hash(state);
316            },
317            Expr::Exclude(input, excl) => {
318                input.hash(state);
319                excl.hash(state);
320            },
321            Expr::RenameAlias { function: _, expr } => expr.hash(state),
322            Expr::AnonymousFunction {
323                input,
324                function: _,
325                output_type: _,
326                options,
327            } => {
328                input.hash(state);
329                options.hash(state);
330            },
331            Expr::SubPlan(_, names) => names.hash(state),
332            #[cfg(feature = "dtype-struct")]
333            Expr::Field(names) => names.hash(state),
334        }
335    }
336}
337
338impl Eq for Expr {}
339
340impl Default for Expr {
341    fn default() -> Self {
342        Expr::Literal(LiteralValue::Null)
343    }
344}
345
346#[derive(Debug, Clone, PartialEq, Eq, Hash)]
347#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
348
349pub enum Excluded {
350    Name(PlSmallStr),
351    Dtype(DataType),
352}
353
354impl Expr {
355    /// Get Field result of the expression. The schema is the input data.
356    pub fn to_field(&self, schema: &Schema, ctxt: Context) -> PolarsResult<Field> {
357        // this is not called much and the expression depth is typically shallow
358        let mut arena = Arena::with_capacity(5);
359        self.to_field_amortized(schema, ctxt, &mut arena)
360    }
361    pub(crate) fn to_field_amortized(
362        &self,
363        schema: &Schema,
364        ctxt: Context,
365        expr_arena: &mut Arena<AExpr>,
366    ) -> PolarsResult<Field> {
367        let root = to_aexpr(self.clone(), expr_arena)?;
368        expr_arena
369            .get(root)
370            .to_field_and_validate(schema, ctxt, expr_arena)
371    }
372}
373
374#[derive(Copy, Clone, PartialEq, Eq, Hash)]
375#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
376pub enum Operator {
377    Eq,
378    EqValidity,
379    NotEq,
380    NotEqValidity,
381    Lt,
382    LtEq,
383    Gt,
384    GtEq,
385    Plus,
386    Minus,
387    Multiply,
388    Divide,
389    TrueDivide,
390    FloorDivide,
391    Modulus,
392    And,
393    Or,
394    Xor,
395    LogicalAnd,
396    LogicalOr,
397}
398
399impl Display for Operator {
400    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
401        use Operator::*;
402        let tkn = match self {
403            Eq => "==",
404            EqValidity => "==v",
405            NotEq => "!=",
406            NotEqValidity => "!=v",
407            Lt => "<",
408            LtEq => "<=",
409            Gt => ">",
410            GtEq => ">=",
411            Plus => "+",
412            Minus => "-",
413            Multiply => "*",
414            Divide => "//",
415            TrueDivide => "/",
416            FloorDivide => "floor_div",
417            Modulus => "%",
418            And | LogicalAnd => "&",
419            Or | LogicalOr => "|",
420            Xor => "^",
421        };
422        write!(f, "{tkn}")
423    }
424}
425
426impl Operator {
427    pub fn is_comparison(&self) -> bool {
428        matches!(
429            self,
430            Self::Eq
431                | Self::NotEq
432                | Self::Lt
433                | Self::LtEq
434                | Self::Gt
435                | Self::GtEq
436                | Self::And
437                | Self::Or
438                | Self::Xor
439                | Self::EqValidity
440                | Self::NotEqValidity
441        )
442    }
443
444    pub fn swap_operands(self) -> Self {
445        match self {
446            Operator::Eq => Operator::Eq,
447            Operator::Gt => Operator::Lt,
448            Operator::GtEq => Operator::LtEq,
449            Operator::LtEq => Operator::GtEq,
450            Operator::Or => Operator::Or,
451            Operator::LogicalAnd => Operator::LogicalAnd,
452            Operator::LogicalOr => Operator::LogicalOr,
453            Operator::Xor => Operator::Xor,
454            Operator::NotEq => Operator::NotEq,
455            Operator::EqValidity => Operator::EqValidity,
456            Operator::NotEqValidity => Operator::NotEqValidity,
457            Operator::Divide => Operator::Multiply,
458            Operator::Multiply => Operator::Divide,
459            Operator::And => Operator::And,
460            Operator::Plus => Operator::Minus,
461            Operator::Minus => Operator::Plus,
462            Operator::Lt => Operator::Gt,
463            _ => unimplemented!(),
464        }
465    }
466
467    pub fn is_arithmetic(&self) -> bool {
468        !(self.is_comparison())
469    }
470}