polars_plan/plans/functions/
dsl.rs

1use strum_macros::IntoStaticStr;
2
3use super::*;
4
5#[cfg(feature = "python")]
6#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
7#[derive(Clone)]
8pub struct OpaquePythonUdf {
9    pub function: PythonFunction,
10    pub schema: Option<SchemaRef>,
11    ///  allow predicate pushdown optimizations
12    pub predicate_pd: bool,
13    ///  allow projection pushdown optimizations
14    pub projection_pd: bool,
15    pub streamable: bool,
16    pub validate_output: bool,
17}
18
19// Except for Opaque functions, this only has the DSL name of the function.
20#[derive(Clone, IntoStaticStr)]
21#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
22#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
23pub enum DslFunction {
24    RowIndex {
25        name: PlSmallStr,
26        offset: Option<IdxSize>,
27    },
28    // This is both in DSL and IR because we want to be able to serialize it.
29    #[cfg(feature = "python")]
30    OpaquePython(OpaquePythonUdf),
31    Explode {
32        columns: Vec<Selector>,
33        allow_empty: bool,
34    },
35    #[cfg(feature = "pivot")]
36    Unpivot {
37        args: UnpivotArgsDSL,
38    },
39    Rename {
40        existing: Arc<[PlSmallStr]>,
41        new: Arc<[PlSmallStr]>,
42        strict: bool,
43    },
44    Unnest(Vec<Selector>),
45    Stats(StatsFunction),
46    /// FillValue
47    FillNan(Expr),
48    Drop(DropFunction),
49    // Function that is already converted to IR.
50    #[cfg_attr(feature = "serde", serde(skip))]
51    FunctionIR(FunctionIR),
52}
53
54#[derive(Clone)]
55#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
56pub struct DropFunction {
57    /// Columns that are going to be dropped
58    pub(crate) to_drop: Vec<Selector>,
59    /// If `true`, performs a check for each item in `to_drop` against the schema. Returns an
60    /// `ColumnNotFound` error if the column does not exist in the schema.
61    pub(crate) strict: bool,
62}
63
64#[derive(Clone)]
65#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
66pub enum StatsFunction {
67    Var {
68        ddof: u8,
69    },
70    Std {
71        ddof: u8,
72    },
73    Quantile {
74        quantile: Expr,
75        method: QuantileMethod,
76    },
77    Median,
78    Mean,
79    Sum,
80    Min,
81    Max,
82}
83
84pub(crate) fn validate_columns_in_input<S: AsRef<str>, I: IntoIterator<Item = S>>(
85    columns: I,
86    input_schema: &Schema,
87    operation_name: &str,
88) -> PolarsResult<()> {
89    let columns = columns.into_iter();
90    for c in columns {
91        polars_ensure!(input_schema.contains(c.as_ref()), ColumnNotFound: "'{}' on column: '{}' is invalid\n\nSchema at this point: {:?}", operation_name, c.as_ref(), input_schema)
92    }
93    Ok(())
94}
95
96impl DslFunction {
97    pub(crate) fn into_function_ir(self, input_schema: &Schema) -> PolarsResult<FunctionIR> {
98        let function = match self {
99            #[cfg(feature = "pivot")]
100            DslFunction::Unpivot { args } => {
101                let on = expand_selectors(args.on, input_schema, &[])?;
102                let index = expand_selectors(args.index, input_schema, &[])?;
103                validate_columns_in_input(on.as_ref(), input_schema, "unpivot")?;
104                validate_columns_in_input(index.as_ref(), input_schema, "unpivot")?;
105
106                let args = UnpivotArgsIR {
107                    on: on.iter().cloned().collect(),
108                    index: index.iter().cloned().collect(),
109                    variable_name: args.variable_name.clone(),
110                    value_name: args.value_name.clone(),
111                };
112
113                FunctionIR::Unpivot {
114                    args: Arc::new(args),
115                    schema: Default::default(),
116                }
117            },
118            DslFunction::FunctionIR(func) => func,
119            DslFunction::RowIndex { name, offset } => FunctionIR::RowIndex {
120                name,
121                offset,
122                schema: Default::default(),
123            },
124            DslFunction::Rename {
125                existing,
126                new,
127                strict,
128            } => {
129                let swapping = new.iter().any(|name| input_schema.get(name).is_some());
130                if strict {
131                    validate_columns_in_input(existing.as_ref(), input_schema, "rename")?;
132                }
133                FunctionIR::Rename {
134                    existing,
135                    new,
136                    swapping,
137                    schema: Default::default(),
138                }
139            },
140            DslFunction::Unnest(selectors) => {
141                let columns = expand_selectors(selectors, input_schema, &[])?;
142                validate_columns_in_input(columns.as_ref(), input_schema, "explode")?;
143                FunctionIR::Unnest { columns }
144            },
145            #[cfg(feature = "python")]
146            DslFunction::OpaquePython(inner) => FunctionIR::OpaquePython(inner),
147            DslFunction::Stats(_)
148            | DslFunction::FillNan(_)
149            | DslFunction::Drop(_)
150            | DslFunction::Explode { .. } => {
151                // We should not reach this.
152                panic!("impl error")
153            },
154        };
155        Ok(function)
156    }
157}
158
159impl Debug for DslFunction {
160    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
161        write!(f, "{self}")
162    }
163}
164
165impl Display for DslFunction {
166    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
167        use DslFunction::*;
168        match self {
169            FunctionIR(inner) => write!(f, "{inner}"),
170            v => {
171                let s: &str = v.into();
172                write!(f, "{s}")
173            },
174        }
175    }
176}
177
178impl From<FunctionIR> for DslFunction {
179    fn from(value: FunctionIR) -> Self {
180        DslFunction::FunctionIR(value)
181    }
182}