datafusion_functions/
utils.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::ArrayRef;
19use arrow::datatypes::DataType;
20
21use datafusion_common::{Result, ScalarValue};
22use datafusion_expr::function::Hint;
23use datafusion_expr::ColumnarValue;
24
25/// Creates a function to identify the optimal return type of a string function given
26/// the type of its first argument.
27///
28/// If the input type is `LargeUtf8` or `LargeBinary` the return type is
29/// `$largeUtf8Type`,
30///
31/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`,
32///
33/// If the input type is `Utf8View` the return type is $utf8Type,
34macro_rules! get_optimal_return_type {
35    ($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
36        pub(crate) fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
37            Ok(match arg_type {
38                // LargeBinary inputs are automatically coerced to Utf8
39                DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
40                // Binary inputs are automatically coerced to Utf8
41                DataType::Utf8 | DataType::Binary => $utf8Type,
42                // Utf8View max offset size is u32::MAX, the same as UTF8
43                DataType::Utf8View | DataType::BinaryView => $utf8Type,
44                DataType::Null => DataType::Null,
45                DataType::Dictionary(_, value_type) => match **value_type {
46                    DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
47                    DataType::Utf8 | DataType::Binary => $utf8Type,
48                    DataType::Null => DataType::Null,
49                    _ => {
50                        return datafusion_common::exec_err!(
51                            "The {} function can only accept strings, but got {:?}.",
52                            name.to_uppercase(),
53                            **value_type
54                        );
55                    }
56                },
57                data_type => {
58                    return datafusion_common::exec_err!(
59                        "The {} function can only accept strings, but got {:?}.",
60                        name.to_uppercase(),
61                        data_type
62                    );
63                }
64            })
65        }
66    };
67}
68
69// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size.
70get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);
71
72// `utf8_to_int_type`: returns either a Int32 or Int64 based on the input type size.
73get_optimal_return_type!(utf8_to_int_type, DataType::Int64, DataType::Int32);
74
75/// Creates a scalar function implementation for the given function.
76/// * `inner` - the function to be executed
77/// * `hints` - hints to be used when expanding scalars to arrays
78pub(super) fn make_scalar_function<F>(
79    inner: F,
80    hints: Vec<Hint>,
81) -> impl Fn(&[ColumnarValue]) -> Result<ColumnarValue>
82where
83    F: Fn(&[ArrayRef]) -> Result<ArrayRef>,
84{
85    move |args: &[ColumnarValue]| {
86        // first, identify if any of the arguments is an Array. If yes, store its `len`,
87        // as any scalar will need to be converted to an array of len `len`.
88        let len = args
89            .iter()
90            .fold(Option::<usize>::None, |acc, arg| match arg {
91                ColumnarValue::Scalar(_) => acc,
92                ColumnarValue::Array(a) => Some(a.len()),
93            });
94
95        let is_scalar = len.is_none();
96
97        let inferred_length = len.unwrap_or(1);
98        let args = args
99            .iter()
100            .zip(hints.iter().chain(std::iter::repeat(&Hint::Pad)))
101            .map(|(arg, hint)| {
102                // Decide on the length to expand this scalar to depending
103                // on the given hints.
104                let expansion_len = match hint {
105                    Hint::AcceptsSingular => 1,
106                    Hint::Pad => inferred_length,
107                };
108                arg.to_array(expansion_len)
109            })
110            .collect::<Result<Vec<_>>>()?;
111
112        let result = (inner)(&args);
113        if is_scalar {
114            // If all inputs are scalar, keeps output as scalar
115            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
116            result.map(ColumnarValue::Scalar)
117        } else {
118            result.map(ColumnarValue::Array)
119        }
120    }
121}
122
123#[cfg(test)]
124pub mod test {
125    /// $FUNC ScalarUDFImpl to test
126    /// $ARGS arguments (vec) to pass to function
127    /// $EXPECTED a Result<ColumnarValue>
128    /// $EXPECTED_TYPE is the expected value type
129    /// $EXPECTED_DATA_TYPE is the expected result type
130    /// $ARRAY_TYPE is the column type after function applied
131    macro_rules! test_function {
132        ($FUNC:expr, $ARGS:expr, $EXPECTED:expr, $EXPECTED_TYPE:ty, $EXPECTED_DATA_TYPE:expr, $ARRAY_TYPE:ident) => {
133            let expected: Result<Option<$EXPECTED_TYPE>> = $EXPECTED;
134            let func = $FUNC;
135
136            let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::<Vec<_>>();
137            let cardinality = $ARGS
138                .iter()
139                .fold(Option::<usize>::None, |acc, arg| match arg {
140                    ColumnarValue::Scalar(_) => acc,
141                    ColumnarValue::Array(a) => Some(a.len()),
142                })
143                .unwrap_or(1);
144
145            let scalar_arguments = $ARGS.iter().map(|arg| match arg {
146                ColumnarValue::Scalar(scalar) => Some(scalar.clone()),
147                ColumnarValue::Array(_) => None,
148            }).collect::<Vec<_>>();
149            let scalar_arguments_refs = scalar_arguments.iter().map(|arg| arg.as_ref()).collect::<Vec<_>>();
150
151            let nullables = $ARGS.iter().map(|arg| match arg {
152                ColumnarValue::Scalar(scalar) => scalar.is_null(),
153                ColumnarValue::Array(a) => a.null_count() > 0,
154            }).collect::<Vec<_>>();
155
156            let return_info = func.return_type_from_args(datafusion_expr::ReturnTypeArgs {
157                arg_types: &type_array,
158                scalar_arguments: &scalar_arguments_refs,
159                nullables: &nullables
160            });
161
162            match expected {
163                Ok(expected) => {
164                    assert_eq!(return_info.is_ok(), true);
165                    let (return_type, _nullable) = return_info.unwrap().into_parts();
166                    assert_eq!(return_type, $EXPECTED_DATA_TYPE);
167
168                    let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type});
169                    assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err());
170
171                    let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array");
172                    let result = result.as_any().downcast_ref::<$ARRAY_TYPE>().expect("Failed to convert to type");
173                    assert_eq!(result.data_type(), &$EXPECTED_DATA_TYPE);
174
175                    // value is correct
176                    match expected {
177                        Some(v) => assert_eq!(result.value(0), v),
178                        None => assert!(result.is_null(0)),
179                    };
180                }
181                Err(expected_error) => {
182                    if return_info.is_err() {
183                        match return_info {
184                            Ok(_) => assert!(false, "expected error"),
185                            Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); }
186                        }
187                    }
188                    else {
189                        let (return_type, _nullable) = return_info.unwrap().into_parts();
190
191                        // invoke is expected error - cannot use .expect_err() due to Debug not being implemented
192                        match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}) {
193                            Ok(_) => assert!(false, "expected error"),
194                            Err(error) => {
195                                assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace()));
196                            }
197                        }
198                    }
199                }
200            };
201        };
202    }
203
204    use arrow::datatypes::DataType;
205    #[allow(unused_imports)]
206    pub(crate) use test_function;
207
208    use super::*;
209
210    #[test]
211    fn string_to_int_type() {
212        let v = utf8_to_int_type(&DataType::Utf8, "test").unwrap();
213        assert_eq!(v, DataType::Int32);
214
215        let v = utf8_to_int_type(&DataType::Utf8View, "test").unwrap();
216        assert_eq!(v, DataType::Int32);
217
218        let v = utf8_to_int_type(&DataType::LargeUtf8, "test").unwrap();
219        assert_eq!(v, DataType::Int64);
220    }
221}