polars_plan/dsl/
options.rs

1use std::hash::Hash;
2use std::sync::Arc;
3
4use polars_core::error::PolarsResult;
5#[cfg(feature = "iejoin")]
6use polars_ops::frame::IEJoinOptions;
7use polars_ops::frame::{CrossJoinFilter, CrossJoinOptions, JoinTypeOptions};
8use polars_ops::prelude::{JoinArgs, JoinType};
9#[cfg(feature = "dynamic_group_by")]
10use polars_time::RollingGroupOptions;
11use polars_utils::pl_str::PlSmallStr;
12use polars_utils::IdxSize;
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15use strum_macros::IntoStaticStr;
16
17use super::ExprIR;
18use crate::dsl::Selector;
19
20#[derive(Copy, Clone, PartialEq, Debug, Eq, Hash)]
21#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22pub struct RollingCovOptions {
23    pub window_size: IdxSize,
24    pub min_periods: IdxSize,
25    pub ddof: u8,
26}
27
28#[derive(Clone, PartialEq, Debug, Eq, Hash)]
29#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
30pub struct StrptimeOptions {
31    /// Formatting string
32    pub format: Option<PlSmallStr>,
33    /// If set then polars will return an error if any date parsing fails
34    pub strict: bool,
35    /// If polars may parse matches that not contain the whole string
36    /// e.g. "foo-2021-01-01-bar" could match "2021-01-01"
37    pub exact: bool,
38    /// use a cache of unique, converted dates to apply the datetime conversion.
39    pub cache: bool,
40}
41
42impl Default for StrptimeOptions {
43    fn default() -> Self {
44        StrptimeOptions {
45            format: None,
46            strict: true,
47            exact: true,
48            cache: true,
49        }
50    }
51}
52
53#[derive(Clone, PartialEq, Eq, IntoStaticStr, Debug)]
54#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
55#[strum(serialize_all = "snake_case")]
56pub enum JoinTypeOptionsIR {
57    #[cfg(feature = "iejoin")]
58    IEJoin(IEJoinOptions),
59    #[cfg_attr(feature = "serde", serde(skip))]
60    // Fused cross join and filter (only in in-memory engine)
61    Cross { predicate: ExprIR },
62}
63
64impl Hash for JoinTypeOptionsIR {
65    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
66        use JoinTypeOptionsIR::*;
67        match self {
68            #[cfg(feature = "iejoin")]
69            IEJoin(opt) => opt.hash(state),
70            Cross { predicate } => predicate.node().hash(state),
71        }
72    }
73}
74
75impl JoinTypeOptionsIR {
76    pub fn compile<C: FnOnce(&ExprIR) -> PolarsResult<Arc<dyn CrossJoinFilter>>>(
77        self,
78        plan: C,
79    ) -> PolarsResult<JoinTypeOptions> {
80        use JoinTypeOptionsIR::*;
81        match self {
82            Cross { predicate } => {
83                let predicate = plan(&predicate)?;
84
85                Ok(JoinTypeOptions::Cross(CrossJoinOptions { predicate }))
86            },
87            #[cfg(feature = "iejoin")]
88            IEJoin(opt) => Ok(JoinTypeOptions::IEJoin(opt)),
89        }
90    }
91}
92
93#[derive(Clone, Debug, PartialEq, Eq, Hash)]
94#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
95pub struct JoinOptions {
96    pub allow_parallel: bool,
97    pub force_parallel: bool,
98    pub args: JoinArgs,
99    pub options: Option<JoinTypeOptionsIR>,
100    /// Proxy of the number of rows in both sides of the joins
101    /// Holds `(Option<known_size>, estimated_size)`
102    pub rows_left: (Option<usize>, usize),
103    pub rows_right: (Option<usize>, usize),
104}
105
106impl Default for JoinOptions {
107    fn default() -> Self {
108        JoinOptions {
109            allow_parallel: true,
110            force_parallel: false,
111            // Todo!: make default
112            args: JoinArgs::new(JoinType::Left),
113            options: Default::default(),
114            rows_left: (None, usize::MAX),
115            rows_right: (None, usize::MAX),
116        }
117    }
118}
119
120#[derive(Clone, Debug, PartialEq, Eq, Hash)]
121#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
122pub enum WindowType {
123    /// Explode the aggregated list and just do a hstack instead of a join
124    /// this requires the groups to be sorted to make any sense
125    Over(WindowMapping),
126    #[cfg(feature = "dynamic_group_by")]
127    Rolling(RollingGroupOptions),
128}
129
130impl From<WindowMapping> for WindowType {
131    fn from(value: WindowMapping) -> Self {
132        Self::Over(value)
133    }
134}
135
136impl Default for WindowType {
137    fn default() -> Self {
138        Self::Over(WindowMapping::default())
139    }
140}
141
142#[derive(Copy, Clone, Debug, PartialEq, Eq, Default, Hash, IntoStaticStr)]
143#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
144#[strum(serialize_all = "snake_case")]
145pub enum WindowMapping {
146    /// Map the group values to the position
147    #[default]
148    GroupsToRows,
149    /// Explode the aggregated list and just do a hstack instead of a join
150    /// this requires the groups to be sorted to make any sense
151    Explode,
152    /// Join the groups as 'List<group_dtype>' to the row positions.
153    /// warning: this can be memory intensive
154    Join,
155}
156
157#[derive(Clone, Debug, PartialEq, Eq, Hash)]
158#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
159pub enum NestedType {
160    #[cfg(feature = "dtype-array")]
161    Array,
162    // List,
163}
164
165#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
166#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
167pub struct UnpivotArgsDSL {
168    pub on: Vec<Selector>,
169    pub index: Vec<Selector>,
170    pub variable_name: Option<PlSmallStr>,
171    pub value_name: Option<PlSmallStr>,
172}