polars_plan/plans/
expr_ir.rs

1use std::borrow::Borrow;
2use std::hash::Hash;
3#[cfg(feature = "cse")]
4use std::hash::Hasher;
5use std::sync::OnceLock;
6
7use polars_utils::format_pl_smallstr;
8#[cfg(feature = "ir_serde")]
9use serde::{Deserialize, Serialize};
10
11use super::*;
12use crate::constants::{get_len_name, get_literal_name};
13
14#[derive(Default, Debug, Clone, Hash, PartialEq, Eq)]
15#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
16pub enum OutputName {
17    /// No not yet set.
18    #[default]
19    None,
20    /// The most left-hand-side literal will be the output name.
21    LiteralLhs(PlSmallStr),
22    /// The most left-hand-side column will be the output name.
23    ColumnLhs(PlSmallStr),
24    /// Rename the output as `PlSmallStr`.
25    Alias(PlSmallStr),
26    #[cfg(feature = "dtype-struct")]
27    /// A struct field.
28    Field(PlSmallStr),
29}
30
31impl OutputName {
32    pub fn get(&self) -> Option<&PlSmallStr> {
33        match self {
34            OutputName::Alias(name) => Some(name),
35            OutputName::ColumnLhs(name) => Some(name),
36            OutputName::LiteralLhs(name) => Some(name),
37            #[cfg(feature = "dtype-struct")]
38            OutputName::Field(name) => Some(name),
39            OutputName::None => None,
40        }
41    }
42
43    pub fn unwrap(&self) -> &PlSmallStr {
44        self.get().expect("no output name set")
45    }
46
47    pub(crate) fn is_none(&self) -> bool {
48        matches!(self, OutputName::None)
49    }
50}
51
52#[derive(Debug)]
53#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
54pub struct ExprIR {
55    /// Output name of this expression.
56    output_name: OutputName,
57    /// Output dtype of this expression
58    /// Reduced expression.
59    /// This expression is pruned from `alias` and already expanded.
60    node: Node,
61    #[cfg_attr(feature = "ir_serde", serde(skip))]
62    output_dtype: OnceLock<DataType>,
63}
64
65impl Eq for ExprIR {}
66
67impl PartialEq for ExprIR {
68    fn eq(&self, other: &Self) -> bool {
69        self.node == other.node && self.output_name == other.output_name
70    }
71}
72
73impl Clone for ExprIR {
74    fn clone(&self) -> Self {
75        let output_dtype = OnceLock::new();
76        if let Some(dt) = self.output_dtype.get() {
77            output_dtype.set(dt.clone()).unwrap()
78        }
79
80        ExprIR {
81            output_name: self.output_name.clone(),
82            node: self.node,
83            output_dtype,
84        }
85    }
86}
87
88impl Borrow<Node> for ExprIR {
89    fn borrow(&self) -> &Node {
90        &self.node
91    }
92}
93
94impl ExprIR {
95    pub fn new(node: Node, output_name: OutputName) -> Self {
96        debug_assert!(!output_name.is_none());
97        ExprIR {
98            output_name,
99            node,
100            output_dtype: OnceLock::new(),
101        }
102    }
103
104    pub fn with_dtype(self, dtype: DataType) -> Self {
105        let _ = self.output_dtype.set(dtype);
106        self
107    }
108
109    pub(crate) fn set_dtype(&mut self, dtype: DataType) {
110        self.output_dtype = OnceLock::from(dtype);
111    }
112
113    pub fn from_node(node: Node, arena: &Arena<AExpr>) -> Self {
114        let mut out = Self {
115            node,
116            output_name: OutputName::None,
117            output_dtype: OnceLock::new(),
118        };
119        out.node = node;
120        for (_, ae) in arena.iter(node) {
121            match ae {
122                AExpr::Column(name) => {
123                    out.output_name = OutputName::ColumnLhs(name.clone());
124                    break;
125                },
126                AExpr::Literal(lv) => {
127                    if let LiteralValue::Series(s) = lv {
128                        out.output_name = OutputName::LiteralLhs(s.name().clone());
129                    } else {
130                        out.output_name = OutputName::LiteralLhs(get_literal_name().clone());
131                    }
132                    break;
133                },
134                AExpr::Function {
135                    input, function, ..
136                } => {
137                    match function {
138                        #[cfg(feature = "dtype-struct")]
139                        FunctionExpr::StructExpr(StructFunction::FieldByName(name)) => {
140                            out.output_name = OutputName::Field(name.clone());
141                        },
142                        _ => {
143                            if input.is_empty() {
144                                out.output_name =
145                                    OutputName::LiteralLhs(format_pl_smallstr!("{}", function));
146                            } else {
147                                out.output_name = input[0].output_name.clone();
148                            }
149                        },
150                    }
151                    break;
152                },
153                AExpr::AnonymousFunction { input, options, .. } => {
154                    if input.is_empty() {
155                        out.output_name =
156                            OutputName::LiteralLhs(PlSmallStr::from_static(options.fmt_str));
157                    } else {
158                        out.output_name = input[0].output_name.clone();
159                    }
160                    break;
161                },
162                AExpr::Len => out.output_name = OutputName::LiteralLhs(get_len_name()),
163                AExpr::Alias(_, _) => {
164                    // Should be removed during conversion.
165                    #[cfg(debug_assertions)]
166                    {
167                        unreachable!()
168                    }
169                },
170                _ => {},
171            }
172        }
173        debug_assert!(!out.output_name.is_none());
174        out
175    }
176
177    #[inline]
178    pub fn node(&self) -> Node {
179        self.node
180    }
181
182    /// Create a `ExprIR` structure that implements display
183    pub fn display<'a>(&'a self, expr_arena: &'a Arena<AExpr>) -> ExprIRDisplay<'a> {
184        ExprIRDisplay {
185            node: self.node(),
186            output_name: self.output_name_inner(),
187            expr_arena,
188        }
189    }
190
191    pub(crate) fn set_node(&mut self, node: Node) {
192        self.node = node;
193        self.output_dtype = OnceLock::new();
194    }
195
196    pub(crate) fn set_alias(&mut self, name: PlSmallStr) {
197        self.output_name = OutputName::Alias(name)
198    }
199
200    pub fn output_name_inner(&self) -> &OutputName {
201        &self.output_name
202    }
203
204    pub fn output_name(&self) -> &PlSmallStr {
205        self.output_name.unwrap()
206    }
207
208    pub fn to_expr(&self, expr_arena: &Arena<AExpr>) -> Expr {
209        let out = node_to_expr(self.node, expr_arena);
210
211        match &self.output_name {
212            OutputName::Alias(name) => out.alias(name.clone()),
213            _ => out,
214        }
215    }
216
217    pub fn get_alias(&self) -> Option<&PlSmallStr> {
218        match &self.output_name {
219            OutputName::Alias(name) => Some(name),
220            _ => None,
221        }
222    }
223
224    // Utility for debugging.
225    #[cfg(debug_assertions)]
226    #[allow(dead_code)]
227    pub(crate) fn print(&self, expr_arena: &Arena<AExpr>) {
228        eprintln!("{:?}", self.to_expr(expr_arena))
229    }
230
231    pub(crate) fn has_alias(&self) -> bool {
232        matches!(self.output_name, OutputName::Alias(_))
233    }
234
235    #[cfg(feature = "cse")]
236    pub(crate) fn traverse_and_hash<H: Hasher>(&self, expr_arena: &Arena<AExpr>, state: &mut H) {
237        traverse_and_hash_aexpr(self.node, expr_arena, state);
238        if let Some(alias) = self.get_alias() {
239            alias.hash(state)
240        }
241    }
242
243    pub fn is_scalar(&self, expr_arena: &Arena<AExpr>) -> bool {
244        is_scalar_ae(self.node, expr_arena)
245    }
246
247    pub fn dtype(
248        &self,
249        schema: &Schema,
250        ctxt: Context,
251        expr_arena: &Arena<AExpr>,
252    ) -> PolarsResult<&DataType> {
253        match self.output_dtype.get() {
254            Some(dtype) => Ok(dtype),
255            None => {
256                let dtype = expr_arena
257                    .get(self.node)
258                    .to_dtype(schema, ctxt, expr_arena)?;
259                let _ = self.output_dtype.set(dtype);
260                Ok(self.output_dtype.get().unwrap())
261            },
262        }
263    }
264
265    pub fn field(
266        &self,
267        schema: &Schema,
268        ctxt: Context,
269        expr_arena: &Arena<AExpr>,
270    ) -> PolarsResult<Field> {
271        let dtype = self.dtype(schema, ctxt, expr_arena)?;
272        let name = self.output_name();
273        Ok(Field::new(name.clone(), dtype.clone()))
274    }
275}
276
277impl AsRef<ExprIR> for ExprIR {
278    fn as_ref(&self) -> &ExprIR {
279        self
280    }
281}
282
283/// A Node that is restricted to `AExpr::Column`
284#[repr(transparent)]
285#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
286pub struct ColumnNode(pub(crate) Node);
287
288impl From<ColumnNode> for Node {
289    fn from(value: ColumnNode) -> Self {
290        value.0
291    }
292}
293impl From<&ExprIR> for Node {
294    fn from(value: &ExprIR) -> Self {
295        value.node()
296    }
297}
298
299pub(crate) fn name_to_expr_ir(name: PlSmallStr, expr_arena: &mut Arena<AExpr>) -> ExprIR {
300    let node = expr_arena.add(AExpr::Column(name.clone()));
301    ExprIR::new(node, OutputName::ColumnLhs(name))
302}
303
304pub(crate) fn names_to_expr_irs<I, S>(names: I, expr_arena: &mut Arena<AExpr>) -> Vec<ExprIR>
305where
306    I: IntoIterator<Item = S>,
307    S: Into<PlSmallStr>,
308{
309    names
310        .into_iter()
311        .map(|name| {
312            let name = name.into();
313            name_to_expr_ir(name, expr_arena)
314        })
315        .collect()
316}