1use super::{Between, Expr, Like};
19use crate::expr::{
20 AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList,
21 InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction,
22 WindowFunctionParams,
23};
24use crate::type_coercion::functions::{
25 data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf,
26};
27use crate::udf::ReturnTypeArgs;
28use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition};
29use arrow::compute::can_cast_types;
30use arrow::datatypes::{DataType, Field};
31use datafusion_common::{
32 not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema,
33 Result, TableReference,
34};
35use datafusion_expr_common::type_coercion::binary::BinaryTypeCoercer;
36use datafusion_functions_window_common::field::WindowUDFFieldArgs;
37use std::collections::HashMap;
38use std::sync::Arc;
39
40pub trait ExprSchemable {
42 fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType>;
44
45 fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool>;
47
48 fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>>;
50
51 fn to_field(
53 &self,
54 input_schema: &dyn ExprSchema,
55 ) -> Result<(Option<TableReference>, Arc<Field>)>;
56
57 fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr>;
59
60 fn data_type_and_nullable(&self, schema: &dyn ExprSchema)
62 -> Result<(DataType, bool)>;
63}
64
65impl ExprSchemable for Expr {
66 #[cfg_attr(feature = "recursive_protection", recursive::recursive)]
105 fn get_type(&self, schema: &dyn ExprSchema) -> Result<DataType> {
106 match self {
107 Expr::Alias(Alias { expr, name, .. }) => match &**expr {
108 Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
109 None => schema.data_type(&Column::from_name(name)).cloned(),
110 Some(dt) => Ok(dt.clone()),
111 },
112 _ => expr.get_type(schema),
113 },
114 Expr::Negative(expr) => expr.get_type(schema),
115 Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
116 Expr::OuterReferenceColumn(ty, _) => Ok(ty.clone()),
117 Expr::ScalarVariable(ty, _) => Ok(ty.clone()),
118 Expr::Literal(l) => Ok(l.data_type()),
119 Expr::Case(case) => {
120 for (_, then_expr) in &case.when_then_expr {
121 let then_type = then_expr.get_type(schema)?;
122 if !then_type.is_null() {
123 return Ok(then_type);
124 }
125 }
126 case.else_expr
127 .as_ref()
128 .map_or(Ok(DataType::Null), |e| e.get_type(schema))
129 }
130 Expr::Cast(Cast { data_type, .. })
131 | Expr::TryCast(TryCast { data_type, .. }) => Ok(data_type.clone()),
132 Expr::Unnest(Unnest { expr }) => {
133 let arg_data_type = expr.get_type(schema)?;
134 match arg_data_type {
136 DataType::List(field)
137 | DataType::LargeList(field)
138 | DataType::FixedSizeList(field, _) => Ok(field.data_type().clone()),
139 DataType::Struct(_) => Ok(arg_data_type),
140 DataType::Null => {
141 not_impl_err!("unnest() does not support null yet")
142 }
143 _ => {
144 plan_err!(
145 "unnest() can only be applied to array, struct and null"
146 )
147 }
148 }
149 }
150 Expr::ScalarFunction(_func) => {
151 let (return_type, _) = self.data_type_and_nullable(schema)?;
152 Ok(return_type)
153 }
154 Expr::WindowFunction(window_function) => self
155 .data_type_and_nullable_with_window_function(schema, window_function)
156 .map(|(return_type, _)| return_type),
157 Expr::AggregateFunction(AggregateFunction {
158 func,
159 params: AggregateFunctionParams { args, .. },
160 }) => {
161 let data_types = args
162 .iter()
163 .map(|e| e.get_type(schema))
164 .collect::<Result<Vec<_>>>()?;
165 let new_types = data_types_with_aggregate_udf(&data_types, func)
166 .map_err(|err| {
167 plan_datafusion_err!(
168 "{} {}",
169 match err {
170 DataFusionError::Plan(msg) => msg,
171 err => err.to_string(),
172 },
173 utils::generate_signature_error_msg(
174 func.name(),
175 func.signature().clone(),
176 &data_types
177 )
178 )
179 })?;
180 Ok(func.return_type(&new_types)?)
181 }
182 Expr::Not(_)
183 | Expr::IsNull(_)
184 | Expr::Exists { .. }
185 | Expr::InSubquery(_)
186 | Expr::Between { .. }
187 | Expr::InList { .. }
188 | Expr::IsNotNull(_)
189 | Expr::IsTrue(_)
190 | Expr::IsFalse(_)
191 | Expr::IsUnknown(_)
192 | Expr::IsNotTrue(_)
193 | Expr::IsNotFalse(_)
194 | Expr::IsNotUnknown(_) => Ok(DataType::Boolean),
195 Expr::ScalarSubquery(subquery) => {
196 Ok(subquery.subquery.schema().field(0).data_type().clone())
197 }
198 Expr::BinaryExpr(BinaryExpr {
199 ref left,
200 ref right,
201 ref op,
202 }) => BinaryTypeCoercer::new(
203 &left.get_type(schema)?,
204 op,
205 &right.get_type(schema)?,
206 )
207 .get_result_type(),
208 Expr::Like { .. } | Expr::SimilarTo { .. } => Ok(DataType::Boolean),
209 Expr::Placeholder(Placeholder { data_type, .. }) => {
210 if let Some(dtype) = data_type {
211 Ok(dtype.clone())
212 } else {
213 Ok(DataType::Null)
216 }
217 }
218 #[expect(deprecated)]
219 Expr::Wildcard { .. } => Ok(DataType::Null),
220 Expr::GroupingSet(_) => {
221 Ok(DataType::Null)
223 }
224 }
225 }
226
227 fn nullable(&self, input_schema: &dyn ExprSchema) -> Result<bool> {
239 match self {
240 Expr::Alias(Alias { expr, .. }) | Expr::Not(expr) | Expr::Negative(expr) => {
241 expr.nullable(input_schema)
242 }
243
244 Expr::InList(InList { expr, list, .. }) => {
245 const MAX_INSPECT_LIMIT: usize = 6;
247 let has_nullable = std::iter::once(expr.as_ref())
249 .chain(list)
250 .take(MAX_INSPECT_LIMIT)
251 .find_map(|e| {
252 e.nullable(input_schema)
253 .map(|nullable| if nullable { Some(()) } else { None })
254 .transpose()
255 })
256 .transpose()?;
257 Ok(match has_nullable {
258 Some(_) => true,
260 None if list.len() + 1 > MAX_INSPECT_LIMIT => true,
262 _ => false,
264 })
265 }
266
267 Expr::Between(Between {
268 expr, low, high, ..
269 }) => Ok(expr.nullable(input_schema)?
270 || low.nullable(input_schema)?
271 || high.nullable(input_schema)?),
272
273 Expr::Column(c) => input_schema.nullable(c),
274 Expr::OuterReferenceColumn(_, _) => Ok(true),
275 Expr::Literal(value) => Ok(value.is_null()),
276 Expr::Case(case) => {
277 let then_nullable = case
279 .when_then_expr
280 .iter()
281 .map(|(_, t)| t.nullable(input_schema))
282 .collect::<Result<Vec<_>>>()?;
283 if then_nullable.contains(&true) {
284 Ok(true)
285 } else if let Some(e) = &case.else_expr {
286 e.nullable(input_schema)
287 } else {
288 Ok(true)
291 }
292 }
293 Expr::Cast(Cast { expr, .. }) => expr.nullable(input_schema),
294 Expr::ScalarFunction(_func) => {
295 let (_, nullable) = self.data_type_and_nullable(input_schema)?;
296 Ok(nullable)
297 }
298 Expr::AggregateFunction(AggregateFunction { func, .. }) => {
299 Ok(func.is_nullable())
300 }
301 Expr::WindowFunction(window_function) => self
302 .data_type_and_nullable_with_window_function(
303 input_schema,
304 window_function,
305 )
306 .map(|(_, nullable)| nullable),
307 Expr::ScalarVariable(_, _)
308 | Expr::TryCast { .. }
309 | Expr::Unnest(_)
310 | Expr::Placeholder(_) => Ok(true),
311 Expr::IsNull(_)
312 | Expr::IsNotNull(_)
313 | Expr::IsTrue(_)
314 | Expr::IsFalse(_)
315 | Expr::IsUnknown(_)
316 | Expr::IsNotTrue(_)
317 | Expr::IsNotFalse(_)
318 | Expr::IsNotUnknown(_)
319 | Expr::Exists { .. } => Ok(false),
320 Expr::InSubquery(InSubquery { expr, .. }) => expr.nullable(input_schema),
321 Expr::ScalarSubquery(subquery) => {
322 Ok(subquery.subquery.schema().field(0).is_nullable())
323 }
324 Expr::BinaryExpr(BinaryExpr {
325 ref left,
326 ref right,
327 ..
328 }) => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
329 Expr::Like(Like { expr, pattern, .. })
330 | Expr::SimilarTo(Like { expr, pattern, .. }) => {
331 Ok(expr.nullable(input_schema)? || pattern.nullable(input_schema)?)
332 }
333 #[expect(deprecated)]
334 Expr::Wildcard { .. } => Ok(false),
335 Expr::GroupingSet(_) => {
336 Ok(true)
339 }
340 }
341 }
342
343 fn metadata(&self, schema: &dyn ExprSchema) -> Result<HashMap<String, String>> {
344 match self {
345 Expr::Column(c) => Ok(schema.metadata(c)?.clone()),
346 Expr::Alias(Alias { expr, .. }) => expr.metadata(schema),
347 Expr::Cast(Cast { expr, .. }) => expr.metadata(schema),
348 _ => Ok(HashMap::new()),
349 }
350 }
351
352 fn data_type_and_nullable(
363 &self,
364 schema: &dyn ExprSchema,
365 ) -> Result<(DataType, bool)> {
366 match self {
367 Expr::Alias(Alias { expr, name, .. }) => match &**expr {
368 Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type {
369 None => schema
370 .data_type_and_nullable(&Column::from_name(name))
371 .map(|(d, n)| (d.clone(), n)),
372 Some(dt) => Ok((dt.clone(), expr.nullable(schema)?)),
373 },
374 _ => expr.data_type_and_nullable(schema),
375 },
376 Expr::Negative(expr) => expr.data_type_and_nullable(schema),
377 Expr::Column(c) => schema
378 .data_type_and_nullable(c)
379 .map(|(d, n)| (d.clone(), n)),
380 Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)),
381 Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)),
382 Expr::Literal(l) => Ok((l.data_type(), l.is_null())),
383 Expr::IsNull(_)
384 | Expr::IsNotNull(_)
385 | Expr::IsTrue(_)
386 | Expr::IsFalse(_)
387 | Expr::IsUnknown(_)
388 | Expr::IsNotTrue(_)
389 | Expr::IsNotFalse(_)
390 | Expr::IsNotUnknown(_)
391 | Expr::Exists { .. } => Ok((DataType::Boolean, false)),
392 Expr::ScalarSubquery(subquery) => Ok((
393 subquery.subquery.schema().field(0).data_type().clone(),
394 subquery.subquery.schema().field(0).is_nullable(),
395 )),
396 Expr::BinaryExpr(BinaryExpr {
397 ref left,
398 ref right,
399 ref op,
400 }) => {
401 let (lhs_type, lhs_nullable) = left.data_type_and_nullable(schema)?;
402 let (rhs_type, rhs_nullable) = right.data_type_and_nullable(schema)?;
403 let mut coercer = BinaryTypeCoercer::new(&lhs_type, op, &rhs_type);
404 coercer.set_lhs_spans(left.spans().cloned().unwrap_or_default());
405 coercer.set_rhs_spans(right.spans().cloned().unwrap_or_default());
406 Ok((coercer.get_result_type()?, lhs_nullable || rhs_nullable))
407 }
408 Expr::WindowFunction(window_function) => {
409 self.data_type_and_nullable_with_window_function(schema, window_function)
410 }
411 Expr::ScalarFunction(ScalarFunction { func, args }) => {
412 let (arg_types, nullables): (Vec<DataType>, Vec<bool>) = args
413 .iter()
414 .map(|e| e.data_type_and_nullable(schema))
415 .collect::<Result<Vec<_>>>()?
416 .into_iter()
417 .unzip();
418 let new_data_types = data_types_with_scalar_udf(&arg_types, func)
420 .map_err(|err| {
421 plan_datafusion_err!(
422 "{} {}",
423 match err {
424 DataFusionError::Plan(msg) => msg,
425 err => err.to_string(),
426 },
427 utils::generate_signature_error_msg(
428 func.name(),
429 func.signature().clone(),
430 &arg_types,
431 )
432 )
433 })?;
434
435 let arguments = args
436 .iter()
437 .map(|e| match e {
438 Expr::Literal(sv) => Some(sv),
439 _ => None,
440 })
441 .collect::<Vec<_>>();
442 let args = ReturnTypeArgs {
443 arg_types: &new_data_types,
444 scalar_arguments: &arguments,
445 nullables: &nullables,
446 };
447
448 let (return_type, nullable) =
449 func.return_type_from_args(args)?.into_parts();
450 Ok((return_type, nullable))
451 }
452 _ => Ok((self.get_type(schema)?, self.nullable(schema)?)),
453 }
454 }
455
456 fn to_field(
461 &self,
462 input_schema: &dyn ExprSchema,
463 ) -> Result<(Option<TableReference>, Arc<Field>)> {
464 let (relation, schema_name) = self.qualified_name();
465 let (data_type, nullable) = self.data_type_and_nullable(input_schema)?;
466 let field = Field::new(schema_name, data_type, nullable)
467 .with_metadata(self.metadata(input_schema)?)
468 .into();
469 Ok((relation, field))
470 }
471
472 fn cast_to(self, cast_to_type: &DataType, schema: &dyn ExprSchema) -> Result<Expr> {
479 let this_type = self.get_type(schema)?;
480 if this_type == *cast_to_type {
481 return Ok(self);
482 }
483
484 if can_cast_types(&this_type, cast_to_type) {
489 match self {
490 Expr::ScalarSubquery(subquery) => {
491 Ok(Expr::ScalarSubquery(cast_subquery(subquery, cast_to_type)?))
492 }
493 _ => Ok(Expr::Cast(Cast::new(Box::new(self), cast_to_type.clone()))),
494 }
495 } else {
496 plan_err!("Cannot automatically convert {this_type:?} to {cast_to_type:?}")
497 }
498 }
499}
500
501impl Expr {
502 fn data_type_and_nullable_with_window_function(
512 &self,
513 schema: &dyn ExprSchema,
514 window_function: &WindowFunction,
515 ) -> Result<(DataType, bool)> {
516 let WindowFunction {
517 fun,
518 params: WindowFunctionParams { args, .. },
519 ..
520 } = window_function;
521
522 let data_types = args
523 .iter()
524 .map(|e| e.get_type(schema))
525 .collect::<Result<Vec<_>>>()?;
526 match fun {
527 WindowFunctionDefinition::AggregateUDF(udaf) => {
528 let new_types = data_types_with_aggregate_udf(&data_types, udaf)
529 .map_err(|err| {
530 plan_datafusion_err!(
531 "{} {}",
532 match err {
533 DataFusionError::Plan(msg) => msg,
534 err => err.to_string(),
535 },
536 utils::generate_signature_error_msg(
537 fun.name(),
538 fun.signature(),
539 &data_types
540 )
541 )
542 })?;
543
544 let return_type = udaf.return_type(&new_types)?;
545 let nullable = udaf.is_nullable();
546
547 Ok((return_type, nullable))
548 }
549 WindowFunctionDefinition::WindowUDF(udwf) => {
550 let new_types =
551 data_types_with_window_udf(&data_types, udwf).map_err(|err| {
552 plan_datafusion_err!(
553 "{} {}",
554 match err {
555 DataFusionError::Plan(msg) => msg,
556 err => err.to_string(),
557 },
558 utils::generate_signature_error_msg(
559 fun.name(),
560 fun.signature(),
561 &data_types
562 )
563 )
564 })?;
565 let (_, function_name) = self.qualified_name();
566 let field_args = WindowUDFFieldArgs::new(&new_types, &function_name);
567
568 udwf.field(field_args)
569 .map(|field| (field.data_type().clone(), field.is_nullable()))
570 }
571 }
572 }
573}
574
575pub fn cast_subquery(subquery: Subquery, cast_to_type: &DataType) -> Result<Subquery> {
584 if subquery.subquery.schema().field(0).data_type() == cast_to_type {
585 return Ok(subquery);
586 }
587
588 let plan = subquery.subquery.as_ref();
589 let new_plan = match plan {
590 LogicalPlan::Projection(projection) => {
591 let cast_expr = projection.expr[0]
592 .clone()
593 .cast_to(cast_to_type, projection.input.schema())?;
594 LogicalPlan::Projection(Projection::try_new(
595 vec![cast_expr],
596 Arc::clone(&projection.input),
597 )?)
598 }
599 _ => {
600 let cast_expr = Expr::Column(Column::from(plan.schema().qualified_field(0)))
601 .cast_to(cast_to_type, subquery.subquery.schema())?;
602 LogicalPlan::Projection(Projection::try_new(
603 vec![cast_expr],
604 subquery.subquery,
605 )?)
606 }
607 };
608 Ok(Subquery {
609 subquery: Arc::new(new_plan),
610 outer_ref_columns: subquery.outer_ref_columns,
611 })
612}
613
614#[cfg(test)]
615mod tests {
616 use super::*;
617 use crate::{col, lit};
618
619 use datafusion_common::{internal_err, DFSchema, ScalarValue};
620
621 macro_rules! test_is_expr_nullable {
622 ($EXPR_TYPE:ident) => {{
623 let expr = lit(ScalarValue::Null).$EXPR_TYPE();
624 assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
625 }};
626 }
627
628 #[test]
629 fn expr_schema_nullability() {
630 let expr = col("foo").eq(lit(1));
631 assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
632 assert!(expr
633 .nullable(&MockExprSchema::new().with_nullable(true))
634 .unwrap());
635
636 test_is_expr_nullable!(is_null);
637 test_is_expr_nullable!(is_not_null);
638 test_is_expr_nullable!(is_true);
639 test_is_expr_nullable!(is_not_true);
640 test_is_expr_nullable!(is_false);
641 test_is_expr_nullable!(is_not_false);
642 test_is_expr_nullable!(is_unknown);
643 test_is_expr_nullable!(is_not_unknown);
644 }
645
646 #[test]
647 fn test_between_nullability() {
648 let get_schema = |nullable| {
649 MockExprSchema::new()
650 .with_data_type(DataType::Int32)
651 .with_nullable(nullable)
652 };
653
654 let expr = col("foo").between(lit(1), lit(2));
655 assert!(!expr.nullable(&get_schema(false)).unwrap());
656 assert!(expr.nullable(&get_schema(true)).unwrap());
657
658 let null = lit(ScalarValue::Int32(None));
659
660 let expr = col("foo").between(null.clone(), lit(2));
661 assert!(expr.nullable(&get_schema(false)).unwrap());
662
663 let expr = col("foo").between(lit(1), null.clone());
664 assert!(expr.nullable(&get_schema(false)).unwrap());
665
666 let expr = col("foo").between(null.clone(), null);
667 assert!(expr.nullable(&get_schema(false)).unwrap());
668 }
669
670 #[test]
671 fn test_inlist_nullability() {
672 let get_schema = |nullable| {
673 MockExprSchema::new()
674 .with_data_type(DataType::Int32)
675 .with_nullable(nullable)
676 };
677
678 let expr = col("foo").in_list(vec![lit(1); 5], false);
679 assert!(!expr.nullable(&get_schema(false)).unwrap());
680 assert!(expr.nullable(&get_schema(true)).unwrap());
681 assert!(expr
683 .nullable(&get_schema(false).with_error_on_nullable(true))
684 .is_err());
685
686 let null = lit(ScalarValue::Int32(None));
687 let expr = col("foo").in_list(vec![null, lit(1)], false);
688 assert!(expr.nullable(&get_schema(false)).unwrap());
689
690 let expr = col("foo").in_list(vec![lit(1); 6], false);
692 assert!(expr.nullable(&get_schema(false)).unwrap());
693 }
694
695 #[test]
696 fn test_like_nullability() {
697 let get_schema = |nullable| {
698 MockExprSchema::new()
699 .with_data_type(DataType::Utf8)
700 .with_nullable(nullable)
701 };
702
703 let expr = col("foo").like(lit("bar"));
704 assert!(!expr.nullable(&get_schema(false)).unwrap());
705 assert!(expr.nullable(&get_schema(true)).unwrap());
706
707 let expr = col("foo").like(lit(ScalarValue::Utf8(None)));
708 assert!(expr.nullable(&get_schema(false)).unwrap());
709 }
710
711 #[test]
712 fn expr_schema_data_type() {
713 let expr = col("foo");
714 assert_eq!(
715 DataType::Utf8,
716 expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
717 .unwrap()
718 );
719 }
720
721 #[test]
722 fn test_expr_metadata() {
723 let mut meta = HashMap::new();
724 meta.insert("bar".to_string(), "buzz".to_string());
725 let expr = col("foo");
726 let schema = MockExprSchema::new()
727 .with_data_type(DataType::Int32)
728 .with_metadata(meta.clone());
729
730 assert_eq!(meta, expr.metadata(&schema).unwrap());
732 assert_eq!(meta, expr.clone().alias("bar").metadata(&schema).unwrap());
733 assert_eq!(
734 meta,
735 expr.clone()
736 .cast_to(&DataType::Int64, &schema)
737 .unwrap()
738 .metadata(&schema)
739 .unwrap()
740 );
741
742 let schema = DFSchema::from_unqualified_fields(
743 vec![Field::new("foo", DataType::Int32, true).with_metadata(meta.clone())]
744 .into(),
745 HashMap::new(),
746 )
747 .unwrap();
748
749 assert_eq!(&meta, expr.to_field(&schema).unwrap().1.metadata());
751 }
752
753 #[derive(Debug)]
754 struct MockExprSchema {
755 nullable: bool,
756 data_type: DataType,
757 error_on_nullable: bool,
758 metadata: HashMap<String, String>,
759 }
760
761 impl MockExprSchema {
762 fn new() -> Self {
763 Self {
764 nullable: false,
765 data_type: DataType::Null,
766 error_on_nullable: false,
767 metadata: HashMap::new(),
768 }
769 }
770
771 fn with_nullable(mut self, nullable: bool) -> Self {
772 self.nullable = nullable;
773 self
774 }
775
776 fn with_data_type(mut self, data_type: DataType) -> Self {
777 self.data_type = data_type;
778 self
779 }
780
781 fn with_error_on_nullable(mut self, error_on_nullable: bool) -> Self {
782 self.error_on_nullable = error_on_nullable;
783 self
784 }
785
786 fn with_metadata(mut self, metadata: HashMap<String, String>) -> Self {
787 self.metadata = metadata;
788 self
789 }
790 }
791
792 impl ExprSchema for MockExprSchema {
793 fn nullable(&self, _col: &Column) -> Result<bool> {
794 if self.error_on_nullable {
795 internal_err!("nullable error")
796 } else {
797 Ok(self.nullable)
798 }
799 }
800
801 fn data_type(&self, _col: &Column) -> Result<&DataType> {
802 Ok(&self.data_type)
803 }
804
805 fn metadata(&self, _col: &Column) -> Result<&HashMap<String, String>> {
806 Ok(&self.metadata)
807 }
808
809 fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> {
810 Ok((self.data_type(col)?, self.nullable(col)?))
811 }
812 }
813}