polars_plan/plans/aexpr/
schema.rs

1#[cfg(feature = "dtype-decimal")]
2use polars_core::chunked_array::arithmetic::{
3    _get_decimal_scale_add_sub, _get_decimal_scale_div, _get_decimal_scale_mul,
4};
5use recursive::recursive;
6
7use super::*;
8
9fn float_type(field: &mut Field) {
10    let should_coerce = match &field.dtype {
11        DataType::Float32 => false,
12        #[cfg(feature = "dtype-decimal")]
13        DataType::Decimal(..) => true,
14        DataType::Boolean => true,
15        dt => dt.is_primitive_numeric(),
16    };
17    if should_coerce {
18        field.coerce(DataType::Float64);
19    }
20}
21
22fn validate_expr(node: Node, arena: &Arena<AExpr>, schema: &Schema) -> PolarsResult<()> {
23    let mut ctx = ToFieldContext {
24        schema,
25        ctx: Context::Default,
26        arena,
27        validate: true,
28    };
29    arena
30        .get(node)
31        .to_field_impl(&mut ctx, &mut false)
32        .map(|_| ())
33}
34
35struct ToFieldContext<'a> {
36    schema: &'a Schema,
37    ctx: Context,
38    arena: &'a Arena<AExpr>,
39    // Traverse all expressions to validate they are in the schema.
40    validate: bool,
41}
42
43impl AExpr {
44    pub fn to_dtype(
45        &self,
46        schema: &Schema,
47        ctx: Context,
48        arena: &Arena<AExpr>,
49    ) -> PolarsResult<DataType> {
50        self.to_field(schema, ctx, arena).map(|f| f.dtype)
51    }
52
53    /// Get Field result of the expression. The schema is the input data.
54    pub fn to_field(
55        &self,
56        schema: &Schema,
57        ctx: Context,
58        arena: &Arena<AExpr>,
59    ) -> PolarsResult<Field> {
60        // Indicates whether we should auto-implode the result. This is initialized to true if we are
61        // in an aggregation context, so functions that return scalars should explicitly set this
62        // to false in `to_field_impl`.
63        let mut agg_list = matches!(ctx, Context::Aggregation);
64        let mut ctx = ToFieldContext {
65            schema,
66            ctx,
67            arena,
68            validate: true,
69        };
70        let mut field = self.to_field_impl(&mut ctx, &mut agg_list)?;
71
72        if agg_list {
73            field.coerce(field.dtype().clone().implode());
74        }
75
76        Ok(field)
77    }
78
79    /// Get Field result of the expression. The schema is the input data.
80    pub fn to_field_and_validate(
81        &self,
82        schema: &Schema,
83        ctx: Context,
84        arena: &Arena<AExpr>,
85    ) -> PolarsResult<Field> {
86        // Indicates whether we should auto-implode the result. This is initialized to true if we are
87        // in an aggregation context, so functions that return scalars should explicitly set this
88        // to false in `to_field_impl`.
89        let mut agg_list = matches!(ctx, Context::Aggregation);
90
91        let mut ctx = ToFieldContext {
92            schema,
93            ctx,
94            arena,
95            validate: true,
96        };
97        let mut field = self.to_field_impl(&mut ctx, &mut agg_list)?;
98
99        if agg_list {
100            field.coerce(field.dtype().clone().implode());
101        }
102
103        Ok(field)
104    }
105
106    /// Get Field result of the expression. The schema is the input data.
107    ///
108    /// This is taken as `&mut bool` as for some expressions this is determined by the upper node
109    /// (e.g. `alias`, `cast`).
110    #[recursive]
111    pub fn to_field_impl(
112        &self,
113        ctx: &mut ToFieldContext,
114        agg_list: &mut bool,
115    ) -> PolarsResult<Field> {
116        use AExpr::*;
117        use DataType::*;
118        match self {
119            Len => {
120                *agg_list = false;
121                Ok(Field::new(PlSmallStr::from_static(LEN), IDX_DTYPE))
122            },
123            Window {
124                function,
125                options,
126                partition_by,
127                order_by,
128            } => {
129                if let WindowType::Over(WindowMapping::Join) = options {
130                    // expr.over(..), defaults to agg-list unless explicitly unset
131                    // by the `to_field_impl` of the `expr`
132                    *agg_list = true;
133                }
134
135                if ctx.validate {
136                    for node in partition_by {
137                        validate_expr(*node, ctx.arena, ctx.schema)?;
138                    }
139                    if let Some((node, _)) = order_by {
140                        validate_expr(*node, ctx.arena, ctx.schema)?;
141                    }
142                }
143
144                let e = ctx.arena.get(*function);
145                e.to_field_impl(ctx, agg_list)
146            },
147            Explode(expr) => {
148                // `Explode` is a "flatten" operation, which is not the same as returning a scalar.
149                // Namely, it should be auto-imploded in the aggregation context, so we don't update
150                // the `agg_list` state here.
151                let field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?;
152
153                let field = match field.dtype() {
154                    List(inner) => Field::new(field.name().clone(), *inner.clone()),
155                    #[cfg(feature = "dtype-array")]
156                    Array(inner, ..) => Field::new(field.name().clone(), *inner.clone()),
157                    _ => field,
158                };
159
160                Ok(field)
161            },
162            Alias(expr, name) => Ok(Field::new(
163                name.clone(),
164                ctx.arena.get(*expr).to_field_impl(ctx, agg_list)?.dtype,
165            )),
166            Column(name) => ctx
167                .schema
168                .get_field(name)
169                .ok_or_else(|| PolarsError::ColumnNotFound(name.to_string().into())),
170            Literal(sv) => {
171                *agg_list = false;
172                Ok(match sv {
173                    LiteralValue::Series(s) => s.field().into_owned(),
174                    _ => Field::new(sv.output_name().clone(), sv.get_datatype()),
175                })
176            },
177            BinaryExpr { left, right, op } => {
178                use DataType::*;
179
180                let field = match op {
181                    Operator::Lt
182                    | Operator::Gt
183                    | Operator::Eq
184                    | Operator::NotEq
185                    | Operator::LogicalAnd
186                    | Operator::LtEq
187                    | Operator::GtEq
188                    | Operator::NotEqValidity
189                    | Operator::EqValidity
190                    | Operator::LogicalOr => {
191                        let out_field;
192                        let out_name = {
193                            out_field = ctx.arena.get(*left).to_field_impl(ctx, agg_list)?;
194                            out_field.name()
195                        };
196                        Field::new(out_name.clone(), Boolean)
197                    },
198                    Operator::TrueDivide => get_truediv_field(*left, *right, ctx, agg_list)?,
199                    _ => get_arithmetic_field(*left, *right, *op, ctx, agg_list)?,
200                };
201
202                Ok(field)
203            },
204            Sort { expr, .. } => ctx.arena.get(*expr).to_field_impl(ctx, agg_list),
205            Gather {
206                expr,
207                idx,
208                returns_scalar,
209                ..
210            } => {
211                if *returns_scalar {
212                    *agg_list = false;
213                }
214                if ctx.validate {
215                    validate_expr(*idx, ctx.arena, ctx.schema)?
216                }
217                ctx.arena.get(*expr).to_field_impl(ctx, &mut false)
218            },
219            SortBy { expr, .. } => ctx.arena.get(*expr).to_field_impl(ctx, agg_list),
220            Filter { input, by } => {
221                if ctx.validate {
222                    validate_expr(*by, ctx.arena, ctx.schema)?
223                }
224                ctx.arena.get(*input).to_field_impl(ctx, agg_list)
225            },
226            Agg(agg) => {
227                use IRAggExpr::*;
228                match agg {
229                    Max { input: expr, .. }
230                    | Min { input: expr, .. }
231                    | First(expr)
232                    | Last(expr) => {
233                        *agg_list = false;
234                        ctx.arena.get(*expr).to_field_impl(ctx, &mut false)
235                    },
236                    Sum(expr) => {
237                        *agg_list = false;
238                        let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?;
239                        let dt = match field.dtype() {
240                            Boolean => Some(IDX_DTYPE),
241                            UInt8 | Int8 | Int16 | UInt16 => Some(Int64),
242                            _ => None,
243                        };
244                        if let Some(dt) = dt {
245                            field.coerce(dt);
246                        }
247                        Ok(field)
248                    },
249                    Median(expr) => {
250                        *agg_list = false;
251                        let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?;
252                        match field.dtype {
253                            Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)),
254                            _ => float_type(&mut field),
255                        }
256                        Ok(field)
257                    },
258                    Mean(expr) => {
259                        *agg_list = false;
260                        let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?;
261                        match field.dtype {
262                            Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)),
263                            _ => float_type(&mut field),
264                        }
265                        Ok(field)
266                    },
267                    Implode(expr) => {
268                        let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?;
269                        field.coerce(DataType::List(field.dtype().clone().into()));
270                        Ok(field)
271                    },
272                    Std(expr, _) => {
273                        *agg_list = false;
274                        let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?;
275                        float_type(&mut field);
276                        Ok(field)
277                    },
278                    Var(expr, _) => {
279                        *agg_list = false;
280                        let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?;
281                        float_type(&mut field);
282                        Ok(field)
283                    },
284                    NUnique(expr) => {
285                        *agg_list = false;
286                        let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?;
287                        field.coerce(IDX_DTYPE);
288                        Ok(field)
289                    },
290                    Count(expr, _) => {
291                        *agg_list = false;
292                        let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?;
293                        field.coerce(IDX_DTYPE);
294                        Ok(field)
295                    },
296                    AggGroups(expr) => {
297                        *agg_list = true;
298                        let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?;
299                        field.coerce(List(IDX_DTYPE.into()));
300                        Ok(field)
301                    },
302                    Quantile { expr, .. } => {
303                        *agg_list = false;
304                        let mut field = ctx.arena.get(*expr).to_field_impl(ctx, &mut false)?;
305                        float_type(&mut field);
306                        Ok(field)
307                    },
308                }
309            },
310            Cast { expr, dtype, .. } => {
311                let field = ctx.arena.get(*expr).to_field_impl(ctx, agg_list)?;
312                Ok(Field::new(field.name().clone(), dtype.clone()))
313            },
314            Ternary { truthy, falsy, .. } => {
315                let mut agg_list_truthy = *agg_list;
316                let mut agg_list_falsy = *agg_list;
317
318                // During aggregation:
319                // left: col(foo):              list<T>         nesting: 1
320                // right; col(foo).first():     T               nesting: 0
321                // col(foo) + col(foo).first() will have nesting 1 as we still maintain the groups list.
322                let mut truthy = ctx
323                    .arena
324                    .get(*truthy)
325                    .to_field_impl(ctx, &mut agg_list_truthy)?;
326                let falsy = ctx
327                    .arena
328                    .get(*falsy)
329                    .to_field_impl(ctx, &mut agg_list_falsy)?;
330
331                let st = if let DataType::Null = *truthy.dtype() {
332                    falsy.dtype().clone()
333                } else {
334                    try_get_supertype(truthy.dtype(), falsy.dtype())?
335                };
336
337                *agg_list = agg_list_truthy | agg_list_falsy;
338
339                truthy.coerce(st);
340                Ok(truthy)
341            },
342            AnonymousFunction {
343                output_type,
344                input,
345                options,
346                ..
347            } => {
348                let fields = func_args_to_fields(input, ctx, agg_list)?;
349                polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", options.fmt_str);
350                let out = output_type.get_field(ctx.schema, ctx.ctx, &fields)?;
351
352                if options.flags.contains(FunctionFlags::RETURNS_SCALAR) {
353                    *agg_list = false;
354                } else if !options.is_elementwise() && matches!(ctx.ctx, Context::Aggregation) {
355                    *agg_list = true;
356                }
357
358                Ok(out)
359            },
360            Function {
361                function,
362                input,
363                options,
364            } => {
365                let fields = func_args_to_fields(input, ctx, agg_list)?;
366                polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function);
367                let out = function.get_field(ctx.schema, ctx.ctx, &fields)?;
368
369                if options.flags.contains(FunctionFlags::RETURNS_SCALAR) {
370                    *agg_list = false;
371                } else if !options.is_elementwise() && matches!(ctx.ctx, Context::Aggregation) {
372                    *agg_list = true;
373                }
374
375                Ok(out)
376            },
377            Slice {
378                input,
379                offset,
380                length,
381            } => {
382                if ctx.validate {
383                    validate_expr(*offset, ctx.arena, ctx.schema)?;
384                    validate_expr(*length, ctx.arena, ctx.schema)?;
385                }
386
387                ctx.arena.get(*input).to_field_impl(ctx, agg_list)
388            },
389        }
390    }
391}
392
393fn func_args_to_fields(
394    input: &[ExprIR],
395    ctx: &mut ToFieldContext,
396    agg_list: &mut bool,
397) -> PolarsResult<Vec<Field>> {
398    input
399        .iter()
400        .enumerate()
401        // Default context because `col()` would return a list in aggregation context
402        .map(|(i, e)| {
403            let tmp = &mut false;
404
405            ctx.arena
406                .get(e.node())
407                .to_field_impl(
408                    ctx,
409                    if i == 0 {
410                        // Only mutate first agg_list as that is the dtype of the function.
411                        agg_list
412                    } else {
413                        tmp
414                    },
415                )
416                .map(|mut field| {
417                    field.name = e.output_name().clone();
418                    field
419                })
420        })
421        .collect()
422}
423
424#[allow(clippy::too_many_arguments)]
425fn get_arithmetic_field(
426    left: Node,
427    right: Node,
428    op: Operator,
429    ctx: &mut ToFieldContext,
430    agg_list: &mut bool,
431) -> PolarsResult<Field> {
432    use DataType::*;
433    let left_ae = ctx.arena.get(left);
434    let right_ae = ctx.arena.get(right);
435
436    // don't traverse tree until strictly needed. Can have terrible performance.
437    // # 3210
438
439    // take the left field as a whole.
440    // don't take dtype and name separate as that splits the tree every node
441    // leading to quadratic behavior. # 4736
442    //
443    // further right_type is only determined when needed.
444    let mut left_field = left_ae.to_field_impl(ctx, agg_list)?;
445
446    let super_type = match op {
447        Operator::Minus => {
448            let right_type = right_ae.to_field_impl(ctx, agg_list)?.dtype;
449            match (&left_field.dtype, &right_type) {
450                #[cfg(feature = "dtype-struct")]
451                (Struct(_), Struct(_)) => {
452                    return Ok(left_field);
453                },
454                (Duration(_), Datetime(_, _))
455                | (Datetime(_, _), Duration(_))
456                | (Duration(_), Date)
457                | (Date, Duration(_))
458                | (Duration(_), Time)
459                | (Time, Duration(_)) => try_get_supertype(left_field.dtype(), &right_type)?,
460                (Datetime(tu, _), Date) | (Date, Datetime(tu, _)) => Duration(*tu),
461                // T - T != T if T is a datetime / date
462                (Datetime(tul, _), Datetime(tur, _)) => Duration(get_time_units(tul, tur)),
463                (_, Datetime(_, _)) | (Datetime(_, _), _) => {
464                    polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
465                },
466                (Date, Date) => Duration(TimeUnit::Milliseconds),
467                (_, Date) | (Date, _) => {
468                    polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
469                },
470                (Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)),
471                (_, Duration(_)) | (Duration(_), _) => {
472                    polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
473                },
474                (Time, Time) => Duration(TimeUnit::Nanoseconds),
475                (_, Time) | (Time, _) => {
476                    polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
477                },
478                (l @ List(a), r @ List(b))
479                    if ![a, b]
480                        .into_iter()
481                        .all(|x| x.is_supported_list_arithmetic_input()) =>
482                {
483                    polars_bail!(
484                        InvalidOperation:
485                        "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
486                        "sub", l, r,
487                    )
488                },
489                (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {
490                    // FIXME: This should not use `try_get_supertype()`! It should instead recursively use the enclosing match block.
491                    // Otherwise we will silently permit addition operations between logical types (see above).
492                    // This currently doesn't cause any problems because the list arithmetic implementation checks and raises errors
493                    // if the leaf types aren't numeric, but it means we don't raise an error until execution and the DSL schema
494                    // may be incorrect.
495                    list_dtype.cast_leaf(try_get_supertype(
496                        list_dtype.leaf_dtype(),
497                        other_dtype.leaf_dtype(),
498                    )?)
499                },
500                #[cfg(feature = "dtype-array")]
501                (list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {
502                    list_dtype.cast_leaf(try_get_supertype(
503                        list_dtype.leaf_dtype(),
504                        other_dtype.leaf_dtype(),
505                    )?)
506                },
507                #[cfg(feature = "dtype-decimal")]
508                (Decimal(_, Some(scale_left)), Decimal(_, Some(scale_right))) => {
509                    let scale = _get_decimal_scale_add_sub(*scale_left, *scale_right);
510                    Decimal(None, Some(scale))
511                },
512                (left, right) => try_get_supertype(left, right)?,
513            }
514        },
515        Operator::Plus => {
516            let right_type = right_ae.to_field_impl(ctx, agg_list)?.dtype;
517            match (&left_field.dtype, &right_type) {
518                (Duration(_), Datetime(_, _))
519                | (Datetime(_, _), Duration(_))
520                | (Duration(_), Date)
521                | (Date, Duration(_))
522                | (Duration(_), Time)
523                | (Time, Duration(_)) => try_get_supertype(left_field.dtype(), &right_type)?,
524                (_, Datetime(_, _))
525                | (Datetime(_, _), _)
526                | (_, Date)
527                | (Date, _)
528                | (Time, _)
529                | (_, Time) => {
530                    polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
531                },
532                (Duration(tul), Duration(tur)) => Duration(get_time_units(tul, tur)),
533                (_, Duration(_)) | (Duration(_), _) => {
534                    polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
535                },
536                (Boolean, Boolean) => IDX_DTYPE,
537                (l @ List(a), r @ List(b))
538                    if ![a, b]
539                        .into_iter()
540                        .all(|x| x.is_supported_list_arithmetic_input()) =>
541                {
542                    polars_bail!(
543                        InvalidOperation:
544                        "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
545                        "add", l, r,
546                    )
547                },
548                (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {
549                    list_dtype.cast_leaf(try_get_supertype(
550                        list_dtype.leaf_dtype(),
551                        other_dtype.leaf_dtype(),
552                    )?)
553                },
554                #[cfg(feature = "dtype-array")]
555                (list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {
556                    list_dtype.cast_leaf(try_get_supertype(
557                        list_dtype.leaf_dtype(),
558                        other_dtype.leaf_dtype(),
559                    )?)
560                },
561                #[cfg(feature = "dtype-decimal")]
562                (Decimal(_, Some(scale_left)), Decimal(_, Some(scale_right))) => {
563                    let scale = _get_decimal_scale_add_sub(*scale_left, *scale_right);
564                    Decimal(None, Some(scale))
565                },
566                (left, right) => try_get_supertype(left, right)?,
567            }
568        },
569        _ => {
570            let right_type = right_ae.to_field_impl(ctx, agg_list)?.dtype;
571
572            match (&left_field.dtype, &right_type) {
573                #[cfg(feature = "dtype-struct")]
574                (Struct(_), Struct(_)) => {
575                    return Ok(left_field);
576                },
577                (Datetime(_, _), _)
578                | (_, Datetime(_, _))
579                | (Time, _)
580                | (_, Time)
581                | (Date, _)
582                | (_, Date) => {
583                    polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
584                },
585                (Duration(_), Duration(_)) => {
586                    // True divide handled somewhere else
587                    polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
588                },
589                (l, Duration(_)) if l.is_primitive_numeric() => match op {
590                    Operator::Multiply => {
591                        left_field.coerce(right_type);
592                        return Ok(left_field);
593                    },
594                    _ => {
595                        polars_bail!(InvalidOperation: "{} not allowed on {} and {}", op, left_field.dtype, right_type)
596                    },
597                },
598                #[cfg(feature = "dtype-decimal")]
599                (Decimal(_, Some(scale_left)), Decimal(_, Some(scale_right))) => {
600                    let scale = match op {
601                        Operator::Multiply => _get_decimal_scale_mul(*scale_left, *scale_right),
602                        Operator::Divide | Operator::TrueDivide => {
603                            _get_decimal_scale_div(*scale_left)
604                        },
605                        _ => {
606                            debug_assert!(false);
607                            *scale_left
608                        },
609                    };
610                    let dtype = Decimal(None, Some(scale));
611                    left_field.coerce(dtype);
612                    return Ok(left_field);
613                },
614
615                (l @ List(a), r @ List(b))
616                    if ![a, b]
617                        .into_iter()
618                        .all(|x| x.is_supported_list_arithmetic_input()) =>
619                {
620                    polars_bail!(
621                        InvalidOperation:
622                        "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
623                        op, l, r,
624                    )
625                },
626                // List<->primitive operations can be done directly after casting the to the primitive
627                // supertype for the primitive values on both sides.
628                (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {
629                    let dtype = list_dtype.cast_leaf(try_get_supertype(
630                        list_dtype.leaf_dtype(),
631                        other_dtype.leaf_dtype(),
632                    )?);
633                    left_field.coerce(dtype);
634                    return Ok(left_field);
635                },
636                #[cfg(feature = "dtype-array")]
637                (list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {
638                    let dtype = list_dtype.cast_leaf(try_get_supertype(
639                        list_dtype.leaf_dtype(),
640                        other_dtype.leaf_dtype(),
641                    )?);
642                    left_field.coerce(dtype);
643                    return Ok(left_field);
644                },
645                _ => {
646                    // Avoid needlessly type casting numeric columns during arithmetic
647                    // with literals.
648                    if (left_field.dtype.is_integer() && right_type.is_integer())
649                        || (left_field.dtype.is_float() && right_type.is_float())
650                    {
651                        match (left_ae, right_ae) {
652                            (AExpr::Literal(_), AExpr::Literal(_)) => {},
653                            (AExpr::Literal(_), _) => {
654                                // literal will be coerced to match right type
655                                left_field.coerce(right_type);
656                                return Ok(left_field);
657                            },
658                            (_, AExpr::Literal(_)) => {
659                                // literal will be coerced to match right type
660                                return Ok(left_field);
661                            },
662                            _ => {},
663                        }
664                    }
665                },
666            }
667
668            try_get_supertype(&left_field.dtype, &right_type)?
669        },
670    };
671
672    left_field.coerce(super_type);
673    Ok(left_field)
674}
675
676fn get_truediv_field(
677    left: Node,
678    right: Node,
679    ctx: &mut ToFieldContext,
680    agg_list: &mut bool,
681) -> PolarsResult<Field> {
682    let mut left_field = ctx.arena.get(left).to_field_impl(ctx, agg_list)?;
683    let right_field = ctx.arena.get(right).to_field_impl(ctx, agg_list)?;
684    use DataType::*;
685
686    // TODO: Re-investigate this. A lot of "_" is being used on the RHS match because this code
687    // originally (mostly) only looked at the LHS dtype.
688    let out_type = match (left_field.dtype(), right_field.dtype()) {
689        (l @ List(a), r @ List(b))
690            if ![a, b]
691                .into_iter()
692                .all(|x| x.is_supported_list_arithmetic_input()) =>
693        {
694            polars_bail!(
695                InvalidOperation:
696                "cannot {} two list columns with non-numeric inner types: (left: {}, right: {})",
697                "div", l, r,
698            )
699        },
700        (list_dtype @ List(_), other_dtype) | (other_dtype, list_dtype @ List(_)) => {
701            list_dtype.cast_leaf(match (list_dtype.leaf_dtype(), other_dtype.leaf_dtype()) {
702                (Float32, Float32) => Float32,
703                (Float32, Float64) | (Float64, Float32) => Float64,
704                // FIXME: We should properly recurse on the enclosing match block here.
705                (dt, _) => dt.clone(),
706            })
707        },
708        #[cfg(feature = "dtype-array")]
709        (list_dtype @ Array(..), other_dtype) | (other_dtype, list_dtype @ Array(..)) => {
710            list_dtype.cast_leaf(match (list_dtype.leaf_dtype(), other_dtype.leaf_dtype()) {
711                (Float32, Float32) => Float32,
712                (Float32, Float64) | (Float64, Float32) => Float64,
713                // FIXME: We should properly recurse on the enclosing match block here.
714                (dt, _) => dt.clone(),
715            })
716        },
717        (Float32, _) => Float32,
718        #[cfg(feature = "dtype-decimal")]
719        (Decimal(_, Some(scale_left)), Decimal(_, _)) => {
720            let scale = _get_decimal_scale_div(*scale_left);
721            Decimal(None, Some(scale))
722        },
723        (dt, _) if dt.is_primitive_numeric() => Float64,
724        #[cfg(feature = "dtype-duration")]
725        (Duration(_), Duration(_)) => Float64,
726        #[cfg(feature = "dtype-duration")]
727        (Duration(_), dt) if dt.is_primitive_numeric() => return Ok(left_field),
728        #[cfg(feature = "dtype-duration")]
729        (Duration(_), dt) => {
730            polars_bail!(InvalidOperation: "true division of {} with {} is not allowed", left_field.dtype(), dt)
731        },
732        #[cfg(feature = "dtype-datetime")]
733        (Datetime(_, _), _) => {
734            polars_bail!(InvalidOperation: "division of 'Datetime' datatype is not allowed")
735        },
736        #[cfg(feature = "dtype-time")]
737        (Time, _) => polars_bail!(InvalidOperation: "division of 'Time' datatype is not allowed"),
738        #[cfg(feature = "dtype-date")]
739        (Date, _) => polars_bail!(InvalidOperation: "division of 'Date' datatype is not allowed"),
740        // we don't know what to do here, best return the dtype
741        (dt, _) => dt.clone(),
742    };
743
744    left_field.coerce(out_type);
745    Ok(left_field)
746}