1use std::borrow::Borrow;
2use std::hash::Hash;
3#[cfg(feature = "cse")]
4use std::hash::Hasher;
5use std::sync::OnceLock;
6
7use polars_utils::format_pl_smallstr;
8#[cfg(feature = "ir_serde")]
9use serde::{Deserialize, Serialize};
10
11use super::*;
12use crate::constants::{get_len_name, get_literal_name};
13
14#[derive(Default, Debug, Clone, Hash, PartialEq, Eq)]
15#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
16pub enum OutputName {
17 #[default]
19 None,
20 LiteralLhs(PlSmallStr),
22 ColumnLhs(PlSmallStr),
24 Alias(PlSmallStr),
26 #[cfg(feature = "dtype-struct")]
27 Field(PlSmallStr),
29}
30
31impl OutputName {
32 pub fn get(&self) -> Option<&PlSmallStr> {
33 match self {
34 OutputName::Alias(name) => Some(name),
35 OutputName::ColumnLhs(name) => Some(name),
36 OutputName::LiteralLhs(name) => Some(name),
37 #[cfg(feature = "dtype-struct")]
38 OutputName::Field(name) => Some(name),
39 OutputName::None => None,
40 }
41 }
42
43 pub fn unwrap(&self) -> &PlSmallStr {
44 self.get().expect("no output name set")
45 }
46
47 pub(crate) fn is_none(&self) -> bool {
48 matches!(self, OutputName::None)
49 }
50}
51
52#[derive(Debug)]
53#[cfg_attr(feature = "ir_serde", derive(Serialize, Deserialize))]
54pub struct ExprIR {
55 output_name: OutputName,
57 node: Node,
61 #[cfg_attr(feature = "ir_serde", serde(skip))]
62 output_dtype: OnceLock<DataType>,
63}
64
65impl Eq for ExprIR {}
66
67impl PartialEq for ExprIR {
68 fn eq(&self, other: &Self) -> bool {
69 self.node == other.node && self.output_name == other.output_name
70 }
71}
72
73impl Clone for ExprIR {
74 fn clone(&self) -> Self {
75 let output_dtype = OnceLock::new();
76 if let Some(dt) = self.output_dtype.get() {
77 output_dtype.set(dt.clone()).unwrap()
78 }
79
80 ExprIR {
81 output_name: self.output_name.clone(),
82 node: self.node,
83 output_dtype,
84 }
85 }
86}
87
88impl Borrow<Node> for ExprIR {
89 fn borrow(&self) -> &Node {
90 &self.node
91 }
92}
93
94impl ExprIR {
95 pub fn new(node: Node, output_name: OutputName) -> Self {
96 debug_assert!(!output_name.is_none());
97 ExprIR {
98 output_name,
99 node,
100 output_dtype: OnceLock::new(),
101 }
102 }
103
104 pub fn with_dtype(self, dtype: DataType) -> Self {
105 let _ = self.output_dtype.set(dtype);
106 self
107 }
108
109 pub(crate) fn set_dtype(&mut self, dtype: DataType) {
110 self.output_dtype = OnceLock::from(dtype);
111 }
112
113 pub fn from_node(node: Node, arena: &Arena<AExpr>) -> Self {
114 let mut out = Self {
115 node,
116 output_name: OutputName::None,
117 output_dtype: OnceLock::new(),
118 };
119 out.node = node;
120 for (_, ae) in arena.iter(node) {
121 match ae {
122 AExpr::Column(name) => {
123 out.output_name = OutputName::ColumnLhs(name.clone());
124 break;
125 },
126 AExpr::Literal(lv) => {
127 if let LiteralValue::Series(s) = lv {
128 out.output_name = OutputName::LiteralLhs(s.name().clone());
129 } else {
130 out.output_name = OutputName::LiteralLhs(get_literal_name().clone());
131 }
132 break;
133 },
134 AExpr::Function {
135 input, function, ..
136 } => {
137 match function {
138 #[cfg(feature = "dtype-struct")]
139 FunctionExpr::StructExpr(StructFunction::FieldByName(name)) => {
140 out.output_name = OutputName::Field(name.clone());
141 },
142 _ => {
143 if input.is_empty() {
144 out.output_name =
145 OutputName::LiteralLhs(format_pl_smallstr!("{}", function));
146 } else {
147 out.output_name = input[0].output_name.clone();
148 }
149 },
150 }
151 break;
152 },
153 AExpr::AnonymousFunction { input, options, .. } => {
154 if input.is_empty() {
155 out.output_name =
156 OutputName::LiteralLhs(PlSmallStr::from_static(options.fmt_str));
157 } else {
158 out.output_name = input[0].output_name.clone();
159 }
160 break;
161 },
162 AExpr::Len => out.output_name = OutputName::LiteralLhs(get_len_name()),
163 AExpr::Alias(_, _) => {
164 #[cfg(debug_assertions)]
166 {
167 unreachable!()
168 }
169 },
170 _ => {},
171 }
172 }
173 debug_assert!(!out.output_name.is_none());
174 out
175 }
176
177 #[inline]
178 pub fn node(&self) -> Node {
179 self.node
180 }
181
182 pub fn display<'a>(&'a self, expr_arena: &'a Arena<AExpr>) -> ExprIRDisplay<'a> {
184 ExprIRDisplay {
185 node: self.node(),
186 output_name: self.output_name_inner(),
187 expr_arena,
188 }
189 }
190
191 pub(crate) fn set_node(&mut self, node: Node) {
192 self.node = node;
193 self.output_dtype = OnceLock::new();
194 }
195
196 pub(crate) fn set_alias(&mut self, name: PlSmallStr) {
197 self.output_name = OutputName::Alias(name)
198 }
199
200 pub fn output_name_inner(&self) -> &OutputName {
201 &self.output_name
202 }
203
204 pub fn output_name(&self) -> &PlSmallStr {
205 self.output_name.unwrap()
206 }
207
208 pub fn to_expr(&self, expr_arena: &Arena<AExpr>) -> Expr {
209 let out = node_to_expr(self.node, expr_arena);
210
211 match &self.output_name {
212 OutputName::Alias(name) => out.alias(name.clone()),
213 _ => out,
214 }
215 }
216
217 pub fn get_alias(&self) -> Option<&PlSmallStr> {
218 match &self.output_name {
219 OutputName::Alias(name) => Some(name),
220 _ => None,
221 }
222 }
223
224 #[cfg(debug_assertions)]
226 #[allow(dead_code)]
227 pub(crate) fn print(&self, expr_arena: &Arena<AExpr>) {
228 eprintln!("{:?}", self.to_expr(expr_arena))
229 }
230
231 pub(crate) fn has_alias(&self) -> bool {
232 matches!(self.output_name, OutputName::Alias(_))
233 }
234
235 #[cfg(feature = "cse")]
236 pub(crate) fn traverse_and_hash<H: Hasher>(&self, expr_arena: &Arena<AExpr>, state: &mut H) {
237 traverse_and_hash_aexpr(self.node, expr_arena, state);
238 if let Some(alias) = self.get_alias() {
239 alias.hash(state)
240 }
241 }
242
243 pub fn is_scalar(&self, expr_arena: &Arena<AExpr>) -> bool {
244 is_scalar_ae(self.node, expr_arena)
245 }
246
247 pub fn dtype(
248 &self,
249 schema: &Schema,
250 ctxt: Context,
251 expr_arena: &Arena<AExpr>,
252 ) -> PolarsResult<&DataType> {
253 match self.output_dtype.get() {
254 Some(dtype) => Ok(dtype),
255 None => {
256 let dtype = expr_arena
257 .get(self.node)
258 .to_dtype(schema, ctxt, expr_arena)?;
259 let _ = self.output_dtype.set(dtype);
260 Ok(self.output_dtype.get().unwrap())
261 },
262 }
263 }
264
265 pub fn field(
266 &self,
267 schema: &Schema,
268 ctxt: Context,
269 expr_arena: &Arena<AExpr>,
270 ) -> PolarsResult<Field> {
271 let dtype = self.dtype(schema, ctxt, expr_arena)?;
272 let name = self.output_name();
273 Ok(Field::new(name.clone(), dtype.clone()))
274 }
275}
276
277impl AsRef<ExprIR> for ExprIR {
278 fn as_ref(&self) -> &ExprIR {
279 self
280 }
281}
282
283#[repr(transparent)]
285#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Ord, PartialOrd)]
286pub struct ColumnNode(pub(crate) Node);
287
288impl From<ColumnNode> for Node {
289 fn from(value: ColumnNode) -> Self {
290 value.0
291 }
292}
293impl From<&ExprIR> for Node {
294 fn from(value: &ExprIR) -> Self {
295 value.node()
296 }
297}
298
299pub(crate) fn name_to_expr_ir(name: PlSmallStr, expr_arena: &mut Arena<AExpr>) -> ExprIR {
300 let node = expr_arena.add(AExpr::Column(name.clone()));
301 ExprIR::new(node, OutputName::ColumnLhs(name))
302}
303
304pub(crate) fn names_to_expr_irs<I, S>(names: I, expr_arena: &mut Arena<AExpr>) -> Vec<ExprIR>
305where
306 I: IntoIterator<Item = S>,
307 S: Into<PlSmallStr>,
308{
309 names
310 .into_iter()
311 .map(|name| {
312 let name = name.into();
313 name_to_expr_ir(name, expr_arena)
314 })
315 .collect()
316}