datafusion_expr/type_coercion/
functions.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::binary::{binary_numeric_coercion, comparison_coercion};
19use crate::{AggregateUDF, ScalarUDF, Signature, TypeSignature, WindowUDF};
20use arrow::{
21    compute::can_cast_types,
22    datatypes::{DataType, TimeUnit},
23};
24use datafusion_common::types::LogicalType;
25use datafusion_common::utils::{coerced_fixed_size_list_to_list, ListCoercion};
26use datafusion_common::{
27    exec_err, internal_datafusion_err, internal_err, plan_err, types::NativeType,
28    utils::list_ndims, Result,
29};
30use datafusion_expr_common::signature::ArrayFunctionArgument;
31use datafusion_expr_common::{
32    signature::{ArrayFunctionSignature, FIXED_SIZE_LIST_WILDCARD, TIMEZONE_WILDCARD},
33    type_coercion::binary::comparison_coercion_numeric,
34    type_coercion::binary::string_coercion,
35};
36use std::sync::Arc;
37
38/// Performs type coercion for scalar function arguments.
39///
40/// Returns the data types to which each argument must be coerced to
41/// match `signature`.
42///
43/// For more details on coercion in general, please see the
44/// [`type_coercion`](crate::type_coercion) module.
45pub fn data_types_with_scalar_udf(
46    current_types: &[DataType],
47    func: &ScalarUDF,
48) -> Result<Vec<DataType>> {
49    let signature = func.signature();
50    let type_signature = &signature.type_signature;
51
52    if current_types.is_empty() {
53        if type_signature.supports_zero_argument() {
54            return Ok(vec![]);
55        } else if type_signature.used_to_support_zero_arguments() {
56            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
57            return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
58        } else {
59            return plan_err!("'{}' does not support zero arguments", func.name());
60        }
61    }
62
63    let valid_types =
64        get_valid_types_with_scalar_udf(type_signature, current_types, func)?;
65
66    if valid_types
67        .iter()
68        .any(|data_type| data_type == current_types)
69    {
70        return Ok(current_types.to_vec());
71    }
72
73    try_coerce_types(func.name(), valid_types, current_types, type_signature)
74}
75
76/// Performs type coercion for aggregate function arguments.
77///
78/// Returns the data types to which each argument must be coerced to
79/// match `signature`.
80///
81/// For more details on coercion in general, please see the
82/// [`type_coercion`](crate::type_coercion) module.
83pub fn data_types_with_aggregate_udf(
84    current_types: &[DataType],
85    func: &AggregateUDF,
86) -> Result<Vec<DataType>> {
87    let signature = func.signature();
88    let type_signature = &signature.type_signature;
89
90    if current_types.is_empty() {
91        if type_signature.supports_zero_argument() {
92            return Ok(vec![]);
93        } else if type_signature.used_to_support_zero_arguments() {
94            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
95            return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
96        } else {
97            return plan_err!("'{}' does not support zero arguments", func.name());
98        }
99    }
100
101    let valid_types =
102        get_valid_types_with_aggregate_udf(type_signature, current_types, func)?;
103    if valid_types
104        .iter()
105        .any(|data_type| data_type == current_types)
106    {
107        return Ok(current_types.to_vec());
108    }
109
110    try_coerce_types(func.name(), valid_types, current_types, type_signature)
111}
112
113/// Performs type coercion for window function arguments.
114///
115/// Returns the data types to which each argument must be coerced to
116/// match `signature`.
117///
118/// For more details on coercion in general, please see the
119/// [`type_coercion`](crate::type_coercion) module.
120pub fn data_types_with_window_udf(
121    current_types: &[DataType],
122    func: &WindowUDF,
123) -> Result<Vec<DataType>> {
124    let signature = func.signature();
125    let type_signature = &signature.type_signature;
126
127    if current_types.is_empty() {
128        if type_signature.supports_zero_argument() {
129            return Ok(vec![]);
130        } else if type_signature.used_to_support_zero_arguments() {
131            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
132            return plan_err!("'{}' does not support zero arguments. Use TypeSignature::Nullary for zero arguments", func.name());
133        } else {
134            return plan_err!("'{}' does not support zero arguments", func.name());
135        }
136    }
137
138    let valid_types =
139        get_valid_types_with_window_udf(type_signature, current_types, func)?;
140    if valid_types
141        .iter()
142        .any(|data_type| data_type == current_types)
143    {
144        return Ok(current_types.to_vec());
145    }
146
147    try_coerce_types(func.name(), valid_types, current_types, type_signature)
148}
149
150/// Performs type coercion for function arguments.
151///
152/// Returns the data types to which each argument must be coerced to
153/// match `signature`.
154///
155/// For more details on coercion in general, please see the
156/// [`type_coercion`](crate::type_coercion) module.
157pub fn data_types(
158    function_name: impl AsRef<str>,
159    current_types: &[DataType],
160    signature: &Signature,
161) -> Result<Vec<DataType>> {
162    let type_signature = &signature.type_signature;
163
164    if current_types.is_empty() {
165        if type_signature.supports_zero_argument() {
166            return Ok(vec![]);
167        } else if type_signature.used_to_support_zero_arguments() {
168            // Special error to help during upgrade: https://github.com/apache/datafusion/issues/13763
169            return plan_err!(
170                "function '{}' has signature {type_signature:?} which does not support zero arguments. Use TypeSignature::Nullary for zero arguments",
171                function_name.as_ref()
172            );
173        } else {
174            return plan_err!(
175                "Function '{}' has signature {type_signature:?} which does not support zero arguments",
176                function_name.as_ref()
177            );
178        }
179    }
180
181    let valid_types =
182        get_valid_types(function_name.as_ref(), type_signature, current_types)?;
183    if valid_types
184        .iter()
185        .any(|data_type| data_type == current_types)
186    {
187        return Ok(current_types.to_vec());
188    }
189
190    try_coerce_types(
191        function_name.as_ref(),
192        valid_types,
193        current_types,
194        type_signature,
195    )
196}
197
198fn is_well_supported_signature(type_signature: &TypeSignature) -> bool {
199    if let TypeSignature::OneOf(signatures) = type_signature {
200        return signatures.iter().all(is_well_supported_signature);
201    }
202
203    matches!(
204        type_signature,
205        TypeSignature::UserDefined
206            | TypeSignature::Numeric(_)
207            | TypeSignature::String(_)
208            | TypeSignature::Coercible(_)
209            | TypeSignature::Any(_)
210            | TypeSignature::Nullary
211            | TypeSignature::Comparable(_)
212    )
213}
214
215fn try_coerce_types(
216    function_name: &str,
217    valid_types: Vec<Vec<DataType>>,
218    current_types: &[DataType],
219    type_signature: &TypeSignature,
220) -> Result<Vec<DataType>> {
221    let mut valid_types = valid_types;
222
223    // Well-supported signature that returns exact valid types.
224    if !valid_types.is_empty() && is_well_supported_signature(type_signature) {
225        // There may be many valid types if valid signature is OneOf
226        // Otherwise, there should be only one valid type
227        if !type_signature.is_one_of() {
228            assert_eq!(valid_types.len(), 1);
229        }
230
231        let valid_types = valid_types.swap_remove(0);
232        if let Some(t) = maybe_data_types_without_coercion(&valid_types, current_types) {
233            return Ok(t);
234        }
235    } else {
236        // TODO: Deprecate this branch after all signatures are well-supported (aka coercion has happened already)
237        // Try and coerce the argument types to match the signature, returning the
238        // coerced types from the first matching signature.
239        for valid_types in valid_types {
240            if let Some(types) = maybe_data_types(&valid_types, current_types) {
241                return Ok(types);
242            }
243        }
244    }
245
246    // none possible -> Error
247    plan_err!(
248        "Failed to coerce arguments to satisfy a call to '{function_name}' function: coercion from {current_types:?} to the signature {type_signature:?} failed"
249    )
250}
251
252fn get_valid_types_with_scalar_udf(
253    signature: &TypeSignature,
254    current_types: &[DataType],
255    func: &ScalarUDF,
256) -> Result<Vec<Vec<DataType>>> {
257    match signature {
258        TypeSignature::UserDefined => match func.coerce_types(current_types) {
259            Ok(coerced_types) => Ok(vec![coerced_types]),
260            Err(e) => exec_err!(
261                "Function '{}' user-defined coercion failed with {:?}",
262                func.name(),
263                e.strip_backtrace()
264            ),
265        },
266        TypeSignature::OneOf(signatures) => {
267            let mut res = vec![];
268            let mut errors = vec![];
269            for sig in signatures {
270                match get_valid_types_with_scalar_udf(sig, current_types, func) {
271                    Ok(valid_types) => {
272                        res.extend(valid_types);
273                    }
274                    Err(e) => {
275                        errors.push(e.to_string());
276                    }
277                }
278            }
279
280            // Every signature failed, return the joined error
281            if res.is_empty() {
282                internal_err!(
283                    "Function '{}' failed to match any signature, errors: {}",
284                    func.name(),
285                    errors.join(",")
286                )
287            } else {
288                Ok(res)
289            }
290        }
291        _ => get_valid_types(func.name(), signature, current_types),
292    }
293}
294
295fn get_valid_types_with_aggregate_udf(
296    signature: &TypeSignature,
297    current_types: &[DataType],
298    func: &AggregateUDF,
299) -> Result<Vec<Vec<DataType>>> {
300    let valid_types = match signature {
301        TypeSignature::UserDefined => match func.coerce_types(current_types) {
302            Ok(coerced_types) => vec![coerced_types],
303            Err(e) => {
304                return exec_err!(
305                    "Function '{}' user-defined coercion failed with {:?}",
306                    func.name(),
307                    e.strip_backtrace()
308                )
309            }
310        },
311        TypeSignature::OneOf(signatures) => signatures
312            .iter()
313            .filter_map(|t| {
314                get_valid_types_with_aggregate_udf(t, current_types, func).ok()
315            })
316            .flatten()
317            .collect::<Vec<_>>(),
318        _ => get_valid_types(func.name(), signature, current_types)?,
319    };
320
321    Ok(valid_types)
322}
323
324fn get_valid_types_with_window_udf(
325    signature: &TypeSignature,
326    current_types: &[DataType],
327    func: &WindowUDF,
328) -> Result<Vec<Vec<DataType>>> {
329    let valid_types = match signature {
330        TypeSignature::UserDefined => match func.coerce_types(current_types) {
331            Ok(coerced_types) => vec![coerced_types],
332            Err(e) => {
333                return exec_err!(
334                    "Function '{}' user-defined coercion failed with {:?}",
335                    func.name(),
336                    e.strip_backtrace()
337                )
338            }
339        },
340        TypeSignature::OneOf(signatures) => signatures
341            .iter()
342            .filter_map(|t| get_valid_types_with_window_udf(t, current_types, func).ok())
343            .flatten()
344            .collect::<Vec<_>>(),
345        _ => get_valid_types(func.name(), signature, current_types)?,
346    };
347
348    Ok(valid_types)
349}
350
351/// Returns a Vec of all possible valid argument types for the given signature.
352fn get_valid_types(
353    function_name: &str,
354    signature: &TypeSignature,
355    current_types: &[DataType],
356) -> Result<Vec<Vec<DataType>>> {
357    fn array_valid_types(
358        function_name: &str,
359        current_types: &[DataType],
360        arguments: &[ArrayFunctionArgument],
361        array_coercion: Option<&ListCoercion>,
362    ) -> Result<Vec<Vec<DataType>>> {
363        if current_types.len() != arguments.len() {
364            return Ok(vec![vec![]]);
365        }
366
367        let array_idx = arguments.iter().enumerate().find_map(|(idx, arg)| {
368            if *arg == ArrayFunctionArgument::Array {
369                Some(idx)
370            } else {
371                None
372            }
373        });
374        let Some(array_idx) = array_idx else {
375            return Err(internal_datafusion_err!("Function '{function_name}' expected at least one argument array argument"));
376        };
377        let Some(array_type) = array(&current_types[array_idx]) else {
378            return Ok(vec![vec![]]);
379        };
380
381        // We need to find the coerced base type, mainly for cases like:
382        // `array_append(List(null), i64)` -> `List(i64)`
383        let mut new_base_type = datafusion_common::utils::base_type(&array_type);
384        for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) {
385            match argument_type {
386                ArrayFunctionArgument::Element | ArrayFunctionArgument::Array => {
387                    new_base_type =
388                        coerce_array_types(function_name, current_type, &new_base_type)?;
389                }
390                ArrayFunctionArgument::Index => {}
391            }
392        }
393        let new_array_type = datafusion_common::utils::coerced_type_with_base_type_only(
394            &array_type,
395            &new_base_type,
396            array_coercion,
397        );
398
399        let new_elem_type = match new_array_type {
400            DataType::List(ref field)
401            | DataType::LargeList(ref field)
402            | DataType::FixedSizeList(ref field, _) => field.data_type(),
403            _ => return Ok(vec![vec![]]),
404        };
405
406        let mut valid_types = Vec::with_capacity(arguments.len());
407        for (current_type, argument_type) in current_types.iter().zip(arguments.iter()) {
408            let valid_type = match argument_type {
409                ArrayFunctionArgument::Element => new_elem_type.clone(),
410                ArrayFunctionArgument::Index => DataType::Int64,
411                ArrayFunctionArgument::Array => {
412                    let Some(current_type) = array(current_type) else {
413                        return Ok(vec![vec![]]);
414                    };
415                    let new_type =
416                        datafusion_common::utils::coerced_type_with_base_type_only(
417                            &current_type,
418                            &new_base_type,
419                            array_coercion,
420                        );
421                    // All array arguments must be coercible to the same type
422                    if new_type != new_array_type {
423                        return Ok(vec![vec![]]);
424                    }
425                    new_type
426                }
427            };
428            valid_types.push(valid_type);
429        }
430
431        Ok(vec![valid_types])
432    }
433
434    fn array(array_type: &DataType) -> Option<DataType> {
435        match array_type {
436            DataType::List(_) | DataType::LargeList(_) => Some(array_type.clone()),
437            DataType::FixedSizeList(field, _) => Some(DataType::List(Arc::clone(field))),
438            _ => None,
439        }
440    }
441
442    fn coerce_array_types(
443        function_name: &str,
444        current_type: &DataType,
445        base_type: &DataType,
446    ) -> Result<DataType> {
447        let current_base_type = datafusion_common::utils::base_type(current_type);
448        let new_base_type = comparison_coercion(base_type, &current_base_type);
449        new_base_type.ok_or_else(|| {
450            internal_datafusion_err!(
451                "Function '{function_name}' does not support coercion from {base_type:?} to {current_base_type:?}"
452            )
453        })
454    }
455
456    fn recursive_array(array_type: &DataType) -> Option<DataType> {
457        match array_type {
458            DataType::List(_)
459            | DataType::LargeList(_)
460            | DataType::FixedSizeList(_, _) => {
461                let array_type = coerced_fixed_size_list_to_list(array_type);
462                Some(array_type)
463            }
464            _ => None,
465        }
466    }
467
468    fn function_length_check(
469        function_name: &str,
470        length: usize,
471        expected_length: usize,
472    ) -> Result<()> {
473        if length != expected_length {
474            return plan_err!(
475                "Function '{function_name}' expects {expected_length} arguments but received {length}"
476            );
477        }
478        Ok(())
479    }
480
481    let valid_types = match signature {
482        TypeSignature::Variadic(valid_types) => valid_types
483            .iter()
484            .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
485            .collect(),
486        TypeSignature::String(number) => {
487            function_length_check(function_name, current_types.len(), *number)?;
488
489            let mut new_types = Vec::with_capacity(current_types.len());
490            for data_type in current_types.iter() {
491                let logical_data_type: NativeType = data_type.into();
492                if logical_data_type == NativeType::String {
493                    new_types.push(data_type.to_owned());
494                } else if logical_data_type == NativeType::Null {
495                    // TODO: Switch to Utf8View if all the string functions supports Utf8View
496                    new_types.push(DataType::Utf8);
497                } else {
498                    return plan_err!(
499                        "Function '{function_name}' expects NativeType::String but received {logical_data_type}"
500                    );
501                }
502            }
503
504            // Find the common string type for the given types
505            fn find_common_type(
506                function_name: &str,
507                lhs_type: &DataType,
508                rhs_type: &DataType,
509            ) -> Result<DataType> {
510                match (lhs_type, rhs_type) {
511                    (DataType::Dictionary(_, lhs), DataType::Dictionary(_, rhs)) => {
512                        find_common_type(function_name, lhs, rhs)
513                    }
514                    (DataType::Dictionary(_, v), other)
515                    | (other, DataType::Dictionary(_, v)) => {
516                        find_common_type(function_name, v, other)
517                    }
518                    _ => {
519                        if let Some(coerced_type) = string_coercion(lhs_type, rhs_type) {
520                            Ok(coerced_type)
521                        } else {
522                            plan_err!(
523                                "Function '{function_name}' could not coerce {lhs_type} and {rhs_type} to a common string type"
524                            )
525                        }
526                    }
527                }
528            }
529
530            // Length checked above, safe to unwrap
531            let mut coerced_type = new_types.first().unwrap().to_owned();
532            for t in new_types.iter().skip(1) {
533                coerced_type = find_common_type(function_name, &coerced_type, t)?;
534            }
535
536            fn base_type_or_default_type(data_type: &DataType) -> DataType {
537                if let DataType::Dictionary(_, v) = data_type {
538                    base_type_or_default_type(v)
539                } else {
540                    data_type.to_owned()
541                }
542            }
543
544            vec![vec![base_type_or_default_type(&coerced_type); *number]]
545        }
546        TypeSignature::Numeric(number) => {
547            function_length_check(function_name, current_types.len(), *number)?;
548
549            // Find common numeric type among given types except string
550            let mut valid_type = current_types.first().unwrap().to_owned();
551            for t in current_types.iter().skip(1) {
552                let logical_data_type: NativeType = t.into();
553                if logical_data_type == NativeType::Null {
554                    continue;
555                }
556
557                if !logical_data_type.is_numeric() {
558                    return plan_err!(
559                        "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
560                    );
561                }
562
563                if let Some(coerced_type) = binary_numeric_coercion(&valid_type, t) {
564                    valid_type = coerced_type;
565                } else {
566                    return plan_err!(
567                        "For function '{function_name}' {valid_type} and {t} are not coercible to a common numeric type"
568                    );
569                }
570            }
571
572            let logical_data_type: NativeType = valid_type.clone().into();
573            // Fallback to default type if we don't know which type to coerced to
574            // f64 is chosen since most of the math functions utilize Signature::numeric,
575            // and their default type is double precision
576            if logical_data_type == NativeType::Null {
577                valid_type = DataType::Float64;
578            } else if !logical_data_type.is_numeric() {
579                return plan_err!(
580                    "Function '{function_name}' expects NativeType::Numeric but received {logical_data_type}"
581                );
582            }
583
584            vec![vec![valid_type; *number]]
585        }
586        TypeSignature::Comparable(num) => {
587            function_length_check(function_name, current_types.len(), *num)?;
588            let mut target_type = current_types[0].to_owned();
589            for data_type in current_types.iter().skip(1) {
590                if let Some(dt) = comparison_coercion_numeric(&target_type, data_type) {
591                    target_type = dt;
592                } else {
593                    return plan_err!("For function '{function_name}' {target_type} and {data_type} is not comparable");
594                }
595            }
596            // Convert null to String type.
597            if target_type.is_null() {
598                vec![vec![DataType::Utf8View; *num]]
599            } else {
600                vec![vec![target_type; *num]]
601            }
602        }
603        TypeSignature::Coercible(param_types) => {
604            function_length_check(function_name, current_types.len(), param_types.len())?;
605
606            let mut new_types = Vec::with_capacity(current_types.len());
607            for (current_type, param) in current_types.iter().zip(param_types.iter()) {
608                let current_native_type: NativeType = current_type.into();
609
610                if param.desired_type().matches_native_type(&current_native_type) {
611                    let casted_type = param.desired_type().default_casted_type(
612                        &current_native_type,
613                        current_type,
614                    )?;
615
616                    new_types.push(casted_type);
617                } else if param
618                .allowed_source_types()
619                .iter()
620                .any(|t| t.matches_native_type(&current_native_type)) {
621                    // If the condition is met which means `implicit coercion`` is provided so we can safely unwrap
622                    let default_casted_type = param.default_casted_type().unwrap();
623                    let casted_type = default_casted_type.default_cast_for(current_type)?;
624                    new_types.push(casted_type);
625                } else {
626                    return internal_err!(
627                        "Expect {} but received {}, DataType: {}",
628                        param.desired_type(),
629                        current_native_type,
630                        current_type
631                    );
632                }
633            }
634
635            vec![new_types]
636        }
637        TypeSignature::Uniform(number, valid_types) => {
638            if *number == 0 {
639                return plan_err!("The function '{function_name}' expected at least one argument");
640            }
641
642            valid_types
643                .iter()
644                .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
645                .collect()
646        }
647        TypeSignature::UserDefined => {
648            return internal_err!(
649                "Function '{function_name}' user-defined signature should be handled by function-specific coerce_types"
650            )
651        }
652        TypeSignature::VariadicAny => {
653            if current_types.is_empty() {
654                return plan_err!(
655                    "Function '{function_name}' expected at least one argument but received 0"
656                );
657            }
658            vec![current_types.to_vec()]
659        }
660        TypeSignature::Exact(valid_types) => vec![valid_types.clone()],
661        TypeSignature::ArraySignature(ref function_signature) => match function_signature {
662            ArrayFunctionSignature::Array { arguments, array_coercion, } => {
663                array_valid_types(function_name, current_types, arguments, array_coercion.as_ref())?
664            }
665            ArrayFunctionSignature::RecursiveArray => {
666                if current_types.len() != 1 {
667                    return Ok(vec![vec![]]);
668                }
669                recursive_array(&current_types[0])
670                    .map_or_else(|| vec![vec![]], |array_type| vec![vec![array_type]])
671            }
672            ArrayFunctionSignature::MapArray => {
673                if current_types.len() != 1 {
674                    return Ok(vec![vec![]]);
675                }
676
677                match &current_types[0] {
678                    DataType::Map(_, _) => vec![vec![current_types[0].clone()]],
679                    _ => vec![vec![]],
680                }
681            }
682        },
683        TypeSignature::Nullary => {
684            if !current_types.is_empty() {
685                return plan_err!(
686                    "The function '{function_name}' expected zero argument but received {}",
687                    current_types.len()
688                );
689            }
690            vec![vec![]]
691        }
692        TypeSignature::Any(number) => {
693            if current_types.is_empty() {
694                return plan_err!(
695                    "The function '{function_name}' expected at least one argument but received 0"
696                );
697            }
698
699            if current_types.len() != *number {
700                return plan_err!(
701                    "The function '{function_name}' expected {number} arguments but received {}",
702                    current_types.len()
703                );
704            }
705            vec![(0..*number).map(|i| current_types[i].clone()).collect()]
706        }
707        TypeSignature::OneOf(types) => types
708            .iter()
709            .filter_map(|t| get_valid_types(function_name, t, current_types).ok())
710            .flatten()
711            .collect::<Vec<_>>(),
712    };
713
714    Ok(valid_types)
715}
716
717/// Try to coerce the current argument types to match the given `valid_types`.
718///
719/// For example, if a function `func` accepts arguments of  `(int64, int64)`,
720/// but was called with `(int32, int64)`, this function could match the
721/// valid_types by coercing the first argument to `int64`, and would return
722/// `Some([int64, int64])`.
723fn maybe_data_types(
724    valid_types: &[DataType],
725    current_types: &[DataType],
726) -> Option<Vec<DataType>> {
727    if valid_types.len() != current_types.len() {
728        return None;
729    }
730
731    let mut new_type = Vec::with_capacity(valid_types.len());
732    for (i, valid_type) in valid_types.iter().enumerate() {
733        let current_type = &current_types[i];
734
735        if current_type == valid_type {
736            new_type.push(current_type.clone())
737        } else {
738            // attempt to coerce.
739            // TODO: Replace with `can_cast_types` after failing cases are resolved
740            // (they need new signature that returns exactly valid types instead of list of possible valid types).
741            if let Some(coerced_type) = coerced_from(valid_type, current_type) {
742                new_type.push(coerced_type)
743            } else {
744                // not possible
745                return None;
746            }
747        }
748    }
749    Some(new_type)
750}
751
752/// Check if the current argument types can be coerced to match the given `valid_types`
753/// unlike `maybe_data_types`, this function does not coerce the types.
754/// TODO: I think this function should replace `maybe_data_types` after signature are well-supported.
755fn maybe_data_types_without_coercion(
756    valid_types: &[DataType],
757    current_types: &[DataType],
758) -> Option<Vec<DataType>> {
759    if valid_types.len() != current_types.len() {
760        return None;
761    }
762
763    let mut new_type = Vec::with_capacity(valid_types.len());
764    for (i, valid_type) in valid_types.iter().enumerate() {
765        let current_type = &current_types[i];
766
767        if current_type == valid_type {
768            new_type.push(current_type.clone())
769        } else if can_cast_types(current_type, valid_type) {
770            // validate the valid type is castable from the current type
771            new_type.push(valid_type.clone())
772        } else {
773            return None;
774        }
775    }
776    Some(new_type)
777}
778
779/// Return true if a value of type `type_from` can be coerced
780/// (losslessly converted) into a value of `type_to`
781///
782/// See the module level documentation for more detail on coercion.
783pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
784    if type_into == type_from {
785        return true;
786    }
787    if let Some(coerced) = coerced_from(type_into, type_from) {
788        return coerced == *type_into;
789    }
790    false
791}
792
793/// Find the coerced type for the given `type_into` and `type_from`.
794/// Returns `None` if coercion is not possible.
795///
796/// Expect uni-directional coercion, for example, i32 is coerced to i64, but i64 is not coerced to i32.
797///
798/// Unlike [comparison_coercion], the coerced type is usually `wider` for lossless conversion.
799fn coerced_from<'a>(
800    type_into: &'a DataType,
801    type_from: &'a DataType,
802) -> Option<DataType> {
803    use self::DataType::*;
804
805    // match Dictionary first
806    match (type_into, type_from) {
807        // coerced dictionary first
808        (_, Dictionary(_, value_type))
809            if coerced_from(type_into, value_type).is_some() =>
810        {
811            Some(type_into.clone())
812        }
813        (Dictionary(_, value_type), _)
814            if coerced_from(value_type, type_from).is_some() =>
815        {
816            Some(type_into.clone())
817        }
818        // coerced into type_into
819        (Int8, Null | Int8) => Some(type_into.clone()),
820        (Int16, Null | Int8 | Int16 | UInt8) => Some(type_into.clone()),
821        (Int32, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => Some(type_into.clone()),
822        (Int64, Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32) => {
823            Some(type_into.clone())
824        }
825        (UInt8, Null | UInt8) => Some(type_into.clone()),
826        (UInt16, Null | UInt8 | UInt16) => Some(type_into.clone()),
827        (UInt32, Null | UInt8 | UInt16 | UInt32) => Some(type_into.clone()),
828        (UInt64, Null | UInt8 | UInt16 | UInt32 | UInt64) => Some(type_into.clone()),
829        (
830            Float32,
831            Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
832            | Float32,
833        ) => Some(type_into.clone()),
834        (
835            Float64,
836            Null
837            | Int8
838            | Int16
839            | Int32
840            | Int64
841            | UInt8
842            | UInt16
843            | UInt32
844            | UInt64
845            | Float32
846            | Float64
847            | Decimal128(_, _),
848        ) => Some(type_into.clone()),
849        (
850            Timestamp(TimeUnit::Nanosecond, None),
851            Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8,
852        ) => Some(type_into.clone()),
853        (Interval(_), Utf8 | LargeUtf8) => Some(type_into.clone()),
854        // We can go into a Utf8View from a Utf8 or LargeUtf8
855        (Utf8View, Utf8 | LargeUtf8 | Null) => Some(type_into.clone()),
856        // Any type can be coerced into strings
857        (Utf8 | LargeUtf8, _) => Some(type_into.clone()),
858        (Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
859
860        (List(_), FixedSizeList(_, _)) => Some(type_into.clone()),
861
862        // Only accept list and largelist with the same number of dimensions unless the type is Null.
863        // List or LargeList with different dimensions should be handled in TypeSignature or other places before this
864        (List(_) | LargeList(_), _)
865            if datafusion_common::utils::base_type(type_from).eq(&Null)
866                || list_ndims(type_from) == list_ndims(type_into) =>
867        {
868            Some(type_into.clone())
869        }
870        // should be able to coerce wildcard fixed size list to non wildcard fixed size list
871        (
872            FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD),
873            FixedSizeList(f_from, size_from),
874        ) => match coerced_from(f_into.data_type(), f_from.data_type()) {
875            Some(data_type) if &data_type != f_into.data_type() => {
876                let new_field =
877                    Arc::new(f_into.as_ref().clone().with_data_type(data_type));
878                Some(FixedSizeList(new_field, *size_from))
879            }
880            Some(_) => Some(FixedSizeList(Arc::clone(f_into), *size_from)),
881            _ => None,
882        },
883        (Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
884            match type_from {
885                Timestamp(_, Some(from_tz)) => {
886                    Some(Timestamp(*unit, Some(Arc::clone(from_tz))))
887                }
888                Null | Date32 | Utf8 | LargeUtf8 | Timestamp(_, None) => {
889                    // In the absence of any other information assume the time zone is "+00" (UTC).
890                    Some(Timestamp(*unit, Some("+00".into())))
891                }
892                _ => None,
893            }
894        }
895        (Timestamp(_, Some(_)), Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8) => {
896            Some(type_into.clone())
897        }
898        _ => None,
899    }
900}
901
902#[cfg(test)]
903mod tests {
904
905    use crate::Volatility;
906
907    use super::*;
908    use arrow::datatypes::Field;
909    use datafusion_common::assert_contains;
910
911    #[test]
912    fn test_string_conversion() {
913        let cases = vec![
914            (DataType::Utf8View, DataType::Utf8, true),
915            (DataType::Utf8View, DataType::LargeUtf8, true),
916        ];
917
918        for case in cases {
919            assert_eq!(can_coerce_from(&case.0, &case.1), case.2);
920        }
921    }
922
923    #[test]
924    fn test_maybe_data_types() {
925        // this vec contains: arg1, arg2, expected result
926        let cases = vec![
927            // 2 entries, same values
928            (
929                vec![DataType::UInt8, DataType::UInt16],
930                vec![DataType::UInt8, DataType::UInt16],
931                Some(vec![DataType::UInt8, DataType::UInt16]),
932            ),
933            // 2 entries, can coerce values
934            (
935                vec![DataType::UInt16, DataType::UInt16],
936                vec![DataType::UInt8, DataType::UInt16],
937                Some(vec![DataType::UInt16, DataType::UInt16]),
938            ),
939            // 0 entries, all good
940            (vec![], vec![], Some(vec![])),
941            // 2 entries, can't coerce
942            (
943                vec![DataType::Boolean, DataType::UInt16],
944                vec![DataType::UInt8, DataType::UInt16],
945                None,
946            ),
947            // u32 -> u16 is possible
948            (
949                vec![DataType::Boolean, DataType::UInt32],
950                vec![DataType::Boolean, DataType::UInt16],
951                Some(vec![DataType::Boolean, DataType::UInt32]),
952            ),
953            // UTF8 -> Timestamp
954            (
955                vec![
956                    DataType::Timestamp(TimeUnit::Nanosecond, None),
957                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+TZ".into())),
958                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
959                ],
960                vec![DataType::Utf8, DataType::Utf8, DataType::Utf8],
961                Some(vec![
962                    DataType::Timestamp(TimeUnit::Nanosecond, None),
963                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+00".into())),
964                    DataType::Timestamp(TimeUnit::Nanosecond, Some("+01".into())),
965                ]),
966            ),
967        ];
968
969        for case in cases {
970            assert_eq!(maybe_data_types(&case.0, &case.1), case.2)
971        }
972    }
973
974    #[test]
975    fn test_get_valid_types_numeric() -> Result<()> {
976        let get_valid_types_flatten =
977            |function_name: &str,
978             signature: &TypeSignature,
979             current_types: &[DataType]| {
980                get_valid_types(function_name, signature, current_types)
981                    .unwrap()
982                    .into_iter()
983                    .flatten()
984                    .collect::<Vec<_>>()
985            };
986
987        // Trivial case.
988        let got = get_valid_types_flatten(
989            "test",
990            &TypeSignature::Numeric(1),
991            &[DataType::Int32],
992        );
993        assert_eq!(got, [DataType::Int32]);
994
995        // Args are coerced into a common numeric type.
996        let got = get_valid_types_flatten(
997            "test",
998            &TypeSignature::Numeric(2),
999            &[DataType::Int32, DataType::Int64],
1000        );
1001        assert_eq!(got, [DataType::Int64, DataType::Int64]);
1002
1003        // Args are coerced into a common numeric type, specifically, int would be coerced to float.
1004        let got = get_valid_types_flatten(
1005            "test",
1006            &TypeSignature::Numeric(3),
1007            &[DataType::Int32, DataType::Int64, DataType::Float64],
1008        );
1009        assert_eq!(
1010            got,
1011            [DataType::Float64, DataType::Float64, DataType::Float64]
1012        );
1013
1014        // Cannot coerce args to a common numeric type.
1015        let got = get_valid_types(
1016            "test",
1017            &TypeSignature::Numeric(2),
1018            &[DataType::Int32, DataType::Utf8],
1019        )
1020        .unwrap_err();
1021        assert_contains!(
1022            got.to_string(),
1023            "Function 'test' expects NativeType::Numeric but received NativeType::String"
1024        );
1025
1026        // Fallbacks to float64 if the arg is of type null.
1027        let got = get_valid_types_flatten(
1028            "test",
1029            &TypeSignature::Numeric(1),
1030            &[DataType::Null],
1031        );
1032        assert_eq!(got, [DataType::Float64]);
1033
1034        // Rejects non-numeric arg.
1035        let got = get_valid_types(
1036            "test",
1037            &TypeSignature::Numeric(1),
1038            &[DataType::Timestamp(TimeUnit::Second, None)],
1039        )
1040        .unwrap_err();
1041        assert_contains!(
1042            got.to_string(),
1043            "Function 'test' expects NativeType::Numeric but received NativeType::Timestamp(Second, None)"
1044        );
1045
1046        Ok(())
1047    }
1048
1049    #[test]
1050    fn test_get_valid_types_one_of() -> Result<()> {
1051        let signature =
1052            TypeSignature::OneOf(vec![TypeSignature::Any(1), TypeSignature::Any(2)]);
1053
1054        let invalid_types = get_valid_types(
1055            "test",
1056            &signature,
1057            &[DataType::Int32, DataType::Int32, DataType::Int32],
1058        )?;
1059        assert_eq!(invalid_types.len(), 0);
1060
1061        let args = vec![DataType::Int32, DataType::Int32];
1062        let valid_types = get_valid_types("test", &signature, &args)?;
1063        assert_eq!(valid_types.len(), 1);
1064        assert_eq!(valid_types[0], args);
1065
1066        let args = vec![DataType::Int32];
1067        let valid_types = get_valid_types("test", &signature, &args)?;
1068        assert_eq!(valid_types.len(), 1);
1069        assert_eq!(valid_types[0], args);
1070
1071        Ok(())
1072    }
1073
1074    #[test]
1075    fn test_get_valid_types_length_check() -> Result<()> {
1076        let signature = TypeSignature::Numeric(1);
1077
1078        let err = get_valid_types("test", &signature, &[]).unwrap_err();
1079        assert_contains!(
1080            err.to_string(),
1081            "Function 'test' expects 1 arguments but received 0"
1082        );
1083
1084        let err = get_valid_types(
1085            "test",
1086            &signature,
1087            &[DataType::Int32, DataType::Int32, DataType::Int32],
1088        )
1089        .unwrap_err();
1090        assert_contains!(
1091            err.to_string(),
1092            "Function 'test' expects 1 arguments but received 3"
1093        );
1094
1095        Ok(())
1096    }
1097
1098    #[test]
1099    fn test_fixed_list_wildcard_coerce() -> Result<()> {
1100        let inner = Arc::new(Field::new_list_field(DataType::Int32, false));
1101        let current_types = vec![
1102            DataType::FixedSizeList(Arc::clone(&inner), 2), // able to coerce for any size
1103        ];
1104
1105        let signature = Signature::exact(
1106            vec![DataType::FixedSizeList(
1107                Arc::clone(&inner),
1108                FIXED_SIZE_LIST_WILDCARD,
1109            )],
1110            Volatility::Stable,
1111        );
1112
1113        let coerced_data_types = data_types("test", &current_types, &signature)?;
1114        assert_eq!(coerced_data_types, current_types);
1115
1116        // make sure it can't coerce to a different size
1117        let signature = Signature::exact(
1118            vec![DataType::FixedSizeList(Arc::clone(&inner), 3)],
1119            Volatility::Stable,
1120        );
1121        let coerced_data_types = data_types("test", &current_types, &signature);
1122        assert!(coerced_data_types.is_err());
1123
1124        // make sure it works with the same type.
1125        let signature = Signature::exact(
1126            vec![DataType::FixedSizeList(Arc::clone(&inner), 2)],
1127            Volatility::Stable,
1128        );
1129        let coerced_data_types = data_types("test", &current_types, &signature).unwrap();
1130        assert_eq!(coerced_data_types, current_types);
1131
1132        Ok(())
1133    }
1134
1135    #[test]
1136    fn test_nested_wildcard_fixed_size_lists() -> Result<()> {
1137        let type_into = DataType::FixedSizeList(
1138            Arc::new(Field::new_list_field(
1139                DataType::FixedSizeList(
1140                    Arc::new(Field::new_list_field(DataType::Int32, false)),
1141                    FIXED_SIZE_LIST_WILDCARD,
1142                ),
1143                false,
1144            )),
1145            FIXED_SIZE_LIST_WILDCARD,
1146        );
1147
1148        let type_from = DataType::FixedSizeList(
1149            Arc::new(Field::new_list_field(
1150                DataType::FixedSizeList(
1151                    Arc::new(Field::new_list_field(DataType::Int8, false)),
1152                    4,
1153                ),
1154                false,
1155            )),
1156            3,
1157        );
1158
1159        assert_eq!(
1160            coerced_from(&type_into, &type_from),
1161            Some(DataType::FixedSizeList(
1162                Arc::new(Field::new_list_field(
1163                    DataType::FixedSizeList(
1164                        Arc::new(Field::new_list_field(DataType::Int32, false)),
1165                        4,
1166                    ),
1167                    false,
1168                )),
1169                3,
1170            ))
1171        );
1172
1173        Ok(())
1174    }
1175
1176    #[test]
1177    fn test_coerced_from_dictionary() {
1178        let type_into =
1179            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1180        let type_from = DataType::Int64;
1181        assert_eq!(coerced_from(&type_into, &type_from), None);
1182
1183        let type_from =
1184            DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::UInt32));
1185        let type_into = DataType::Int64;
1186        assert_eq!(
1187            coerced_from(&type_into, &type_from),
1188            Some(type_into.clone())
1189        );
1190    }
1191}