datafusion_functions/unicode/
find_in_set.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    new_null_array, ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray,
23    OffsetSizeTrait, PrimitiveArray,
24};
25use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
26
27use crate::utils::utf8_to_int_type;
28use datafusion_common::{
29    exec_err, internal_err, utils::take_function_args, Result, ScalarValue,
30};
31use datafusion_expr::TypeSignature::Exact;
32use datafusion_expr::{
33    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34    Volatility,
35};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39    doc_section(label = "String Functions"),
40    description = "Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings.",
41    syntax_example = "find_in_set(str, strlist)",
42    sql_example = r#"```sql
43> select find_in_set('b', 'a,b,c,d');
44+----------------------------------------+
45| find_in_set(Utf8("b"),Utf8("a,b,c,d")) |
46+----------------------------------------+
47| 2                                      |
48+----------------------------------------+
49```"#,
50    argument(name = "str", description = "String expression to find in strlist."),
51    argument(
52        name = "strlist",
53        description = "A string list is a string composed of substrings separated by , characters."
54    )
55)]
56#[derive(Debug)]
57pub struct FindInSetFunc {
58    signature: Signature,
59}
60
61impl Default for FindInSetFunc {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl FindInSetFunc {
68    pub fn new() -> Self {
69        use DataType::*;
70        Self {
71            signature: Signature::one_of(
72                vec![
73                    Exact(vec![Utf8View, Utf8View]),
74                    Exact(vec![Utf8, Utf8]),
75                    Exact(vec![LargeUtf8, LargeUtf8]),
76                ],
77                Volatility::Immutable,
78            ),
79        }
80    }
81}
82
83impl ScalarUDFImpl for FindInSetFunc {
84    fn as_any(&self) -> &dyn Any {
85        self
86    }
87
88    fn name(&self) -> &str {
89        "find_in_set"
90    }
91
92    fn signature(&self) -> &Signature {
93        &self.signature
94    }
95
96    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
97        utf8_to_int_type(&arg_types[0], "find_in_set")
98    }
99
100    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
101        let ScalarFunctionArgs { args, .. } = args;
102
103        let [string, str_list] = take_function_args(self.name(), args)?;
104
105        match (string, str_list) {
106            // both inputs are scalars
107            (
108                ColumnarValue::Scalar(
109                    ScalarValue::Utf8View(string)
110                    | ScalarValue::Utf8(string)
111                    | ScalarValue::LargeUtf8(string),
112                ),
113                ColumnarValue::Scalar(
114                    ScalarValue::Utf8View(str_list)
115                    | ScalarValue::Utf8(str_list)
116                    | ScalarValue::LargeUtf8(str_list),
117                ),
118            ) => {
119                let res = match (string, str_list) {
120                    (Some(string), Some(str_list)) => {
121                        let position = str_list
122                            .split(',')
123                            .position(|s| s == string)
124                            .map_or(0, |idx| idx + 1);
125
126                        Some(position as i32)
127                    }
128                    _ => None,
129                };
130                Ok(ColumnarValue::Scalar(ScalarValue::from(res)))
131            }
132
133            // `string` is an array, `str_list` is scalar
134            (
135                ColumnarValue::Array(str_array),
136                ColumnarValue::Scalar(
137                    ScalarValue::Utf8View(str_list_literal)
138                    | ScalarValue::Utf8(str_list_literal)
139                    | ScalarValue::LargeUtf8(str_list_literal),
140                ),
141            ) => {
142                let result_array = match str_list_literal {
143                    // find_in_set(column_a, null) = null
144                    None => new_null_array(str_array.data_type(), str_array.len()),
145                    Some(str_list_literal) => {
146                        let str_list = str_list_literal.split(',').collect::<Vec<&str>>();
147                        let result = match str_array.data_type() {
148                            DataType::Utf8 => {
149                                let string_array = str_array.as_string::<i32>();
150                                find_in_set_right_literal::<Int32Type, _>(
151                                    string_array,
152                                    str_list,
153                                )
154                            }
155                            DataType::LargeUtf8 => {
156                                let string_array = str_array.as_string::<i64>();
157                                find_in_set_right_literal::<Int64Type, _>(
158                                    string_array,
159                                    str_list,
160                                )
161                            }
162                            DataType::Utf8View => {
163                                let string_array = str_array.as_string_view();
164                                find_in_set_right_literal::<Int32Type, _>(
165                                    string_array,
166                                    str_list,
167                                )
168                            }
169                            other => {
170                                exec_err!("Unsupported data type {other:?} for function find_in_set")
171                            }
172                        };
173                        Arc::new(result?)
174                    }
175                };
176                Ok(ColumnarValue::Array(result_array))
177            }
178
179            // `string` is scalar, `str_list` is an array
180            (
181                ColumnarValue::Scalar(
182                    ScalarValue::Utf8View(string_literal)
183                    | ScalarValue::Utf8(string_literal)
184                    | ScalarValue::LargeUtf8(string_literal),
185                ),
186                ColumnarValue::Array(str_list_array),
187            ) => {
188                let res = match string_literal {
189                    // find_in_set(null, column_b) = null
190                    None => {
191                        new_null_array(str_list_array.data_type(), str_list_array.len())
192                    }
193                    Some(string) => {
194                        let result = match str_list_array.data_type() {
195                            DataType::Utf8 => {
196                                let str_list = str_list_array.as_string::<i32>();
197                                find_in_set_left_literal::<Int32Type, _>(string, str_list)
198                            }
199                            DataType::LargeUtf8 => {
200                                let str_list = str_list_array.as_string::<i64>();
201                                find_in_set_left_literal::<Int64Type, _>(string, str_list)
202                            }
203                            DataType::Utf8View => {
204                                let str_list = str_list_array.as_string_view();
205                                find_in_set_left_literal::<Int32Type, _>(string, str_list)
206                            }
207                            other => {
208                                exec_err!("Unsupported data type {other:?} for function find_in_set")
209                            }
210                        };
211                        Arc::new(result?)
212                    }
213                };
214                Ok(ColumnarValue::Array(res))
215            }
216
217            // both inputs are arrays
218            (ColumnarValue::Array(base_array), ColumnarValue::Array(exp_array)) => {
219                let res = find_in_set(base_array, exp_array)?;
220
221                Ok(ColumnarValue::Array(res))
222            }
223            _ => {
224                internal_err!("Invalid argument types for `find_in_set` function")
225            }
226        }
227    }
228
229    fn documentation(&self) -> Option<&Documentation> {
230        self.doc()
231    }
232}
233
234/// Returns a value in the range of 1 to N if the string `str` is in the string list `strlist`
235/// consisting of N substrings. A string list is a string composed of substrings separated by `,`
236/// characters.
237fn find_in_set(str: ArrayRef, str_list: ArrayRef) -> Result<ArrayRef> {
238    match str.data_type() {
239        DataType::Utf8 => {
240            let string_array = str.as_string::<i32>();
241            let str_list_array = str_list.as_string::<i32>();
242            find_in_set_general::<Int32Type, _>(string_array, str_list_array)
243        }
244        DataType::LargeUtf8 => {
245            let string_array = str.as_string::<i64>();
246            let str_list_array = str_list.as_string::<i64>();
247            find_in_set_general::<Int64Type, _>(string_array, str_list_array)
248        }
249        DataType::Utf8View => {
250            let string_array = str.as_string_view();
251            let str_list_array = str_list.as_string_view();
252            find_in_set_general::<Int32Type, _>(string_array, str_list_array)
253        }
254        other => {
255            exec_err!("Unsupported data type {other:?} for function find_in_set")
256        }
257    }
258}
259
260pub fn find_in_set_general<'a, T, V>(
261    string_array: V,
262    str_list_array: V,
263) -> Result<ArrayRef>
264where
265    T: ArrowPrimitiveType,
266    T::Native: OffsetSizeTrait,
267    V: ArrayAccessor<Item = &'a str>,
268{
269    let string_iter = ArrayIter::new(string_array);
270    let str_list_iter = ArrayIter::new(str_list_array);
271
272    let mut builder = PrimitiveArray::<T>::builder(string_iter.len());
273
274    string_iter
275        .zip(str_list_iter)
276        .for_each(
277            |(string_opt, str_list_opt)| match (string_opt, str_list_opt) {
278                (Some(string), Some(str_list)) => {
279                    let position = str_list
280                        .split(',')
281                        .position(|s| s == string)
282                        .map_or(0, |idx| idx + 1);
283                    builder.append_value(T::Native::from_usize(position).unwrap());
284                }
285                _ => builder.append_null(),
286            },
287        );
288
289    Ok(Arc::new(builder.finish()) as ArrayRef)
290}
291
292fn find_in_set_left_literal<'a, T, V>(
293    string: String,
294    str_list_array: V,
295) -> Result<ArrayRef>
296where
297    T: ArrowPrimitiveType,
298    T::Native: OffsetSizeTrait,
299    V: ArrayAccessor<Item = &'a str>,
300{
301    let mut builder = PrimitiveArray::<T>::builder(str_list_array.len());
302
303    let str_list_iter = ArrayIter::new(str_list_array);
304
305    str_list_iter.for_each(|str_list_opt| match str_list_opt {
306        Some(str_list) => {
307            let position = str_list
308                .split(',')
309                .position(|s| s == string)
310                .map_or(0, |idx| idx + 1);
311            builder.append_value(T::Native::from_usize(position).unwrap());
312        }
313        None => builder.append_null(),
314    });
315
316    Ok(Arc::new(builder.finish()) as ArrayRef)
317}
318
319fn find_in_set_right_literal<'a, T, V>(
320    string_array: V,
321    str_list: Vec<&str>,
322) -> Result<ArrayRef>
323where
324    T: ArrowPrimitiveType,
325    T::Native: OffsetSizeTrait,
326    V: ArrayAccessor<Item = &'a str>,
327{
328    let mut builder = PrimitiveArray::<T>::builder(string_array.len());
329
330    let string_iter = ArrayIter::new(string_array);
331
332    string_iter.for_each(|string_opt| match string_opt {
333        Some(string) => {
334            let position = str_list
335                .iter()
336                .position(|s| *s == string)
337                .map_or(0, |idx| idx + 1);
338            builder.append_value(T::Native::from_usize(position).unwrap());
339        }
340        None => builder.append_null(),
341    });
342
343    Ok(Arc::new(builder.finish()) as ArrayRef)
344}
345
346#[cfg(test)]
347mod tests {
348    use crate::unicode::find_in_set::FindInSetFunc;
349    use crate::utils::test::test_function;
350    use arrow::array::{Array, Int32Array, StringArray};
351    use arrow::datatypes::DataType::Int32;
352    use datafusion_common::{Result, ScalarValue};
353    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
354    use std::sync::Arc;
355
356    #[test]
357    fn test_functions() -> Result<()> {
358        test_function!(
359            FindInSetFunc::new(),
360            vec![
361                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
362                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
363            ],
364            Ok(Some(1)),
365            i32,
366            Int32,
367            Int32Array
368        );
369        test_function!(
370            FindInSetFunc::new(),
371            vec![
372                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("🔥")))),
373                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
374                    "a,Д,🔥"
375                )))),
376            ],
377            Ok(Some(3)),
378            i32,
379            Int32,
380            Int32Array
381        );
382        test_function!(
383            FindInSetFunc::new(),
384            vec![
385                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("d")))),
386                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
387            ],
388            Ok(Some(0)),
389            i32,
390            Int32,
391            Int32Array
392        );
393        test_function!(
394            FindInSetFunc::new(),
395            vec![
396                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
397                    "Apache Software Foundation"
398                )))),
399                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
400                    "Github,Apache Software Foundation,DataFusion"
401                )))),
402            ],
403            Ok(Some(2)),
404            i32,
405            Int32,
406            Int32Array
407        );
408        test_function!(
409            FindInSetFunc::new(),
410            vec![
411                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
412                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a,b,c")))),
413            ],
414            Ok(Some(0)),
415            i32,
416            Int32,
417            Int32Array
418        );
419        test_function!(
420            FindInSetFunc::new(),
421            vec![
422                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("a")))),
423                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("")))),
424            ],
425            Ok(Some(0)),
426            i32,
427            Int32,
428            Int32Array
429        );
430        test_function!(
431            FindInSetFunc::new(),
432            vec![
433                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a")))),
434                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
435            ],
436            Ok(None),
437            i32,
438            Int32,
439            Int32Array
440        );
441        test_function!(
442            FindInSetFunc::new(),
443            vec![
444                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
445                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("a,b,c")))),
446            ],
447            Ok(None),
448            i32,
449            Int32,
450            Int32Array
451        );
452
453        Ok(())
454    }
455
456    macro_rules! test_find_in_set {
457        ($test_name:ident, $args:expr, $expected:expr) => {
458            #[test]
459            fn $test_name() -> Result<()> {
460                let fis = crate::unicode::find_in_set();
461
462                let args = $args;
463                let expected = $expected;
464
465                let type_array = args.iter().map(|a| a.data_type()).collect::<Vec<_>>();
466                let cardinality = args
467                    .iter()
468                    .fold(Option::<usize>::None, |acc, arg| match arg {
469                        ColumnarValue::Scalar(_) => acc,
470                        ColumnarValue::Array(a) => Some(a.len()),
471                    })
472                    .unwrap_or(1);
473                let return_type = fis.return_type(&type_array)?;
474                let result = fis.invoke_with_args(ScalarFunctionArgs {
475                    args,
476                    number_rows: cardinality,
477                    return_type: &return_type,
478                });
479                assert!(result.is_ok());
480
481                let result = result?
482                    .to_array(cardinality)
483                    .expect("Failed to convert to array");
484                let result = result
485                    .as_any()
486                    .downcast_ref::<Int32Array>()
487                    .expect("Failed to convert to type");
488                assert_eq!(*result, expected);
489
490                Ok(())
491            }
492        };
493    }
494
495    test_find_in_set!(
496        test_find_in_set_with_scalar_args,
497        vec![
498            ColumnarValue::Array(Arc::new(StringArray::from(vec![
499                "", "a", "b", "c", "d"
500            ]))),
501            ColumnarValue::Scalar(ScalarValue::Utf8(Some("b,c,d".to_string()))),
502        ],
503        Int32Array::from(vec![0, 0, 1, 2, 3])
504    );
505    test_find_in_set!(
506        test_find_in_set_with_scalar_args_2,
507        vec![
508            ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
509                "ApacheSoftware".to_string()
510            ))),
511            ColumnarValue::Array(Arc::new(StringArray::from(vec![
512                "a,b,c",
513                "ApacheSoftware,Github,DataFusion",
514                ""
515            ]))),
516        ],
517        Int32Array::from(vec![0, 1, 0])
518    );
519    test_find_in_set!(
520        test_find_in_set_with_scalar_args_3,
521        vec![
522            ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
523            ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a,b,c".to_string()))),
524        ],
525        Int32Array::from(vec![None::<i32>; 3])
526    );
527    test_find_in_set!(
528        test_find_in_set_with_scalar_args_4,
529        vec![
530            ColumnarValue::Scalar(ScalarValue::Utf8View(Some("a".to_string()))),
531            ColumnarValue::Array(Arc::new(StringArray::from(vec![None::<&str>; 3]))),
532        ],
533        Int32Array::from(vec![None::<i32>; 3])
534    );
535}