datafusion_functions/string/
concat.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{as_largestring_array, Array};
19use arrow::datatypes::DataType;
20use datafusion_expr::sort_properties::ExprProperties;
21use std::any::Any;
22use std::sync::Arc;
23
24use crate::string::concat;
25use crate::strings::{
26    ColumnarValueRef, LargeStringArrayBuilder, StringArrayBuilder, StringViewArrayBuilder,
27};
28use datafusion_common::cast::{as_string_array, as_string_view_array};
29use datafusion_common::{internal_err, plan_err, Result, ScalarValue};
30use datafusion_expr::expr::ScalarFunction;
31use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
32use datafusion_expr::{lit, ColumnarValue, Documentation, Expr, Volatility};
33use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37    doc_section(label = "String Functions"),
38    description = "Concatenates multiple strings together.",
39    syntax_example = "concat(str[, ..., str_n])",
40    sql_example = r#"```sql
41> select concat('data', 'f', 'us', 'ion');
42+-------------------------------------------------------+
43| concat(Utf8("data"),Utf8("f"),Utf8("us"),Utf8("ion")) |
44+-------------------------------------------------------+
45| datafusion                                            |
46+-------------------------------------------------------+
47```"#,
48    standard_argument(name = "str", prefix = "String"),
49    argument(
50        name = "str_n",
51        description = "Subsequent string expressions to concatenate."
52    ),
53    related_udf(name = "concat_ws")
54)]
55#[derive(Debug)]
56pub struct ConcatFunc {
57    signature: Signature,
58}
59
60impl Default for ConcatFunc {
61    fn default() -> Self {
62        ConcatFunc::new()
63    }
64}
65
66impl ConcatFunc {
67    pub fn new() -> Self {
68        use DataType::*;
69        Self {
70            signature: Signature::variadic(
71                vec![Utf8View, Utf8, LargeUtf8],
72                Volatility::Immutable,
73            ),
74        }
75    }
76}
77
78impl ScalarUDFImpl for ConcatFunc {
79    fn as_any(&self) -> &dyn Any {
80        self
81    }
82
83    fn name(&self) -> &str {
84        "concat"
85    }
86
87    fn signature(&self) -> &Signature {
88        &self.signature
89    }
90
91    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
92        use DataType::*;
93        let mut dt = &Utf8;
94        arg_types.iter().for_each(|data_type| {
95            if data_type == &Utf8View {
96                dt = data_type;
97            }
98            if data_type == &LargeUtf8 && dt != &Utf8View {
99                dt = data_type;
100            }
101        });
102
103        Ok(dt.to_owned())
104    }
105
106    /// Concatenates the text representations of all the arguments. NULL arguments are ignored.
107    /// concat('abcde', 2, NULL, 22) = 'abcde222'
108    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
109        let ScalarFunctionArgs { args, .. } = args;
110
111        let mut return_datatype = DataType::Utf8;
112        args.iter().for_each(|col| {
113            if col.data_type() == DataType::Utf8View {
114                return_datatype = col.data_type();
115            }
116            if col.data_type() == DataType::LargeUtf8
117                && return_datatype != DataType::Utf8View
118            {
119                return_datatype = col.data_type();
120            }
121        });
122
123        let array_len = args
124            .iter()
125            .filter_map(|x| match x {
126                ColumnarValue::Array(array) => Some(array.len()),
127                _ => None,
128            })
129            .next();
130
131        // Scalar
132        if array_len.is_none() {
133            let mut result = String::new();
134            for arg in args {
135                let ColumnarValue::Scalar(scalar) = arg else {
136                    return internal_err!("concat expected scalar value, got {arg:?}");
137                };
138
139                match scalar.try_as_str() {
140                    Some(Some(v)) => result.push_str(v),
141                    Some(None) => {} // null literal
142                    None => plan_err!(
143                        "Concat function does not support scalar type {:?}",
144                        scalar
145                    )?,
146                }
147            }
148
149            return match return_datatype {
150                DataType::Utf8View => {
151                    Ok(ColumnarValue::Scalar(ScalarValue::Utf8View(Some(result))))
152                }
153                DataType::Utf8 => {
154                    Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(result))))
155                }
156                DataType::LargeUtf8 => {
157                    Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(result))))
158                }
159                other => {
160                    plan_err!("Concat function does not support datatype of {other}")
161                }
162            };
163        }
164
165        // Array
166        let len = array_len.unwrap();
167        let mut data_size = 0;
168        let mut columns = Vec::with_capacity(args.len());
169
170        for arg in &args {
171            match arg {
172                ColumnarValue::Scalar(ScalarValue::Utf8(maybe_value))
173                | ColumnarValue::Scalar(ScalarValue::LargeUtf8(maybe_value))
174                | ColumnarValue::Scalar(ScalarValue::Utf8View(maybe_value)) => {
175                    if let Some(s) = maybe_value {
176                        data_size += s.len() * len;
177                        columns.push(ColumnarValueRef::Scalar(s.as_bytes()));
178                    }
179                }
180                ColumnarValue::Array(array) => {
181                    match array.data_type() {
182                        DataType::Utf8 => {
183                            let string_array = as_string_array(array)?;
184
185                            data_size += string_array.values().len();
186                            let column = if array.is_nullable() {
187                                ColumnarValueRef::NullableArray(string_array)
188                            } else {
189                                ColumnarValueRef::NonNullableArray(string_array)
190                            };
191                            columns.push(column);
192                        },
193                        DataType::LargeUtf8 => {
194                            let string_array = as_largestring_array(array);
195
196                            data_size += string_array.values().len();
197                            let column = if array.is_nullable() {
198                                ColumnarValueRef::NullableLargeStringArray(string_array)
199                            } else {
200                                ColumnarValueRef::NonNullableLargeStringArray(string_array)
201                            };
202                            columns.push(column);
203                        },
204                        DataType::Utf8View => {
205                            let string_array = as_string_view_array(array)?;
206
207                            data_size += string_array.len();
208                            let column = if array.is_nullable() {
209                                ColumnarValueRef::NullableStringViewArray(string_array)
210                            } else {
211                                ColumnarValueRef::NonNullableStringViewArray(string_array)
212                            };
213                            columns.push(column);
214                        },
215                        other => {
216                            return plan_err!("Input was {other} which is not a supported datatype for concat function")
217                        }
218                    };
219                }
220                _ => unreachable!("concat"),
221            }
222        }
223
224        match return_datatype {
225            DataType::Utf8 => {
226                let mut builder = StringArrayBuilder::with_capacity(len, data_size);
227                for i in 0..len {
228                    columns
229                        .iter()
230                        .for_each(|column| builder.write::<true>(column, i));
231                    builder.append_offset();
232                }
233
234                let string_array = builder.finish(None);
235                Ok(ColumnarValue::Array(Arc::new(string_array)))
236            }
237            DataType::Utf8View => {
238                let mut builder = StringViewArrayBuilder::with_capacity(len, data_size);
239                for i in 0..len {
240                    columns
241                        .iter()
242                        .for_each(|column| builder.write::<true>(column, i));
243                    builder.append_offset();
244                }
245
246                let string_array = builder.finish();
247                Ok(ColumnarValue::Array(Arc::new(string_array)))
248            }
249            DataType::LargeUtf8 => {
250                let mut builder = LargeStringArrayBuilder::with_capacity(len, data_size);
251                for i in 0..len {
252                    columns
253                        .iter()
254                        .for_each(|column| builder.write::<true>(column, i));
255                    builder.append_offset();
256                }
257
258                let string_array = builder.finish(None);
259                Ok(ColumnarValue::Array(Arc::new(string_array)))
260            }
261            _ => unreachable!(),
262        }
263    }
264
265    /// Simplify the `concat` function by
266    /// 1. filtering out all `null` literals
267    /// 2. concatenating contiguous literal arguments
268    ///
269    /// For example:
270    /// `concat(col(a), 'hello ', 'world', col(b), null)`
271    /// will be optimized to
272    /// `concat(col(a), 'hello world', col(b))`
273    fn simplify(
274        &self,
275        args: Vec<Expr>,
276        _info: &dyn SimplifyInfo,
277    ) -> Result<ExprSimplifyResult> {
278        simplify_concat(args)
279    }
280
281    fn documentation(&self) -> Option<&Documentation> {
282        self.doc()
283    }
284
285    fn preserves_lex_ordering(&self, _inputs: &[ExprProperties]) -> Result<bool> {
286        Ok(true)
287    }
288}
289
290pub fn simplify_concat(args: Vec<Expr>) -> Result<ExprSimplifyResult> {
291    let mut new_args = Vec::with_capacity(args.len());
292    let mut contiguous_scalar = "".to_string();
293
294    let return_type = {
295        let data_types: Vec<_> = args
296            .iter()
297            .filter_map(|expr| match expr {
298                Expr::Literal(l) => Some(l.data_type()),
299                _ => None,
300            })
301            .collect();
302        ConcatFunc::new().return_type(&data_types)
303    }?;
304
305    for arg in args.clone() {
306        match arg {
307            Expr::Literal(ScalarValue::Utf8(None)) => {}
308            Expr::Literal(ScalarValue::LargeUtf8(None)) => {
309            }
310            Expr::Literal(ScalarValue::Utf8View(None)) => { }
311
312            // filter out `null` args
313            // All literals have been converted to Utf8 or LargeUtf8 in type_coercion.
314            // Concatenate it with the `contiguous_scalar`.
315            Expr::Literal(ScalarValue::Utf8(Some(v))) => {
316                contiguous_scalar += &v;
317            }
318            Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => {
319                contiguous_scalar += &v;
320            }
321            Expr::Literal(ScalarValue::Utf8View(Some(v))) => {
322                contiguous_scalar += &v;
323            }
324
325            Expr::Literal(x) => {
326                return internal_err!(
327                    "The scalar {x} should be casted to string type during the type coercion."
328                )
329            }
330            // If the arg is not a literal, we should first push the current `contiguous_scalar`
331            // to the `new_args` (if it is not empty) and reset it to empty string.
332            // Then pushing this arg to the `new_args`.
333            arg => {
334                if !contiguous_scalar.is_empty() {
335                    match return_type {
336                        DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
337                        DataType::LargeUtf8 => new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))),
338                        DataType::Utf8View => new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))),
339                        _ => unreachable!(),
340                    }
341                    contiguous_scalar = "".to_string();
342                }
343                new_args.push(arg);
344            }
345        }
346    }
347
348    if !contiguous_scalar.is_empty() {
349        match return_type {
350            DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
351            DataType::LargeUtf8 => {
352                new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar))))
353            }
354            DataType::Utf8View => {
355                new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar))))
356            }
357            _ => unreachable!(),
358        }
359    }
360
361    if !args.eq(&new_args) {
362        Ok(ExprSimplifyResult::Simplified(Expr::ScalarFunction(
363            ScalarFunction {
364                func: concat(),
365                args: new_args,
366            },
367        )))
368    } else {
369        Ok(ExprSimplifyResult::Original(args))
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use crate::utils::test::test_function;
377    use arrow::array::{Array, LargeStringArray, StringViewArray};
378    use arrow::array::{ArrayRef, StringArray};
379    use DataType::*;
380
381    #[test]
382    fn test_functions() -> Result<()> {
383        test_function!(
384            ConcatFunc::new(),
385            vec![
386                ColumnarValue::Scalar(ScalarValue::from("aa")),
387                ColumnarValue::Scalar(ScalarValue::from("bb")),
388                ColumnarValue::Scalar(ScalarValue::from("cc")),
389            ],
390            Ok(Some("aabbcc")),
391            &str,
392            Utf8,
393            StringArray
394        );
395        test_function!(
396            ConcatFunc::new(),
397            vec![
398                ColumnarValue::Scalar(ScalarValue::from("aa")),
399                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
400                ColumnarValue::Scalar(ScalarValue::from("cc")),
401            ],
402            Ok(Some("aacc")),
403            &str,
404            Utf8,
405            StringArray
406        );
407        test_function!(
408            ConcatFunc::new(),
409            vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))],
410            Ok(Some("")),
411            &str,
412            Utf8,
413            StringArray
414        );
415        test_function!(
416            ConcatFunc::new(),
417            vec![
418                ColumnarValue::Scalar(ScalarValue::from("aa")),
419                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
420                ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
421                ColumnarValue::Scalar(ScalarValue::from("cc")),
422            ],
423            Ok(Some("aacc")),
424            &str,
425            Utf8View,
426            StringViewArray
427        );
428        test_function!(
429            ConcatFunc::new(),
430            vec![
431                ColumnarValue::Scalar(ScalarValue::from("aa")),
432                ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)),
433                ColumnarValue::Scalar(ScalarValue::from("cc")),
434            ],
435            Ok(Some("aacc")),
436            &str,
437            LargeUtf8,
438            LargeStringArray
439        );
440        test_function!(
441            ConcatFunc::new(),
442            vec![
443                ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))),
444                ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
445            ],
446            Ok(Some("aacc")),
447            &str,
448            Utf8View,
449            StringViewArray
450        );
451
452        Ok(())
453    }
454
455    #[test]
456    fn concat() -> Result<()> {
457        let c0 =
458            ColumnarValue::Array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
459        let c1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some(",".to_string())));
460        let c2 = ColumnarValue::Array(Arc::new(StringArray::from(vec![
461            Some("x"),
462            None,
463            Some("z"),
464        ])));
465        let c3 = ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
466        let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
467            Some("a"),
468            None,
469            Some("b"),
470        ])));
471
472        let args = ScalarFunctionArgs {
473            args: vec![c0, c1, c2, c3, c4],
474            number_rows: 3,
475            return_type: &Utf8,
476        };
477
478        let result = ConcatFunc::new().invoke_with_args(args)?;
479        let expected =
480            Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", "baz,z,b"]))
481                as ArrayRef;
482        match &result {
483            ColumnarValue::Array(array) => {
484                assert_eq!(&expected, array);
485            }
486            _ => panic!(),
487        }
488        Ok(())
489    }
490}