1use super::binary::{binary_numeric_coercion, comparison_coercion};
19use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
20use arrow::{
21 compute::can_cast_types,
22 datatypes::{DataType, TimeUnit},
23};
24use datafusion_common::types::LogicalType;
25use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion};
26use datafusion_common::{
27 exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType,
28 utils::list_ndims, Result,
29};
30use datafusion_expr_common::signature::ArrayFunctionArgument;
31use datafusion_expr_common::{
32 signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
33 type_coercion::binary::comparison_coercion_numeric,
34 type_coercion::binary::string_coercion,
35};
36use std::sync::Arc;
37
38pub fn data_types_with_scalar_udf(
46 current_types: &[DataType],
47 func: &ScalarUDF,
48) -> Result<Vec<DataType>> {
49 let signature = func.signature();
50 let type_signature = &signature.type_signature;
51
52 if current_types.is_empty() {
53 if type_signature.supports_zero_argument() {
54 return Ok(vec![]);
55 } else if type_signature.used_to_support_zero_arguments() {
56 return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
58 } else {
59 return plan_err!("'{}' does not support zero arguments", func.name());
60 }
61 }
62
63 let valid_types =
64 get_valid_types_with_scalar_udf(type_signature, current_types, func)?;
65
66 if valid_types
67 .iter()
68 .any(|data_type| data_type == current_types)
69 {
70 return Ok(current_types.to_vec());
71 }
72
73 try_coerce_types(func.name(), valid_types, current_types, type_signature)
74}
75
76pub fn data_types_with_aggregate_udf(
84 current_types: &[DataType],
85 func: &AggregateUDF,
86) -> Result<Vec<DataType>> {
87 let signature = func.signature();
88 let type_signature = &signature.type_signature;
89
90 if current_types.is_empty() {
91 if type_signature.supports_zero_argument() {
92 return Ok(vec![]);
93 } else if type_signature.used_to_support_zero_arguments() {
94 return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
96 } else {
97 return plan_err!("'{}' does not support zero arguments", func.name());
98 }
99 }
100
101 let valid_types =
102 get_valid_types_with_aggregate_udf(type_signature, current_types, func)?;
103 if valid_types
104 .iter()
105 .any(|data_type| data_type == current_types)
106 {
107 return Ok(current_types.to_vec());
108 }
109
110 try_coerce_types(func.name(), valid_types, current_types, type_signature)
111}
112
113pub fn data_types_with_window_udf(
121 current_types: &[DataType],
122 func: &WindowUDF,
123) -> Result<Vec<DataType>> {
124 let signature = func.signature();
125 let type_signature = &signature.type_signature;
126
127 if current_types.is_empty() {
128 if type_signature.supports_zero_argument() {
129 return Ok(vec![]);
130 } else if type_signature.used_to_support_zero_arguments() {
131 return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
133 } else {
134 return plan_err!("'{}' does not support zero arguments", func.name());
135 }
136 }
137
138 let valid_types =
139 get_valid_types_with_window_udf(type_signature, current_types, func)?;
140 if valid_types
141 .iter()
142 .any(|data_type| data_type == current_types)
143 {
144 return Ok(current_types.to_vec());
145 }
146
147 try_coerce_types(func.name(), valid_types, current_types, type_signature)
148}
149
150pub fn data_types(
158 function_name: impl AsRef<str>,
159 current_types: &[DataType],
160 signature: &Signature,
161) -> Result<Vec<DataType>> {
162 let type_signature = &signature.type_signature;
163
164 if current_types.is_empty() {
165 if type_signature.supports_zero_argument() {
166 return Ok(vec![]);
167 } else if type_signature.used_to_support_zero_arguments() {
168 return plan_err!(
170 "function '{}' has signature {type_signature:?} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
171 function_name.as_ref()
172 );
173 } else {
174 return plan_err!(
175 "Function '{}' has signature {type_signature:?} which does not support zero arguments",
176 function_name.as_ref()
177 );
178 }
179 }
180
181 let valid_types =
182 get_valid_types(function_name.as_ref(), type_signature, current_types)?;
183 if valid_types
184 .iter()
185 .any(|data_type| data_type == current_types)
186 {
187 return Ok(current_types.to_vec());
188 }
189
190 try_coerce_types(
191 function_name.as_ref(),
192 valid_types,
193 current_types,
194 type_signature,
195 )
196}
197
198fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
199 if let TypeSignature::OneOf(signatures) = type_signature {
200 return signatures.iter().all(is_well_supported_signature);
201 }
202
203 matches!(
204 type_signature,
205 TypeSignature::UserDefined
206 | TypeSignature::Numeric(_)
207 | TypeSignature::String(_)
208 | TypeSignature::Coercible(_)
209 | TypeSignature::Any(_)
210 | TypeSignature::Nullary
211 | TypeSignature::Comparable(_)
212 )
213}
214
215fn try_coerce_types(
216 function_name: &str,
217 valid_types: Vec<Vec<DataType>>,
218 current_types: &[DataType],
219 type_signature: &TypeSignature,
220) -> Result<Vec<DataType>> {
221 let mut valid_types = valid_types;
222
223 if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
225 if !type_signature.is_one_of() {
228 assert_eq!(valid_types.len(), 1);
229 }
230
231 let valid_types = valid_types.swap_remove(0);
232 if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
233 return Ok(t);
234 }
235 } else {
236 for valid_types in valid_types {
240 if let Some(types) = maybe_data_types(&valid_types, current_types) {
241 return Ok(types);
242 }
243 }
244 }
245
246 plan_err!(
248 "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {current_types:?} to the signature {type_signature:?} failed"
249 )
250}
251
252fn get_valid_types_with_scalar_udf(
253 signature: &TypeSignature,
254 current_types: &[DataType],
255 func: &ScalarUDF,
256) -> Result<Vec<Vec<DataType>>> {
257 match signature {
258 TypeSignature::UserDefined => match func.coerce_types(current_types) {
259 Ok(coerced_types) => Ok(vec![coerced_types]),
260 Err(e) => exec_err!(
261 "Function '{}' user-defined coercion failed with {:?}",
262 func.name(),
263 e.strip_backtrace()
264 ),
265 },
266 TypeSignature::OneOf(signatures) => {
267 let mut res = vec![];
268 let mut errors = vec![];
269 for sig in signatures {
270 match get_valid_types_with_scalar_udf(sig, current_types, func) {
271 Ok(valid_types) => {
272 res.extend(valid_types);
273 }
274 Err(e) => {
275 errors.push(e.to_string());
276 }
277 }
278 }
279
280 if res.is_empty() {
282 internal_err!(
283 "Function '{}' failed to match any signature, errors: {}",
284 func.name(),
285 errors.join(",")
286 )
287 } else {
288 Ok(res)
289 }
290 }
291 _ => get_valid_types(func.name(), signature, current_types),
292 }
293}
294
295fn get_valid_types_with_aggregate_udf(
296 signature: &TypeSignature,
297 current_types: &[DataType],
298 func: &AggregateUDF,
299) -> Result<Vec<Vec<DataType>>> {
300 let valid_types = match signature {
301 TypeSignature::UserDefined => match func.coerce_types(current_types) {
302 Ok(coerced_types) => vec![coerced_types],
303 Err(e) => {
304 return exec_err!(
305 "Function '{}' user-defined coercion failed with {:?}",
306 func.name(),
307 e.strip_backtrace()
308 )
309 }
310 },
311 TypeSignature::OneOf(signatures) => signatures
312 .iter()
313 .filter_map(|t| {
314 get_valid_types_with_aggregate_udf(t, current_types, func).ok()
315 })
316 .flatten()
317 .collect::<Vec<_>>(),
318 _ => get_valid_types(func.name(), signature, current_types)?,
319 };
320
321 Ok(valid_types)
322}
323
324fn get_valid_types_with_window_udf(
325 signature: &TypeSignature,
326 current_types: &[DataType],
327 func: &WindowUDF,
328) -> Result<Vec<Vec<DataType>>> {
329 let valid_types = match signature {
330 TypeSignature::UserDefined => match func.coerce_types(current_types) {
331 Ok(coerced_types) => vec![coerced_types],
332 Err(e) => {
333 return exec_err!(
334 "Function '{}' user-defined coercion failed with {:?}",
335 func.name(),
336 e.strip_backtrace()
337 )
338 }
339 },
340 TypeSignature::OneOf(signatures) => signatures
341 .iter()
342 .filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok())
343 .flatten()
344 .collect::<Vec<_>>(),
345 _ => get_valid_types(func.name(), signature, current_types)?,
346 };
347
348 Ok(valid_types)
349}
350
351fn get_valid_types(
353 function_name: &str,
354 signature: &TypeSignature,
355 current_types: &[DataType],
356) -> Result<Vec<Vec<DataType>>> {
357 fn array_valid_types(
358 function_name: &str,
359 current_types: &[DataType],
360 arguments: &[ArrayFunctionArgument],
361 array_coercion: Option<&ListCoercion>,
362 ) -> Result<Vec<Vec<DataType>>> {
363 if current_types.len() != arguments.len() {
364 return Ok(vec![vec![]]);
365 }
366
367 let array_idx = arguments.iter().enumerate().find_map(|(idx, arg)| {
368 if *arg == ArrayFunctionArgument::Array {
369 Some(idx)
370 } else {
371 None
372 }
373 });
374 let Some(array_idx) = array_idx else {
375 return Err(internal_datafusion_err!("Function '{function_name}' expected at least one argument array argument"));
376 };
377 let Some(array_type) = array(¤t_types[array_idx]) else {
378 return Ok(vec![vec![]]);
379 };
380
381 let mut new_base_type = datafusion_common::utils::base_type(&array_type);
384 for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) {
385 match argument_type {
386 ArrayFunctionArgument::Element | ArrayFunctionArgument::Array => {
387 new_base_type =
388 coerce_array_types(function_name, current_type, &new_base_type)?;
389 }
390 ArrayFunctionArgument::Index => {}
391 }
392 }
393 let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only(
394 &array_type,
395 &new_base_type,
396 array_coercion,
397 );
398
399 let new_elem_type = match new_array_type {
400 DataType::List(ref field)
401 | DataType::LargeList(ref field)
402 | DataType::FixedSizeList(ref field, _) => field.data_type(),
403 _ => return Ok(vec![vec![]]),
404 };
405
406 let mut valid_types = Vec::with_capacity(arguments.len());
407 for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) {
408 let valid_type = match argument_type {
409 ArrayFunctionArgument::Element => new_elem_type.clone(),
410 ArrayFunctionArgument::Index => DataType::Int64,
411 ArrayFunctionArgument::Array => {
412 let Some(current_type) = array(current_type) else {
413 return Ok(vec![vec![]]);
414 };
415 let new_type =
416 datafusion_common::utils::coerced_type_with_base_type_only(
417 ¤t_type,
418 &new_base_type,
419 array_coercion,
420 );
421 if new_type != new_array_type {
423 return Ok(vec![vec![]]);
424 }
425 new_type
426 }
427 };
428 valid_types.push(valid_type);
429 }
430
431 Ok(vec![valid_types])
432 }
433
434 fn array(array_type: &DataType) -> Option<DataType> {
435 match array_type {
436 DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()),
437 DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))),
438 _ => None,
439 }
440 }
441
442 fn coerce_array_types(
443 function_name: &str,
444 current_type: &DataType,
445 base_type: &DataType,
446 ) -> Result<DataType> {
447 let current_base_type = datafusion_common::utils::base_type(current_type);
448 let new_base_type = comparison_coercion(base_type, ¤t_base_type);
449 new_base_type.ok_or_else(|| {
450 internal_datafusion_err!(
451 "Function '{function_name}' does not support coercion from {base_type:?} to {current_base_type:?}"
452 )
453 })
454 }
455
456 fn recursive_array(array_type: &DataType) -> Option<DataType> {
457 match array_type {
458 DataType::List(_)
459 | DataType::LargeList(_)
460 | DataType::FixedSizeList(_, _) => {
461 let array_type = coerced_fixed_size_list_to_list(array_type);
462 Some(array_type)
463 }
464 _ => None,
465 }
466 }
467
468 fn function_length_check(
469 function_name: &str,
470 length: usize,
471 expected_length: usize,
472 ) -> Result<()> {
473 if length != expected_length {
474 return plan_err!(
475 "Function '{function_name}' expects {expected_length} arguments but received {length}"
476 );
477 }
478 Ok(())
479 }
480
481 let valid_types = match signature {
482 TypeSignature::Variadic(valid_types) => valid_types
483 .iter()
484 .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
485 .collect(),
486 TypeSignature::String(number) => {
487 function_length_check(function_name, current_types.len(), *number)?;
488
489 let mut new_types = Vec::with_capacity(current_types.len());
490 for data_type in current_types.iter() {
491 let logical_data_type: NativeType = data_type.into();
492 if logical_data_type == NativeType::String {
493 new_types.push(data_type.to_owned());
494 } else if logical_data_type == NativeType::Null {
495 new_types.push(DataType::Utf8);
497 } else {
498 return plan_err!(
499 "Function '{function_name}' expects NativeType::String but received {logical_data_type}"
500 );
501 }
502 }
503
504 fn find_common_type(
506 function_name: &str,
507 lhs_type: &DataType,
508 rhs_type: &DataType,
509 ) -> Result<DataType> {
510 match (lhs_type, rhs_type) {
511 (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
512 find_common_type(function_name, lhs, rhs)
513 }
514 (DataType::Dictionary(_, v), other)
515 | (other, DataType::Dictionary(_, v)) => {
516 find_common_type(function_name, v, other)
517 }
518 _ => {
519 if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
520 Ok(coerced_type)
521 } else {
522 plan_err!(
523 "Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
524 )
525 }
526 }
527 }
528 }
529
530 let mut coerced_type = new_types.first().unwrap().to_owned();
532 for t in new_types.iter().skip(1) {
533 coerced_type = find_common_type(function_name, &coerced_type, t)?;
534 }
535
536 fn base_type_or_default_type(data_type: &DataType) -> DataType {
537 if let DataType::Dictionary(_, v) = data_type {
538 base_type_or_default_type(v)
539 } else {
540 data_type.to_owned()
541 }
542 }
543
544 vec![vec![base_type_or_default_type(&coerced_type); *number]]
545 }
546 TypeSignature::Numeric(number) => {
547 function_length_check(function_name, current_types.len(), *number)?;
548
549 let mut valid_type = current_types.first().unwrap().to_owned();
551 for t in current_types.iter().skip(1) {
552 let logical_data_type: NativeType = t.into();
553 if logical_data_type == NativeType::Null {
554 continue;
555 }
556
557 if !logical_data_type.is_numeric() {
558 return plan_err!(
559 "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
560 );
561 }
562
563 if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
564 valid_type = coerced_type;
565 } else {
566 return plan_err!(
567 "For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
568 );
569 }
570 }
571
572 let logical_data_type: NativeType = valid_type.clone().into();
573 if logical_data_type == NativeType::Null {
577 valid_type = DataType::Float64;
578 } else if !logical_data_type.is_numeric() {
579 return plan_err!(
580 "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
581 );
582 }
583
584 vec![vec![valid_type; *number]]
585 }
586 TypeSignature::Comparable(num) => {
587 function_length_check(function_name, current_types.len(), *num)?;
588 let mut target_type = current_types[0].to_owned();
589 for data_type in current_types.iter().skip(1) {
590 if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) {
591 target_type = dt;
592 } else {
593 return plan_err!("For function '{function_name}' {target_type} and {data_type} is not comparable");
594 }
595 }
596 if target_type.is_null() {
598 vec![vec![DataType::Utf8View; *num]]
599 } else {
600 vec![vec![target_type; *num]]
601 }
602 }
603 TypeSignature::Coercible(param_types) => {
604 function_length_check(function_name, current_types.len(), param_types.len())?;
605
606 let mut new_types = Vec::with_capacity(current_types.len());
607 for (current_type, param) in current_types.iter().zip(param_types.iter()) {
608 let current_native_type: NativeType = current_type.into();
609
610 if param.desired_type().matches_native_type(¤t_native_type) {
611 let casted_type = param.desired_type().default_casted_type(
612 ¤t_native_type,
613 current_type,
614 )?;
615
616 new_types.push(casted_type);
617 } else if param
618 .allowed_source_types()
619 .iter()
620 .any(|t| t.matches_native_type(¤t_native_type)) {
621 let default_casted_type = param.default_casted_type().unwrap();
623 let casted_type = default_casted_type.default_cast_for(current_type)?;
624 new_types.push(casted_type);
625 } else {
626 return internal_err!(
627 "Expect {} but received {}, DataType: {}",
628 param.desired_type(),
629 current_native_type,
630 current_type
631 );
632 }
633 }
634
635 vec![new_types]
636 }
637 TypeSignature::Uniform(number, valid_types) => {
638 if *number == 0 {
639 return plan_err!("The function '{function_name}' expected at least one argument");
640 }
641
642 valid_types
643 .iter()
644 .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
645 .collect()
646 }
647 TypeSignature::UserDefined => {
648 return internal_err!(
649 "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
650 )
651 }
652 TypeSignature::VariadicAny => {
653 if current_types.is_empty() {
654 return plan_err!(
655 "Function '{function_name}' expected at least one argument but received 0"
656 );
657 }
658 vec![current_types.to_vec()]
659 }
660 TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
661 TypeSignature::ArraySignature(ref function_signature) => match function_signature {
662 ArrayFunctionSignature::Array { arguments, array_coercion, } => {
663 array_valid_types(function_name, current_types, arguments, array_coercion.as_ref())?
664 }
665 ArrayFunctionSignature::RecursiveArray => {
666 if current_types.len() != 1 {
667 return Ok(vec![vec![]]);
668 }
669 recursive_array(¤t_types[0])
670 .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
671 }
672 ArrayFunctionSignature::MapArray => {
673 if current_types.len() != 1 {
674 return Ok(vec![vec![]]);
675 }
676
677 match ¤t_types[0] {
678 DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
679 _ => vec![vec![]],
680 }
681 }
682 },
683 TypeSignature::Nullary => {
684 if !current_types.is_empty() {
685 return plan_err!(
686 "The function '{function_name}' expected zero argument but received {}",
687 current_types.len()
688 );
689 }
690 vec![vec![]]
691 }
692 TypeSignature::Any(number) => {
693 if current_types.is_empty() {
694 return plan_err!(
695 "The function '{function_name}' expected at least one argument but received 0"
696 );
697 }
698
699 if current_types.len() != *number {
700 return plan_err!(
701 "The function '{function_name}' expected {number} arguments but received {}",
702 current_types.len()
703 );
704 }
705 vec![(0..*number).map(|i| current_types[i].clone()).collect()]
706 }
707 TypeSignature::OneOf(types) => types
708 .iter()
709 .filter_map(|t| get_valid_types(function_name, t, current_types).ok())
710 .flatten()
711 .collect::<Vec<_>>(),
712 };
713
714 Ok(valid_types)
715}
716
717fn maybe_data_types(
724 valid_types: &[DataType],
725 current_types: &[DataType],
726) -> Option<Vec<DataType>> {
727 if valid_types.len() != current_types.len() {
728 return None;
729 }
730
731 let mut new_type = Vec::with_capacity(valid_types.len());
732 for (i, valid_type) in valid_types.iter().enumerate() {
733 let current_type = ¤t_types[i];
734
735 if current_type == valid_type {
736 new_type.push(current_type.clone())
737 } else {
738 if let Some(coerced_type) = coerced_from(valid_type, current_type) {
742 new_type.push(coerced_type)
743 } else {
744 return None;
746 }
747 }
748 }
749 Some(new_type)
750}
751
752fn maybe_data_types_without_coercion(
756 valid_types: &[DataType],
757 current_types: &[DataType],
758) -> Option<Vec<DataType>> {
759 if valid_types.len() != current_types.len() {
760 return None;
761 }
762
763 let mut new_type = Vec::with_capacity(valid_types.len());
764 for (i, valid_type) in valid_types.iter().enumerate() {
765 let current_type = ¤t_types[i];
766
767 if current_type == valid_type {
768 new_type.push(current_type.clone())
769 } else if can_cast_types(current_type, valid_type) {
770 new_type.push(valid_type.clone())
772 } else {
773 return None;
774 }
775 }
776 Some(new_type)
777}
778
779pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
784 if type_into == type_from {
785 return true;
786 }
787 if let Some(coerced) = coerced_from(type_into, type_from) {
788 return coerced == *type_into;
789 }
790 false
791}
792
793fn coerced_from<'a>(
800 type_into: &'a DataType,
801 type_from: &'a DataType,
802) -> Option<DataType> {
803 use self::DataType::*;
804
805 match (type_into, type_from) {
807 (_, Dictionary(_, value_type))
809 if coerced_from(type_into, value_type).is_some() =>
810 {
811 Some(type_into.clone())
812 }
813 (Dictionary(_, value_type), _)
814 if coerced_from(value_type, type_from).is_some() =>
815 {
816 Some(type_into.clone())
817 }
818 (Int8, Null | Int8) => Some(type_into.clone()),
820 (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()),
821 (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()),
822 (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => {
823 Some(type_into.clone())
824 }
825 (UInt8, Null | UInt8) => Some(type_into.clone()),
826 (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
827 (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
828 (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
829 (
830 Float32,
831 Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
832 | Float32,
833 ) => Some(type_into.clone()),
834 (
835 Float64,
836 Null
837 | Int8
838 | Int16
839 | Int32
840 | Int64
841 | UInt8
842 | UInt16
843 | UInt32
844 | UInt64
845 | Float32
846 | Float64
847 | Decimal128(_, _),
848 ) => Some(type_into.clone()),
849 (
850 Timestamp(TimeUnit::Nanosecond, None),
851 Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
852 ) => Some(type_into.clone()),
853 (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()),
854 (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
856 (Utf8 | LargeUtf8, _) => Some(type_into.clone()),
858 (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
859
860 (List(_), FixedSizeList(_, _)) => Some(type_into.clone()),
861
862 (List(_) | LargeList(_), _)
865 if datafusion_common::utils::base_type(type_from).eq(&Null)
866 || list_ndims(type_from) == list_ndims(type_into) =>
867 {
868 Some(type_into.clone())
869 }
870 (
872 FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD),
873 FixedSizeList(f_from, size_from),
874 ) => match coerced_from(f_into.data_type(), f_from.data_type()) {
875 Some(data_type) if &data_type != f_into.data_type() => {
876 let new_field =
877 Arc::new(f_into.as_ref().clone().with_data_type(data_type));
878 Some(FixedSizeList(new_field, *size_from))
879 }
880 Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)),
881 _ => None,
882 },
883 (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
884 match type_from {
885 Timestamp(_, Some(from_tz)) => {
886 Some(Timestamp(*unit, Some(Arc::clone(from_tz))))
887 }
888 Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
889 Some(Timestamp(*unit, Some("+00".into())))
891 }
892 _ => None,
893 }
894 }
895 (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => {
896 Some(type_into.clone())
897 }
898 _ => None,
899 }
900}
901
902#[cfg(test)]
903mod tests {
904
905 use crate::Volatility;
906
907 use super::*;
908 use arrow::datatypes::Field;
909 use datafusion_common::assert_contains;
910
911 #[test]
912 fn test_string_conversion() {
913 let cases = vec![
914 (DataType::Utf8View, DataType::Utf8, true),
915 (DataType::Utf8View, DataType::LargeUtf8, true),
916 ];
917
918 for case in cases {
919 assert_eq!(can_coerce_from(&case.0, &case.1), case.2);
920 }
921 }
922
923 #[test]
924 fn test_maybe_data_types() {
925 let cases = vec![
927 (
929 vec![DataType::UInt8, DataType::UInt16],
930 vec![DataType::UInt8, DataType::UInt16],
931 Some(vec![DataType::UInt8, DataType::UInt16]),
932 ),
933 (
935 vec![DataType::UInt16, DataType::UInt16],
936 vec![DataType::UInt8, DataType::UInt16],
937 Some(vec![DataType::UInt16, DataType::UInt16]),
938 ),
939 (vec![], vec![], Some(vec![])),
941 (
943 vec![DataType::Boolean, DataType::UInt16],
944 vec![DataType::UInt8, DataType::UInt16],
945 None,
946 ),
947 (
949 vec![DataType::Boolean, DataType::UInt32],
950 vec![DataType::Boolean, DataType::UInt16],
951 Some(vec![DataType::Boolean, DataType::UInt32]),
952 ),
953 (
955 vec![
956 DataType::Timestamp(TimeUnit::Nanosecond, None),
957 DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())),
958 DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
959 ],
960 vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
961 Some(vec![
962 DataType::Timestamp(TimeUnit::Nanosecond, None),
963 DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())),
964 DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
965 ]),
966 ),
967 ];
968
969 for case in cases {
970 assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
971 }
972 }
973
974 #[test]
975 fn test_get_valid_types_numeric() -> Result<()> {
976 let get_valid_types_flatten =
977 |function_name: &str,
978 signature: &TypeSignature,
979 current_types: &[DataType]| {
980 get_valid_types(function_name, signature, current_types)
981 .unwrap()
982 .into_iter()
983 .flatten()
984 .collect::<Vec<_>>()
985 };
986
987 let got = get_valid_types_flatten(
989 "test",
990 &TypeSignature::Numeric(1),
991 &[DataType::Int32],
992 );
993 assert_eq!(got, [DataType::Int32]);
994
995 let got = get_valid_types_flatten(
997 "test",
998 &TypeSignature::Numeric(2),
999 &[DataType::Int32, DataType::Int64],
1000 );
1001 assert_eq!(got, [DataType::Int64, DataType::Int64]);
1002
1003 let got = get_valid_types_flatten(
1005 "test",
1006 &TypeSignature::Numeric(3),
1007 &[DataType::Int32, DataType::Int64, DataType::Float64],
1008 );
1009 assert_eq!(
1010 got,
1011 [DataType::Float64, DataType::Float64, DataType::Float64]
1012 );
1013
1014 let got = get_valid_types(
1016 "test",
1017 &TypeSignature::Numeric(2),
1018 &[DataType::Int32, DataType::Utf8],
1019 )
1020 .unwrap_err();
1021 assert_contains!(
1022 got.to_string(),
1023 "Function 'test' expects NativeType::Numeric but received NativeType::String"
1024 );
1025
1026 let got = get_valid_types_flatten(
1028 "test",
1029 &TypeSignature::Numeric(1),
1030 &[DataType::Null],
1031 );
1032 assert_eq!(got, [DataType::Float64]);
1033
1034 let got = get_valid_types(
1036 "test",
1037 &TypeSignature::Numeric(1),
1038 &[DataType::Timestamp(TimeUnit::Second, None)],
1039 )
1040 .unwrap_err();
1041 assert_contains!(
1042 got.to_string(),
1043 "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(Second, None)"
1044 );
1045
1046 Ok(())
1047 }
1048
1049 #[test]
1050 fn test_get_valid_types_one_of() -> Result<()> {
1051 let signature =
1052 TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
1053
1054 let invalid_types = get_valid_types(
1055 "test",
1056 &signature,
1057 &[DataType::Int32, DataType::Int32, DataType::Int32],
1058 )?;
1059 assert_eq!(invalid_types.len(), 0);
1060
1061 let args = vec![DataType::Int32, DataType::Int32];
1062 let valid_types = get_valid_types("test", &signature, &args)?;
1063 assert_eq!(valid_types.len(), 1);
1064 assert_eq!(valid_types[0], args);
1065
1066 let args = vec![DataType::Int32];
1067 let valid_types = get_valid_types("test", &signature, &args)?;
1068 assert_eq!(valid_types.len(), 1);
1069 assert_eq!(valid_types[0], args);
1070
1071 Ok(())
1072 }
1073
1074 #[test]
1075 fn test_get_valid_types_length_check() -> Result<()> {
1076 let signature = TypeSignature::Numeric(1);
1077
1078 let err = get_valid_types("test", &signature, &[]).unwrap_err();
1079 assert_contains!(
1080 err.to_string(),
1081 "Function 'test' expects 1 arguments but received 0"
1082 );
1083
1084 let err = get_valid_types(
1085 "test",
1086 &signature,
1087 &[DataType::Int32, DataType::Int32, DataType::Int32],
1088 )
1089 .unwrap_err();
1090 assert_contains!(
1091 err.to_string(),
1092 "Function 'test' expects 1 arguments but received 3"
1093 );
1094
1095 Ok(())
1096 }
1097
1098 #[test]
1099 fn test_fixed_list_wildcard_coerce() -> Result<()> {
1100 let inner = Arc::new(Field::new_list_field(DataType::Int32, false));
1101 let current_types = vec![
1102 DataType::FixedSizeList(Arc::clone(&inner), 2), ];
1104
1105 let signature = Signature::exact(
1106 vec![DataType::FixedSizeList(
1107 Arc::clone(&inner),
1108 FIXED_SIZE_LIST_WILDCARD,
1109 )],
1110 Volatility::Stable,
1111 );
1112
1113 let coerced_data_types = data_types("test", ¤t_types, &signature)?;
1114 assert_eq!(coerced_data_types, current_types);
1115
1116 let signature = Signature::exact(
1118 vec![DataType::FixedSizeList(Arc::clone(&inner), 3)],
1119 Volatility::Stable,
1120 );
1121 let coerced_data_types = data_types("test", ¤t_types, &signature);
1122 assert!(coerced_data_types.is_err());
1123
1124 let signature = Signature::exact(
1126 vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
1127 Volatility::Stable,
1128 );
1129 let coerced_data_types = data_types("test", ¤t_types, &signature).unwrap();
1130 assert_eq!(coerced_data_types, current_types);
1131
1132 Ok(())
1133 }
1134
1135 #[test]
1136 fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
1137 let type_into = DataType::FixedSizeList(
1138 Arc::new(Field::new_list_field(
1139 DataType::FixedSizeList(
1140 Arc::new(Field::new_list_field(DataType::Int32, false)),
1141 FIXED_SIZE_LIST_WILDCARD,
1142 ),
1143 false,
1144 )),
1145 FIXED_SIZE_LIST_WILDCARD,
1146 );
1147
1148 let type_from = DataType::FixedSizeList(
1149 Arc::new(Field::new_list_field(
1150 DataType::FixedSizeList(
1151 Arc::new(Field::new_list_field(DataType::Int8, false)),
1152 4,
1153 ),
1154 false,
1155 )),
1156 3,
1157 );
1158
1159 assert_eq!(
1160 coerced_from(&type_into, &type_from),
1161 Some(DataType::FixedSizeList(
1162 Arc::new(Field::new_list_field(
1163 DataType::FixedSizeList(
1164 Arc::new(Field::new_list_field(DataType::Int32, false)),
1165 4,
1166 ),
1167 false,
1168 )),
1169 3,
1170 ))
1171 );
1172
1173 Ok(())
1174 }
1175
1176 #[test]
1177 fn test_coerced_from_dictionary() {
1178 let type_into =
1179 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1180 let type_from = DataType::Int64;
1181 assert_eq!(coerced_from(&type_into, &type_from), None);
1182
1183 let type_from =
1184 DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1185 let type_into = DataType::Int64;
1186 assert_eq!(
1187 coerced_from(&type_into, &type_from),
1188 Some(type_into.clone())
1189 );
1190 }
1191}