datafusion_expr/
expr_schema.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::{Between, Expr, Like};
19use crate::expr::{
20    AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList,
21    InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction,
22    WindowFunctionParams,
23};
24use crate::type_coercion::functions::{
25    data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf,
26};
27use crate::udf::ReturnTypeArgs;
28use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition};
29use arrow::compute::can_cast_types;
30use arrow::datatypes::{DataType, Field};
31use datafusion_common::{
32    not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema,
33    Result, TableReference,
34};
35use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer;
36use datafusion_functions_window_common::field::WindowUDFFieldArgs;
37use std::collections::HashMap;
38use std::sync::Arc;
39
40/// Trait to allow expr to typable with respect to a schema
41pub trait ExprSchemable {
42    /// Given a schema, return the type of the expr
43    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType>;
44
45    /// Given a schema, return the nullability of the expr
46    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool>;
47
48    /// Given a schema, return the expr's optional metadata
49    fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>>;
50
51    /// Convert to a field with respect to a schema
52    fn to_field(
53        &self,
54        input_schema: &dyn ExprSchema,
55    ) -> Result<(Option<TableReference>, Arc<Field>)>;
56
57    /// Cast to a type with respect to a schema
58    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr>;
59
60    /// Given a schema, return the type and nullability of the expr
61    fn data_type_and_nullable(&self, schema: &dyn ExprSchema)
62        -> Result<(DataType, bool)>;
63}
64
65impl ExprSchemable for Expr {
66    /// Returns the [arrow::datatypes::DataType] of the expression
67    /// based on [ExprSchema]
68    ///
69    /// Note: [`DFSchema`] implements [ExprSchema].
70    ///
71    /// [`DFSchema`]: datafusion_common::DFSchema
72    ///
73    /// # Examples
74    ///
75    /// Get the type of an expression that adds 2 columns. Adding an Int32
76    /// and Float32 results in Float32 type
77    ///
78    /// ```
79    /// # use arrow::datatypes::{DataType, Field};
80    /// # use datafusion_common::DFSchema;
81    /// # use datafusion_expr::{col, ExprSchemable};
82    /// # use std::collections::HashMap;
83    ///
84    /// fn main() {
85    ///   let expr = col("c1") + col("c2");
86    ///   let schema = DFSchema::from_unqualified_fields(
87    ///     vec![
88    ///       Field::new("c1", DataType::Int32, true),
89    ///       Field::new("c2", DataType::Float32, true),
90    ///       ].into(),
91    ///       HashMap::new(),
92    ///   ).unwrap();
93    ///   assert_eq!("Float32", format!("{}", expr.get_type(&schema).unwrap()));
94    /// }
95    /// ```
96    ///
97    /// # Errors
98    ///
99    /// This function errors when it is not possible to compute its
100    /// [arrow::datatypes::DataType].  This happens when e.g. the
101    /// expression refers to a column that does not exist in the
102    /// schema, or when the expression is incorrectly typed
103    /// (e.g. `[utf8] + [bool]`).
104    #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
105    fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> {
106        match self {
107            Expr::Alias(Alias { expr, name, .. }) => match &**expr {
108                Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
109                    None => schema.data_type(&Column::from_name(name)).cloned(),
110                    Some(dt) => Ok(dt.clone()),
111                },
112                _ => expr.get_type(schema),
113            },
114            Expr::Negative(expr) => expr.get_type(schema),
115            Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
116            Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()),
117            Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
118            Expr::Literal(l) => Ok(l.data_type()),
119            Expr::Case(case) => {
120                for (_, then_expr) in &case.when_then_expr {
121                    let then_type = then_expr.get_type(schema)?;
122                    if !then_type.is_null() {
123                        return Ok(then_type);
124                    }
125                }
126                case.else_expr
127                    .as_ref()
128                    .map_or(Ok(DataType::Null), |e| e.get_type(schema))
129            }
130            Expr::Cast(Cast { data_type, .. })
131            | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()),
132            Expr::Unnest(Unnest { expr }) => {
133                let arg_data_type = expr.get_type(schema)?;
134                // Unnest's output type is the inner type of the list
135                match arg_data_type {
136                    DataType::List(field)
137                    | DataType::LargeList(field)
138                    | DataType::FixedSizeList(field, _) => Ok(field.data_type().clone()),
139                    DataType::Struct(_) => Ok(arg_data_type),
140                    DataType::Null => {
141                        not_impl_err!("unnest() does not support null yet")
142                    }
143                    _ => {
144                        plan_err!(
145                            "unnest() can only be applied to array, struct and null"
146                        )
147                    }
148                }
149            }
150            Expr::ScalarFunction(_func) => {
151                let (return_type, _) = self.data_type_and_nullable(schema)?;
152                Ok(return_type)
153            }
154            Expr::WindowFunction(window_function) => self
155                .data_type_and_nullable_with_window_function(schema, window_function)
156                .map(|(return_type, _)| return_type),
157            Expr::AggregateFunction(AggregateFunction {
158                func,
159                params: AggregateFunctionParams { args, .. },
160            }) => {
161                let data_types = args
162                    .iter()
163                    .map(|e| e.get_type(schema))
164                    .collect::<Result<Vec<_>>>()?;
165                let new_types = data_types_with_aggregate_udf(&data_types, func)
166                    .map_err(|err| {
167                        plan_datafusion_err!(
168                            "{} {}",
169                            match err {
170                                DataFusionError::Plan(msg) => msg,
171                                err => err.to_string(),
172                            },
173                            utils::generate_signature_error_msg(
174                                func.name(),
175                                func.signature().clone(),
176                                &data_types
177                            )
178                        )
179                    })?;
180                Ok(func.return_type(&new_types)?)
181            }
182            Expr::Not(_)
183            | Expr::IsNull(_)
184            | Expr::Exists { .. }
185            | Expr::InSubquery(_)
186            | Expr::Between { .. }
187            | Expr::InList { .. }
188            | Expr::IsNotNull(_)
189            | Expr::IsTrue(_)
190            | Expr::IsFalse(_)
191            | Expr::IsUnknown(_)
192            | Expr::IsNotTrue(_)
193            | Expr::IsNotFalse(_)
194            | Expr::IsNotUnknown(_) => Ok(DataType::Boolean),
195            Expr::ScalarSubquery(subquery) => {
196                Ok(subquery.subquery.schema().field(0).data_type().clone())
197            }
198            Expr::BinaryExpr(BinaryExpr {
199                ref left,
200                ref right,
201                ref op,
202            }) => BinaryTypeCoercer::new(
203                &left.get_type(schema)?,
204                op,
205                &right.get_type(schema)?,
206            )
207            .get_result_type(),
208            Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean),
209            Expr::Placeholder(Placeholder { data_type, .. }) => {
210                if let Some(dtype) = data_type {
211                    Ok(dtype.clone())
212                } else {
213                    // If the placeholder's type hasn't been specified, treat it as
214                    // null (unspecified placeholders generate an error during planning)
215                    Ok(DataType::Null)
216                }
217            }
218            #[expect(deprecated)]
219            Expr::Wildcard { .. } => Ok(DataType::Null),
220            Expr::GroupingSet(_) => {
221                // Grouping sets do not really have a type and do not appear in projections
222                Ok(DataType::Null)
223            }
224        }
225    }
226
227    /// Returns the nullability of the expression based on [ExprSchema].
228    ///
229    /// Note: [`DFSchema`] implements [ExprSchema].
230    ///
231    /// [`DFSchema`]: datafusion_common::DFSchema
232    ///
233    /// # Errors
234    ///
235    /// This function errors when it is not possible to compute its
236    /// nullability.  This happens when the expression refers to a
237    /// column that does not exist in the schema.
238    fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool> {
239        match self {
240            Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => {
241                expr.nullable(input_schema)
242            }
243
244            Expr::InList(InList { expr, list, .. }) => {
245                // Avoid inspecting too many expressions.
246                const MAX_INSPECT_LIMIT: usize = 6;
247                // Stop if a nullable expression is found or an error occurs.
248                let has_nullable = std::iter::once(expr.as_ref())
249                    .chain(list)
250                    .take(MAX_INSPECT_LIMIT)
251                    .find_map(|e| {
252                        e.nullable(input_schema)
253                            .map(|nullable| if nullable { Some(()) } else { None })
254                            .transpose()
255                    })
256                    .transpose()?;
257                Ok(match has_nullable {
258                    // If a nullable subexpression is found, the result may also be nullable.
259                    Some(_) => true,
260                    // If the list is too long, we assume it is nullable.
261                    None if list.len() + 1 > MAX_INSPECT_LIMIT => true,
262                    // All the subexpressions are non-nullable, so the result must be non-nullable.
263                    _ => false,
264                })
265            }
266
267            Expr::Between(Between {
268                expr, low, high, ..
269            }) => Ok(expr.nullable(input_schema)?
270                || low.nullable(input_schema)?
271                || high.nullable(input_schema)?),
272
273            Expr::Column(c) => input_schema.nullable(c),
274            Expr::OuterReferenceColumn(_, _) => Ok(true),
275            Expr::Literal(value) => Ok(value.is_null()),
276            Expr::Case(case) => {
277                // This expression is nullable if any of the input expressions are nullable
278                let then_nullable = case
279                    .when_then_expr
280                    .iter()
281                    .map(|(_, t)| t.nullable(input_schema))
282                    .collect::<Result<Vec<_>>>()?;
283                if then_nullable.contains(&true) {
284                    Ok(true)
285                } else if let Some(e) = &case.else_expr {
286                    e.nullable(input_schema)
287                } else {
288                    // CASE produces NULL if there is no `else` expr
289                    // (aka when none of the `when_then_exprs` match)
290                    Ok(true)
291                }
292            }
293            Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
294            Expr::ScalarFunction(_func) => {
295                let (_, nullable) = self.data_type_and_nullable(input_schema)?;
296                Ok(nullable)
297            }
298            Expr::AggregateFunction(AggregateFunction { func, .. }) => {
299                Ok(func.is_nullable())
300            }
301            Expr::WindowFunction(window_function) => self
302                .data_type_and_nullable_with_window_function(
303                    input_schema,
304                    window_function,
305                )
306                .map(|(_, nullable)| nullable),
307            Expr::ScalarVariable(_, _)
308            | Expr::TryCast { .. }
309            | Expr::Unnest(_)
310            | Expr::Placeholder(_) => Ok(true),
311            Expr::IsNull(_)
312            | Expr::IsNotNull(_)
313            | Expr::IsTrue(_)
314            | Expr::IsFalse(_)
315            | Expr::IsUnknown(_)
316            | Expr::IsNotTrue(_)
317            | Expr::IsNotFalse(_)
318            | Expr::IsNotUnknown(_)
319            | Expr::Exists { .. } => Ok(false),
320            Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema),
321            Expr::ScalarSubquery(subquery) => {
322                Ok(subquery.subquery.schema().field(0).is_nullable())
323            }
324            Expr::BinaryExpr(BinaryExpr {
325                ref left,
326                ref right,
327                ..
328            }) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
329            Expr::Like(Like { expr, pattern, .. })
330            | Expr::SimilarTo(Like { expr, pattern, .. }) => {
331                Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?)
332            }
333            #[expect(deprecated)]
334            Expr::Wildcard { .. } => Ok(false),
335            Expr::GroupingSet(_) => {
336                // Grouping sets do not really have the concept of nullable and do not appear
337                // in projections
338                Ok(true)
339            }
340        }
341    }
342
343    fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>> {
344        match self {
345            Expr::Column(c) => Ok(schema.metadata(c)?.clone()),
346            Expr::Alias(Alias { expr, .. }) => expr.metadata(schema),
347            Expr::Cast(Cast { expr, .. }) => expr.metadata(schema),
348            _ => Ok(HashMap::new()),
349        }
350    }
351
352    /// Returns the datatype and nullability of the expression based on [ExprSchema].
353    ///
354    /// Note: [`DFSchema`] implements [ExprSchema].
355    ///
356    /// [`DFSchema`]: datafusion_common::DFSchema
357    ///
358    /// # Errors
359    ///
360    /// This function errors when it is not possible to compute its
361    /// datatype or nullability.
362    fn data_type_and_nullable(
363        &self,
364        schema: &dyn ExprSchema,
365    ) -> Result<(DataType, bool)> {
366        match self {
367            Expr::Alias(Alias { expr, name, .. }) => match &**expr {
368                Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
369                    None => schema
370                        .data_type_and_nullable(&Column::from_name(name))
371                        .map(|(d, n)| (d.clone(), n)),
372                    Some(dt) => Ok((dt.clone(), expr.nullable(schema)?)),
373                },
374                _ => expr.data_type_and_nullable(schema),
375            },
376            Expr::Negative(expr) => expr.data_type_and_nullable(schema),
377            Expr::Column(c) => schema
378                .data_type_and_nullable(c)
379                .map(|(d, n)| (d.clone(), n)),
380            Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)),
381            Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)),
382            Expr::Literal(l) => Ok((l.data_type(), l.is_null())),
383            Expr::IsNull(_)
384            | Expr::IsNotNull(_)
385            | Expr::IsTrue(_)
386            | Expr::IsFalse(_)
387            | Expr::IsUnknown(_)
388            | Expr::IsNotTrue(_)
389            | Expr::IsNotFalse(_)
390            | Expr::IsNotUnknown(_)
391            | Expr::Exists { .. } => Ok((DataType::Boolean, false)),
392            Expr::ScalarSubquery(subquery) => Ok((
393                subquery.subquery.schema().field(0).data_type().clone(),
394                subquery.subquery.schema().field(0).is_nullable(),
395            )),
396            Expr::BinaryExpr(BinaryExpr {
397                ref left,
398                ref right,
399                ref op,
400            }) => {
401                let (lhs_type, lhs_nullable) = left.data_type_and_nullable(schema)?;
402                let (rhs_type, rhs_nullable) = right.data_type_and_nullable(schema)?;
403                let mut coercer = BinaryTypeCoercer::new(&lhs_type, op, &rhs_type);
404                coercer.set_lhs_spans(left.spans().cloned().unwrap_or_default());
405                coercer.set_rhs_spans(right.spans().cloned().unwrap_or_default());
406                Ok((coercer.get_result_type()?, lhs_nullable || rhs_nullable))
407            }
408            Expr::WindowFunction(window_function) => {
409                self.data_type_and_nullable_with_window_function(schema, window_function)
410            }
411            Expr::ScalarFunction(ScalarFunction { func, args }) => {
412                let (arg_types, nullables): (Vec<DataType>, Vec<bool>) = args
413                    .iter()
414                    .map(|e| e.data_type_and_nullable(schema))
415                    .collect::<Result<Vec<_>>>()?
416                    .into_iter()
417                    .unzip();
418                // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature`
419                let new_data_types = data_types_with_scalar_udf(&arg_types, func)
420                    .map_err(|err| {
421                        plan_datafusion_err!(
422                            "{} {}",
423                            match err {
424                                DataFusionError::Plan(msg) => msg,
425                                err => err.to_string(),
426                            },
427                            utils::generate_signature_error_msg(
428                                func.name(),
429                                func.signature().clone(),
430                                &arg_types,
431                            )
432                        )
433                    })?;
434
435                let arguments = args
436                    .iter()
437                    .map(|e| match e {
438                        Expr::Literal(sv) => Some(sv),
439                        _ => None,
440                    })
441                    .collect::<Vec<_>>();
442                let args = ReturnTypeArgs {
443                    arg_types: &new_data_types,
444                    scalar_arguments: &arguments,
445                    nullables: &nullables,
446                };
447
448                let (return_type, nullable) =
449                    func.return_type_from_args(args)?.into_parts();
450                Ok((return_type, nullable))
451            }
452            _ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
453        }
454    }
455
456    /// Returns a [arrow::datatypes::Field] compatible with this expression.
457    ///
458    /// So for example, a projected expression `col(c1) + col(c2)` is
459    /// placed in an output field **named** col("c1 + c2")
460    fn to_field(
461        &self,
462        input_schema: &dyn ExprSchema,
463    ) -> Result<(Option<TableReference>, Arc<Field>)> {
464        let (relation, schema_name) = self.qualified_name();
465        let (data_type, nullable) = self.data_type_and_nullable(input_schema)?;
466        let field = Field::new(schema_name, data_type, nullable)
467            .with_metadata(self.metadata(input_schema)?)
468            .into();
469        Ok((relation, field))
470    }
471
472    /// Wraps this expression in a cast to a target [arrow::datatypes::DataType].
473    ///
474    /// # Errors
475    ///
476    /// This function errors when it is impossible to cast the
477    /// expression to the target [arrow::datatypes::DataType].
478    fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr> {
479        let this_type = self.get_type(schema)?;
480        if this_type == *cast_to_type {
481            return Ok(self);
482        }
483
484        // TODO(kszucs): Most of the operations do not validate the type correctness
485        // like all of the binary expressions below. Perhaps Expr should track the
486        // type of the expression?
487
488        if can_cast_types(&this_type, cast_to_type) {
489            match self {
490                Expr::ScalarSubquery(subquery) => {
491                    Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?))
492                }
493                _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))),
494            }
495        } else {
496            plan_err!("Cannot automatically convert {this_type:?} to {cast_to_type:?}")
497        }
498    }
499}
500
501impl Expr {
502    /// Common method for window functions that applies type coercion
503    /// to all arguments of the window function to check if it matches
504    /// its signature.
505    ///
506    /// If successful, this method returns the data type and
507    /// nullability of the window function's result.
508    ///
509    /// Otherwise, returns an error if there's a type mismatch between
510    /// the window function's signature and the provided arguments.
511    fn data_type_and_nullable_with_window_function(
512        &self,
513        schema: &dyn ExprSchema,
514        window_function: &WindowFunction,
515    ) -> Result<(DataType, bool)> {
516        let WindowFunction {
517            fun,
518            params: WindowFunctionParams { args, .. },
519            ..
520        } = window_function;
521
522        let data_types = args
523            .iter()
524            .map(|e| e.get_type(schema))
525            .collect::<Result<Vec<_>>>()?;
526        match fun {
527            WindowFunctionDefinition::AggregateUDF(udaf) => {
528                let new_types = data_types_with_aggregate_udf(&data_types, udaf)
529                    .map_err(|err| {
530                        plan_datafusion_err!(
531                            "{} {}",
532                            match err {
533                                DataFusionError::Plan(msg) => msg,
534                                err => err.to_string(),
535                            },
536                            utils::generate_signature_error_msg(
537                                fun.name(),
538                                fun.signature(),
539                                &data_types
540                            )
541                        )
542                    })?;
543
544                let return_type = udaf.return_type(&new_types)?;
545                let nullable = udaf.is_nullable();
546
547                Ok((return_type, nullable))
548            }
549            WindowFunctionDefinition::WindowUDF(udwf) => {
550                let new_types =
551                    data_types_with_window_udf(&data_types, udwf).map_err(|err| {
552                        plan_datafusion_err!(
553                            "{} {}",
554                            match err {
555                                DataFusionError::Plan(msg) => msg,
556                                err => err.to_string(),
557                            },
558                            utils::generate_signature_error_msg(
559                                fun.name(),
560                                fun.signature(),
561                                &data_types
562                            )
563                        )
564                    })?;
565                let (_, function_name) = self.qualified_name();
566                let field_args = WindowUDFFieldArgs::new(&new_types, &function_name);
567
568                udwf.field(field_args)
569                    .map(|field| (field.data_type().clone(), field.is_nullable()))
570            }
571        }
572    }
573}
574
575/// Cast subquery in InSubquery/ScalarSubquery to a given type.
576///
577/// 1. **Projection plan**: If the subquery is a projection (i.e. a SELECT statement with specific
578///    columns), it casts the first expression in the projection to the target type and creates a
579///    new projection with the casted expression.
580/// 2. **Non-projection plan**: If the subquery isn't a projection, it adds a projection to the plan
581///    with the casted first column.
582///
583pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
584    if subquery.subquery.schema().field(0).data_type() == cast_to_type {
585        return Ok(subquery);
586    }
587
588    let plan = subquery.subquery.as_ref();
589    let new_plan = match plan {
590        LogicalPlan::Projection(projection) => {
591            let cast_expr = projection.expr[0]
592                .clone()
593                .cast_to(cast_to_type, projection.input.schema())?;
594            LogicalPlan::Projection(Projection::try_new(
595                vec![cast_expr],
596                Arc::clone(&projection.input),
597            )?)
598        }
599        _ => {
600            let cast_expr = Expr::Column(Column::from(plan.schema().qualified_field(0)))
601                .cast_to(cast_to_type, subquery.subquery.schema())?;
602            LogicalPlan::Projection(Projection::try_new(
603                vec![cast_expr],
604                subquery.subquery,
605            )?)
606        }
607    };
608    Ok(Subquery {
609        subquery: Arc::new(new_plan),
610        outer_ref_columns: subquery.outer_ref_columns,
611    })
612}
613
614#[cfg(test)]
615mod tests {
616    use super::*;
617    use crate::{col, lit};
618
619    use datafusion_common::{internal_err, DFSchema, ScalarValue};
620
621    macro_rules! test_is_expr_nullable {
622        ($EXPR_TYPE:ident) => {{
623            let expr = lit(ScalarValue::Null).$EXPR_TYPE();
624            assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
625        }};
626    }
627
628    #[test]
629    fn expr_schema_nullability() {
630        let expr = col("foo").eq(lit(1));
631        assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
632        assert!(expr
633            .nullable(&MockExprSchema::new().with_nullable(true))
634            .unwrap());
635
636        test_is_expr_nullable!(is_null);
637        test_is_expr_nullable!(is_not_null);
638        test_is_expr_nullable!(is_true);
639        test_is_expr_nullable!(is_not_true);
640        test_is_expr_nullable!(is_false);
641        test_is_expr_nullable!(is_not_false);
642        test_is_expr_nullable!(is_unknown);
643        test_is_expr_nullable!(is_not_unknown);
644    }
645
646    #[test]
647    fn test_between_nullability() {
648        let get_schema = |nullable| {
649            MockExprSchema::new()
650                .with_data_type(DataType::Int32)
651                .with_nullable(nullable)
652        };
653
654        let expr = col("foo").between(lit(1), lit(2));
655        assert!(!expr.nullable(&get_schema(false)).unwrap());
656        assert!(expr.nullable(&get_schema(true)).unwrap());
657
658        let null = lit(ScalarValue::Int32(None));
659
660        let expr = col("foo").between(null.clone(), lit(2));
661        assert!(expr.nullable(&get_schema(false)).unwrap());
662
663        let expr = col("foo").between(lit(1), null.clone());
664        assert!(expr.nullable(&get_schema(false)).unwrap());
665
666        let expr = col("foo").between(null.clone(), null);
667        assert!(expr.nullable(&get_schema(false)).unwrap());
668    }
669
670    #[test]
671    fn test_inlist_nullability() {
672        let get_schema = |nullable| {
673            MockExprSchema::new()
674                .with_data_type(DataType::Int32)
675                .with_nullable(nullable)
676        };
677
678        let expr = col("foo").in_list(vec![lit(1); 5], false);
679        assert!(!expr.nullable(&get_schema(false)).unwrap());
680        assert!(expr.nullable(&get_schema(true)).unwrap());
681        // Testing nullable() returns an error.
682        assert!(expr
683            .nullable(&get_schema(false).with_error_on_nullable(true))
684            .is_err());
685
686        let null = lit(ScalarValue::Int32(None));
687        let expr = col("foo").in_list(vec![null, lit(1)], false);
688        assert!(expr.nullable(&get_schema(false)).unwrap());
689
690        // Testing on long list
691        let expr = col("foo").in_list(vec![lit(1); 6], false);
692        assert!(expr.nullable(&get_schema(false)).unwrap());
693    }
694
695    #[test]
696    fn test_like_nullability() {
697        let get_schema = |nullable| {
698            MockExprSchema::new()
699                .with_data_type(DataType::Utf8)
700                .with_nullable(nullable)
701        };
702
703        let expr = col("foo").like(lit("bar"));
704        assert!(!expr.nullable(&get_schema(false)).unwrap());
705        assert!(expr.nullable(&get_schema(true)).unwrap());
706
707        let expr = col("foo").like(lit(ScalarValue::Utf8(None)));
708        assert!(expr.nullable(&get_schema(false)).unwrap());
709    }
710
711    #[test]
712    fn expr_schema_data_type() {
713        let expr = col("foo");
714        assert_eq!(
715            DataType::Utf8,
716            expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
717                .unwrap()
718        );
719    }
720
721    #[test]
722    fn test_expr_metadata() {
723        let mut meta = HashMap::new();
724        meta.insert("bar".to_string(), "buzz".to_string());
725        let expr = col("foo");
726        let schema = MockExprSchema::new()
727            .with_data_type(DataType::Int32)
728            .with_metadata(meta.clone());
729
730        // col, alias, and cast should be metadata-preserving
731        assert_eq!(meta, expr.metadata(&schema).unwrap());
732        assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap());
733        assert_eq!(
734            meta,
735            expr.clone()
736                .cast_to(&DataType::Int64, &schema)
737                .unwrap()
738                .metadata(&schema)
739                .unwrap()
740        );
741
742        let schema = DFSchema::from_unqualified_fields(
743            vec![Field::new("foo", DataType::Int32, true).with_metadata(meta.clone())]
744                .into(),
745            HashMap::new(),
746        )
747        .unwrap();
748
749        // verify to_field method populates metadata
750        assert_eq!(&meta, expr.to_field(&schema).unwrap().1.metadata());
751    }
752
753    #[derive(Debug)]
754    struct MockExprSchema {
755        nullable: bool,
756        data_type: DataType,
757        error_on_nullable: bool,
758        metadata: HashMap<String, String>,
759    }
760
761    impl MockExprSchema {
762        fn new() -> Self {
763            Self {
764                nullable: false,
765                data_type: DataType::Null,
766                error_on_nullable: false,
767                metadata: HashMap::new(),
768            }
769        }
770
771        fn with_nullable(mut self, nullable: bool) -> Self {
772            self.nullable = nullable;
773            self
774        }
775
776        fn with_data_type(mut self, data_type: DataType) -> Self {
777            self.data_type = data_type;
778            self
779        }
780
781        fn with_error_on_nullable(mut self, error_on_nullable: bool) -> Self {
782            self.error_on_nullable = error_on_nullable;
783            self
784        }
785
786        fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
787            self.metadata = metadata;
788            self
789        }
790    }
791
792    impl ExprSchema for MockExprSchema {
793        fn nullable(&self, _col: &Column) -> Result<bool> {
794            if self.error_on_nullable {
795                internal_err!("nullable error")
796            } else {
797                Ok(self.nullable)
798            }
799        }
800
801        fn data_type(&self, _col: &Column) -> Result<&DataType> {
802            Ok(&self.data_type)
803        }
804
805        fn metadata(&self, _col: &Column) -> Result<&HashMap<String, String>> {
806            Ok(&self.metadata)
807        }
808
809        fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> {
810            Ok((self.data_type(col)?, self.nullable(col)?))
811        }
812    }
813}