datafusion_functions/regex/
regexpcount.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::{Array, ArrayRef, AsArray, Datum, Int64Array, StringArrayType};
19use arrow::datatypes::{DataType, Int64Type};
20use arrow::datatypes::{
21    DataType::Int64, DataType::LargeUtf8, DataType::Utf8, DataType::Utf8View,
22};
23use arrow::error::ArrowError;
24use datafusion_common::{exec_err, internal_err, Result, ScalarValue};
25use datafusion_expr::{
26    ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature::Exact,
27    TypeSignature::Uniform, Volatility,
28};
29use datafusion_macros::user_doc;
30use itertools::izip;
31use regex::Regex;
32use std::collections::hash_map::Entry;
33use std::collections::HashMap;
34use std::sync::Arc;
35
36#[user_doc(
37    doc_section(label = "Regular Expression Functions"),
38    description = "Returns the number of matches that a [regular expression](https://docs.rs/regex/latest/regex/#syntax) has in a string.",
39    syntax_example = "regexp_count(str, regexp[, start, flags])",
40    sql_example = r#"```sql
41> select regexp_count('abcAbAbc', 'abc', 2, 'i');
42+---------------------------------------------------------------+
43| regexp_count(Utf8("abcAbAbc"),Utf8("abc"),Int64(2),Utf8("i")) |
44+---------------------------------------------------------------+
45| 1                                                             |
46+---------------------------------------------------------------+
47```"#,
48    standard_argument(name = "str", prefix = "String"),
49    standard_argument(name = "regexp", prefix = "Regular"),
50    argument(
51        name = "start",
52        description = "- **start**: Optional start position (the first position is 1) to search for the regular expression. Can be a constant, column, or function."
53    ),
54    argument(
55        name = "flags",
56        description = r#"Optional regular expression flags that control the behavior of the regular expression. The following flags are supported:
57  - **i**: case-insensitive: letters match both upper and lower case
58  - **m**: multi-line mode: ^ and $ match begin/end of line
59  - **s**: allow . to match \n
60  - **R**: enables CRLF mode: when multi-line mode is enabled, \r\n is used
61  - **U**: swap the meaning of x* and x*?"#
62    )
63)]
64#[derive(Debug)]
65pub struct RegexpCountFunc {
66    signature: Signature,
67}
68
69impl Default for RegexpCountFunc {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75impl RegexpCountFunc {
76    pub fn new() -> Self {
77        Self {
78            signature: Signature::one_of(
79                vec![
80                    Uniform(2, vec![Utf8View, LargeUtf8, Utf8]),
81                    Exact(vec![Utf8View, Utf8View, Int64]),
82                    Exact(vec![LargeUtf8, LargeUtf8, Int64]),
83                    Exact(vec![Utf8, Utf8, Int64]),
84                    Exact(vec![Utf8View, Utf8View, Int64, Utf8View]),
85                    Exact(vec![LargeUtf8, LargeUtf8, Int64, LargeUtf8]),
86                    Exact(vec![Utf8, Utf8, Int64, Utf8]),
87                ],
88                Volatility::Immutable,
89            ),
90        }
91    }
92}
93
94impl ScalarUDFImpl for RegexpCountFunc {
95    fn as_any(&self) -> &dyn std::any::Any {
96        self
97    }
98
99    fn name(&self) -> &str {
100        "regexp_count"
101    }
102
103    fn signature(&self) -> &Signature {
104        &self.signature
105    }
106
107    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
108        Ok(Int64)
109    }
110
111    fn invoke_with_args(
112        &self,
113        args: datafusion_expr::ScalarFunctionArgs,
114    ) -> Result<ColumnarValue> {
115        let args = &args.args;
116
117        let len = args
118            .iter()
119            .fold(Option::<usize>::None, |acc, arg| match arg {
120                ColumnarValue::Scalar(_) => acc,
121                ColumnarValue::Array(a) => Some(a.len()),
122            });
123
124        let is_scalar = len.is_none();
125        let inferred_length = len.unwrap_or(1);
126        let args = args
127            .iter()
128            .map(|arg| arg.to_array(inferred_length))
129            .collect::<Result<Vec<_>>>()?;
130
131        let result = regexp_count_func(&args);
132        if is_scalar {
133            // If all inputs are scalar, keeps output as scalar
134            let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
135            result.map(ColumnarValue::Scalar)
136        } else {
137            result.map(ColumnarValue::Array)
138        }
139    }
140
141    fn documentation(&self) -> Option<&Documentation> {
142        self.doc()
143    }
144}
145
146pub fn regexp_count_func(args: &[ArrayRef]) -> Result<ArrayRef> {
147    let args_len = args.len();
148    if !(2..=4).contains(&args_len) {
149        return exec_err!("regexp_count was called with {args_len} arguments. It requires at least 2 and at most 4.");
150    }
151
152    let values = &args[0];
153    match values.data_type() {
154        Utf8 | LargeUtf8 | Utf8View => (),
155        other => {
156            return internal_err!(
157                "Unsupported data type {other:?} for function regexp_count"
158            );
159        }
160    }
161
162    regexp_count(
163        values,
164        &args[1],
165        if args_len > 2 { Some(&args[2]) } else { None },
166        if args_len > 3 { Some(&args[3]) } else { None },
167    )
168    .map_err(|e| e.into())
169}
170
171/// `arrow-rs` style implementation of `regexp_count` function.
172/// This function `regexp_count` is responsible for counting the occurrences of a regular expression pattern
173/// within a string array. It supports optional start positions and flags for case insensitivity.
174///
175/// The function accepts a variable number of arguments:
176/// - `values`: The array of strings to search within.
177/// - `regex_array`: The array of regular expression patterns to search for.
178/// - `start_array` (optional): The array of start positions for the search.
179/// - `flags_array` (optional): The array of flags to modify the search behavior (e.g., case insensitivity).
180///
181/// The function handles different combinations of scalar and array inputs for the regex patterns, start positions,
182/// and flags. It uses a cache to store compiled regular expressions for efficiency.
183///
184/// # Errors
185/// Returns an error if the input arrays have mismatched lengths or if the regular expression fails to compile.
186pub fn regexp_count(
187    values: &dyn Array,
188    regex_array: &dyn Datum,
189    start_array: Option<&dyn Datum>,
190    flags_array: Option<&dyn Datum>,
191) -> Result<ArrayRef, ArrowError> {
192    let (regex_array, is_regex_scalar) = regex_array.get();
193    let (start_array, is_start_scalar) = start_array.map_or((None, true), |start| {
194        let (start, is_start_scalar) = start.get();
195        (Some(start), is_start_scalar)
196    });
197    let (flags_array, is_flags_scalar) = flags_array.map_or((None, true), |flags| {
198        let (flags, is_flags_scalar) = flags.get();
199        (Some(flags), is_flags_scalar)
200    });
201
202    match (values.data_type(), regex_array.data_type(), flags_array) {
203        (Utf8, Utf8, None) => regexp_count_inner(
204            values.as_string::<i32>(),
205            regex_array.as_string::<i32>(),
206            is_regex_scalar,
207            start_array.map(|start| start.as_primitive::<Int64Type>()),
208            is_start_scalar,
209            None,
210            is_flags_scalar,
211        ),
212        (Utf8, Utf8, Some(flags_array)) if *flags_array.data_type() == Utf8 => regexp_count_inner(
213            values.as_string::<i32>(),
214            regex_array.as_string::<i32>(),
215            is_regex_scalar,
216            start_array.map(|start| start.as_primitive::<Int64Type>()),
217            is_start_scalar,
218            Some(flags_array.as_string::<i32>()),
219            is_flags_scalar,
220        ),
221        (LargeUtf8, LargeUtf8, None) => regexp_count_inner(
222            values.as_string::<i64>(),
223            regex_array.as_string::<i64>(),
224            is_regex_scalar,
225            start_array.map(|start| start.as_primitive::<Int64Type>()),
226            is_start_scalar,
227            None,
228            is_flags_scalar,
229        ),
230        (LargeUtf8, LargeUtf8, Some(flags_array)) if *flags_array.data_type() == LargeUtf8 => regexp_count_inner(
231            values.as_string::<i64>(),
232            regex_array.as_string::<i64>(),
233            is_regex_scalar,
234            start_array.map(|start| start.as_primitive::<Int64Type>()),
235            is_start_scalar,
236            Some(flags_array.as_string::<i64>()),
237            is_flags_scalar,
238        ),
239        (Utf8View, Utf8View, None) => regexp_count_inner(
240            values.as_string_view(),
241            regex_array.as_string_view(),
242            is_regex_scalar,
243            start_array.map(|start| start.as_primitive::<Int64Type>()),
244            is_start_scalar,
245            None,
246            is_flags_scalar,
247        ),
248        (Utf8View, Utf8View, Some(flags_array)) if *flags_array.data_type() == Utf8View => regexp_count_inner(
249            values.as_string_view(),
250            regex_array.as_string_view(),
251            is_regex_scalar,
252            start_array.map(|start| start.as_primitive::<Int64Type>()),
253            is_start_scalar,
254            Some(flags_array.as_string_view()),
255            is_flags_scalar,
256        ),
257        _ => Err(ArrowError::ComputeError(
258            "regexp_count() expected the input arrays to be of type Utf8, LargeUtf8, or Utf8View and the data types of the values, regex_array, and flags_array to match".to_string(),
259        )),
260    }
261}
262
263pub fn regexp_count_inner<'a, S>(
264    values: S,
265    regex_array: S,
266    is_regex_scalar: bool,
267    start_array: Option<&Int64Array>,
268    is_start_scalar: bool,
269    flags_array: Option<S>,
270    is_flags_scalar: bool,
271) -> Result<ArrayRef, ArrowError>
272where
273    S: StringArrayType<'a>,
274{
275    let (regex_scalar, is_regex_scalar) = if is_regex_scalar || regex_array.len() == 1 {
276        (Some(regex_array.value(0)), true)
277    } else {
278        (None, false)
279    };
280
281    let (start_array, start_scalar, is_start_scalar) =
282        if let Some(start_array) = start_array {
283            if is_start_scalar || start_array.len() == 1 {
284                (None, Some(start_array.value(0)), true)
285            } else {
286                (Some(start_array), None, false)
287            }
288        } else {
289            (None, Some(1), true)
290        };
291
292    let (flags_array, flags_scalar, is_flags_scalar) =
293        if let Some(flags_array) = flags_array {
294            if is_flags_scalar || flags_array.len() == 1 {
295                (None, Some(flags_array.value(0)), true)
296            } else {
297                (Some(flags_array), None, false)
298            }
299        } else {
300            (None, None, true)
301        };
302
303    let mut regex_cache = HashMap::new();
304
305    match (is_regex_scalar, is_start_scalar, is_flags_scalar) {
306        (true, true, true) => {
307            let regex = match regex_scalar {
308                None | Some("") => {
309                    return Ok(Arc::new(Int64Array::from(vec![0; values.len()])))
310                }
311                Some(regex) => regex,
312            };
313
314            let pattern = compile_regex(regex, flags_scalar)?;
315
316            Ok(Arc::new(
317                values
318                    .iter()
319                    .map(|value| count_matches(value, &pattern, start_scalar))
320                    .collect::<Result<Int64Array, ArrowError>>()?,
321            ))
322        }
323        (true, true, false) => {
324            let regex = match regex_scalar {
325                None | Some("") => {
326                    return Ok(Arc::new(Int64Array::from(vec![0; values.len()])))
327                }
328                Some(regex) => regex,
329            };
330
331            let flags_array = flags_array.unwrap();
332            if values.len() != flags_array.len() {
333                return Err(ArrowError::ComputeError(format!(
334                    "flags_array must be the same length as values array; got {} and {}",
335                    flags_array.len(),
336                    values.len(),
337                )));
338            }
339
340            Ok(Arc::new(
341                values
342                    .iter()
343                    .zip(flags_array.iter())
344                    .map(|(value, flags)| {
345                        let pattern =
346                            compile_and_cache_regex(regex, flags, &mut regex_cache)?;
347                        count_matches(value, pattern, start_scalar)
348                    })
349                    .collect::<Result<Int64Array, ArrowError>>()?,
350            ))
351        }
352        (true, false, true) => {
353            let regex = match regex_scalar {
354                None | Some("") => {
355                    return Ok(Arc::new(Int64Array::from(vec![0; values.len()])))
356                }
357                Some(regex) => regex,
358            };
359
360            let pattern = compile_regex(regex, flags_scalar)?;
361
362            let start_array = start_array.unwrap();
363
364            Ok(Arc::new(
365                values
366                    .iter()
367                    .zip(start_array.iter())
368                    .map(|(value, start)| count_matches(value, &pattern, start))
369                    .collect::<Result<Int64Array, ArrowError>>()?,
370            ))
371        }
372        (true, false, false) => {
373            let regex = match regex_scalar {
374                None | Some("") => {
375                    return Ok(Arc::new(Int64Array::from(vec![0; values.len()])))
376                }
377                Some(regex) => regex,
378            };
379
380            let flags_array = flags_array.unwrap();
381            if values.len() != flags_array.len() {
382                return Err(ArrowError::ComputeError(format!(
383                    "flags_array must be the same length as values array; got {} and {}",
384                    flags_array.len(),
385                    values.len(),
386                )));
387            }
388
389            Ok(Arc::new(
390                izip!(
391                    values.iter(),
392                    start_array.unwrap().iter(),
393                    flags_array.iter()
394                )
395                .map(|(value, start, flags)| {
396                    let pattern =
397                        compile_and_cache_regex(regex, flags, &mut regex_cache)?;
398
399                    count_matches(value, pattern, start)
400                })
401                .collect::<Result<Int64Array, ArrowError>>()?,
402            ))
403        }
404        (false, true, true) => {
405            if values.len() != regex_array.len() {
406                return Err(ArrowError::ComputeError(format!(
407                    "regex_array must be the same length as values array; got {} and {}",
408                    regex_array.len(),
409                    values.len(),
410                )));
411            }
412
413            Ok(Arc::new(
414                values
415                    .iter()
416                    .zip(regex_array.iter())
417                    .map(|(value, regex)| {
418                        let regex = match regex {
419                            None | Some("") => return Ok(0),
420                            Some(regex) => regex,
421                        };
422
423                        let pattern = compile_and_cache_regex(
424                            regex,
425                            flags_scalar,
426                            &mut regex_cache,
427                        )?;
428                        count_matches(value, pattern, start_scalar)
429                    })
430                    .collect::<Result<Int64Array, ArrowError>>()?,
431            ))
432        }
433        (false, true, false) => {
434            if values.len() != regex_array.len() {
435                return Err(ArrowError::ComputeError(format!(
436                    "regex_array must be the same length as values array; got {} and {}",
437                    regex_array.len(),
438                    values.len(),
439                )));
440            }
441
442            let flags_array = flags_array.unwrap();
443            if values.len() != flags_array.len() {
444                return Err(ArrowError::ComputeError(format!(
445                    "flags_array must be the same length as values array; got {} and {}",
446                    flags_array.len(),
447                    values.len(),
448                )));
449            }
450
451            Ok(Arc::new(
452                izip!(values.iter(), regex_array.iter(), flags_array.iter())
453                    .map(|(value, regex, flags)| {
454                        let regex = match regex {
455                            None | Some("") => return Ok(0),
456                            Some(regex) => regex,
457                        };
458
459                        let pattern =
460                            compile_and_cache_regex(regex, flags, &mut regex_cache)?;
461
462                        count_matches(value, pattern, start_scalar)
463                    })
464                    .collect::<Result<Int64Array, ArrowError>>()?,
465            ))
466        }
467        (false, false, true) => {
468            if values.len() != regex_array.len() {
469                return Err(ArrowError::ComputeError(format!(
470                    "regex_array must be the same length as values array; got {} and {}",
471                    regex_array.len(),
472                    values.len(),
473                )));
474            }
475
476            let start_array = start_array.unwrap();
477            if values.len() != start_array.len() {
478                return Err(ArrowError::ComputeError(format!(
479                    "start_array must be the same length as values array; got {} and {}",
480                    start_array.len(),
481                    values.len(),
482                )));
483            }
484
485            Ok(Arc::new(
486                izip!(values.iter(), regex_array.iter(), start_array.iter())
487                    .map(|(value, regex, start)| {
488                        let regex = match regex {
489                            None | Some("") => return Ok(0),
490                            Some(regex) => regex,
491                        };
492
493                        let pattern = compile_and_cache_regex(
494                            regex,
495                            flags_scalar,
496                            &mut regex_cache,
497                        )?;
498                        count_matches(value, pattern, start)
499                    })
500                    .collect::<Result<Int64Array, ArrowError>>()?,
501            ))
502        }
503        (false, false, false) => {
504            if values.len() != regex_array.len() {
505                return Err(ArrowError::ComputeError(format!(
506                    "regex_array must be the same length as values array; got {} and {}",
507                    regex_array.len(),
508                    values.len(),
509                )));
510            }
511
512            let start_array = start_array.unwrap();
513            if values.len() != start_array.len() {
514                return Err(ArrowError::ComputeError(format!(
515                    "start_array must be the same length as values array; got {} and {}",
516                    start_array.len(),
517                    values.len(),
518                )));
519            }
520
521            let flags_array = flags_array.unwrap();
522            if values.len() != flags_array.len() {
523                return Err(ArrowError::ComputeError(format!(
524                    "flags_array must be the same length as values array; got {} and {}",
525                    flags_array.len(),
526                    values.len(),
527                )));
528            }
529
530            Ok(Arc::new(
531                izip!(
532                    values.iter(),
533                    regex_array.iter(),
534                    start_array.iter(),
535                    flags_array.iter()
536                )
537                .map(|(value, regex, start, flags)| {
538                    let regex = match regex {
539                        None | Some("") => return Ok(0),
540                        Some(regex) => regex,
541                    };
542
543                    let pattern =
544                        compile_and_cache_regex(regex, flags, &mut regex_cache)?;
545                    count_matches(value, pattern, start)
546                })
547                .collect::<Result<Int64Array, ArrowError>>()?,
548            ))
549        }
550    }
551}
552
553fn compile_and_cache_regex<'strings, 'cache>(
554    regex: &'strings str,
555    flags: Option<&'strings str>,
556    regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>,
557) -> Result<&'cache Regex, ArrowError>
558where
559    'strings: 'cache,
560{
561    let result = match regex_cache.entry((regex, flags)) {
562        Entry::Occupied(occupied_entry) => occupied_entry.into_mut(),
563        Entry::Vacant(vacant_entry) => {
564            let compiled = compile_regex(regex, flags)?;
565            vacant_entry.insert(compiled)
566        }
567    };
568    Ok(result)
569}
570
571fn compile_regex(regex: &str, flags: Option<&str>) -> Result<Regex, ArrowError> {
572    let pattern = match flags {
573        None | Some("") => regex.to_string(),
574        Some(flags) => {
575            if flags.contains("g") {
576                return Err(ArrowError::ComputeError(
577                    "regexp_count() does not support global flag".to_string(),
578                ));
579            }
580            format!("(?{}){}", flags, regex)
581        }
582    };
583
584    Regex::new(&pattern).map_err(|_| {
585        ArrowError::ComputeError(format!(
586            "Regular expression did not compile: {}",
587            pattern
588        ))
589    })
590}
591
592fn count_matches(
593    value: Option<&str>,
594    pattern: &Regex,
595    start: Option<i64>,
596) -> Result<i64, ArrowError> {
597    let value = match value {
598        None | Some("") => return Ok(0),
599        Some(value) => value,
600    };
601
602    if let Some(start) = start {
603        if start < 1 {
604            return Err(ArrowError::ComputeError(
605                "regexp_count() requires start to be 1 based".to_string(),
606            ));
607        }
608
609        let find_slice = value.chars().skip(start as usize - 1).collect::<String>();
610        let count = pattern.find_iter(find_slice.as_str()).count();
611        Ok(count as i64)
612    } else {
613        let count = pattern.find_iter(value).count();
614        Ok(count as i64)
615    }
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621    use arrow::array::{GenericStringArray, StringViewArray};
622    use datafusion_expr::ScalarFunctionArgs;
623
624    #[test]
625    fn test_regexp_count() {
626        test_case_sensitive_regexp_count_scalar();
627        test_case_sensitive_regexp_count_scalar_start();
628        test_case_insensitive_regexp_count_scalar_flags();
629        test_case_sensitive_regexp_count_start_scalar_complex();
630
631        test_case_sensitive_regexp_count_array::<GenericStringArray<i32>>();
632        test_case_sensitive_regexp_count_array::<GenericStringArray<i64>>();
633        test_case_sensitive_regexp_count_array::<StringViewArray>();
634
635        test_case_sensitive_regexp_count_array_start::<GenericStringArray<i32>>();
636        test_case_sensitive_regexp_count_array_start::<GenericStringArray<i64>>();
637        test_case_sensitive_regexp_count_array_start::<StringViewArray>();
638
639        test_case_insensitive_regexp_count_array_flags::<GenericStringArray<i32>>();
640        test_case_insensitive_regexp_count_array_flags::<GenericStringArray<i64>>();
641        test_case_insensitive_regexp_count_array_flags::<StringViewArray>();
642
643        test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i32>>();
644        test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i64>>();
645        test_case_sensitive_regexp_count_array_complex::<StringViewArray>();
646
647        test_case_regexp_count_cache_check::<GenericStringArray<i32>>();
648    }
649
650    fn test_case_sensitive_regexp_count_scalar() {
651        let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
652        let regex = "abc";
653        let expected: Vec<i64> = vec![0, 1, 2, 1, 3];
654
655        values.iter().enumerate().for_each(|(pos, &v)| {
656            // utf8
657            let v_sv = ScalarValue::Utf8(Some(v.to_string()));
658            let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
659            let expected = expected.get(pos).cloned();
660            let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
661                args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)],
662                number_rows: 2,
663                return_type: &Int64,
664            });
665            match re {
666                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
667                    assert_eq!(v, expected, "regexp_count scalar test failed");
668                }
669                _ => panic!("Unexpected result"),
670            }
671
672            // largeutf8
673            let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
674            let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
675            let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
676                args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)],
677                number_rows: 2,
678                return_type: &Int64,
679            });
680            match re {
681                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
682                    assert_eq!(v, expected, "regexp_count scalar test failed");
683                }
684                _ => panic!("Unexpected result"),
685            }
686
687            // utf8view
688            let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
689            let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
690            let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
691                args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)],
692                number_rows: 2,
693                return_type: &Int64,
694            });
695            match re {
696                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
697                    assert_eq!(v, expected, "regexp_count scalar test failed");
698                }
699                _ => panic!("Unexpected result"),
700            }
701        });
702    }
703
704    fn test_case_sensitive_regexp_count_scalar_start() {
705        let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
706        let regex = "abc";
707        let start = 2;
708        let expected: Vec<i64> = vec![0, 1, 1, 0, 2];
709
710        values.iter().enumerate().for_each(|(pos, &v)| {
711            // utf8
712            let v_sv = ScalarValue::Utf8(Some(v.to_string()));
713            let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
714            let start_sv = ScalarValue::Int64(Some(start));
715            let expected = expected.get(pos).cloned();
716            let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
717                args: vec![
718                    ColumnarValue::Scalar(v_sv),
719                    ColumnarValue::Scalar(regex_sv),
720                    ColumnarValue::Scalar(start_sv.clone()),
721                ],
722                number_rows: 3,
723                return_type: &Int64,
724            });
725            match re {
726                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
727                    assert_eq!(v, expected, "regexp_count scalar test failed");
728                }
729                _ => panic!("Unexpected result"),
730            }
731
732            // largeutf8
733            let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
734            let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
735            let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
736                args: vec![
737                    ColumnarValue::Scalar(v_sv),
738                    ColumnarValue::Scalar(regex_sv),
739                    ColumnarValue::Scalar(start_sv.clone()),
740                ],
741                number_rows: 3,
742                return_type: &Int64,
743            });
744            match re {
745                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
746                    assert_eq!(v, expected, "regexp_count scalar test failed");
747                }
748                _ => panic!("Unexpected result"),
749            }
750
751            // utf8view
752            let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
753            let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
754            let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
755                args: vec![
756                    ColumnarValue::Scalar(v_sv),
757                    ColumnarValue::Scalar(regex_sv),
758                    ColumnarValue::Scalar(start_sv.clone()),
759                ],
760                number_rows: 3,
761                return_type: &Int64,
762            });
763            match re {
764                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
765                    assert_eq!(v, expected, "regexp_count scalar test failed");
766                }
767                _ => panic!("Unexpected result"),
768            }
769        });
770    }
771
772    fn test_case_insensitive_regexp_count_scalar_flags() {
773        let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
774        let regex = "abc";
775        let start = 1;
776        let flags = "i";
777        let expected: Vec<i64> = vec![0, 1, 2, 2, 3];
778
779        values.iter().enumerate().for_each(|(pos, &v)| {
780            // utf8
781            let v_sv = ScalarValue::Utf8(Some(v.to_string()));
782            let regex_sv = ScalarValue::Utf8(Some(regex.to_string()));
783            let start_sv = ScalarValue::Int64(Some(start));
784            let flags_sv = ScalarValue::Utf8(Some(flags.to_string()));
785            let expected = expected.get(pos).cloned();
786            let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
787                args: vec![
788                    ColumnarValue::Scalar(v_sv),
789                    ColumnarValue::Scalar(regex_sv),
790                    ColumnarValue::Scalar(start_sv.clone()),
791                    ColumnarValue::Scalar(flags_sv.clone()),
792                ],
793                number_rows: 4,
794                return_type: &Int64,
795            });
796            match re {
797                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
798                    assert_eq!(v, expected, "regexp_count scalar test failed");
799                }
800                _ => panic!("Unexpected result"),
801            }
802
803            // largeutf8
804            let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
805            let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string()));
806            let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string()));
807            let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
808                args: vec![
809                    ColumnarValue::Scalar(v_sv),
810                    ColumnarValue::Scalar(regex_sv),
811                    ColumnarValue::Scalar(start_sv.clone()),
812                    ColumnarValue::Scalar(flags_sv.clone()),
813                ],
814                number_rows: 4,
815                return_type: &Int64,
816            });
817            match re {
818                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
819                    assert_eq!(v, expected, "regexp_count scalar test failed");
820                }
821                _ => panic!("Unexpected result"),
822            }
823
824            // utf8view
825            let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
826            let regex_sv = ScalarValue::Utf8View(Some(regex.to_string()));
827            let flags_sv = ScalarValue::Utf8View(Some(flags.to_string()));
828            let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
829                args: vec![
830                    ColumnarValue::Scalar(v_sv),
831                    ColumnarValue::Scalar(regex_sv),
832                    ColumnarValue::Scalar(start_sv.clone()),
833                    ColumnarValue::Scalar(flags_sv.clone()),
834                ],
835                number_rows: 4,
836                return_type: &Int64,
837            });
838            match re {
839                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
840                    assert_eq!(v, expected, "regexp_count scalar test failed");
841                }
842                _ => panic!("Unexpected result"),
843            }
844        });
845    }
846
847    fn test_case_sensitive_regexp_count_array<A>()
848    where
849        A: From<Vec<&'static str>> + Array + 'static,
850    {
851        let values = A::from(vec!["", "aabca", "abcabc", "abcAbcab", "abcabcAbc"]);
852        let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
853
854        let expected = Int64Array::from(vec![0, 1, 2, 2, 2]);
855
856        let re = regexp_count_func(&[Arc::new(values), Arc::new(regex)]).unwrap();
857        assert_eq!(re.as_ref(), &expected);
858    }
859
860    fn test_case_sensitive_regexp_count_array_start<A>()
861    where
862        A: From<Vec<&'static str>> + Array + 'static,
863    {
864        let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
865        let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
866        let start = Int64Array::from(vec![1, 2, 3, 4, 5]);
867
868        let expected = Int64Array::from(vec![0, 0, 1, 1, 0]);
869
870        let re = regexp_count_func(&[Arc::new(values), Arc::new(regex), Arc::new(start)])
871            .unwrap();
872        assert_eq!(re.as_ref(), &expected);
873    }
874
875    fn test_case_insensitive_regexp_count_array_flags<A>()
876    where
877        A: From<Vec<&'static str>> + Array + 'static,
878    {
879        let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
880        let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
881        let start = Int64Array::from(vec![1]);
882        let flags = A::from(vec!["", "i", "", "", "i"]);
883
884        let expected = Int64Array::from(vec![0, 1, 2, 2, 3]);
885
886        let re = regexp_count_func(&[
887            Arc::new(values),
888            Arc::new(regex),
889            Arc::new(start),
890            Arc::new(flags),
891        ])
892        .unwrap();
893        assert_eq!(re.as_ref(), &expected);
894    }
895
896    fn test_case_sensitive_regexp_count_start_scalar_complex() {
897        let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"];
898        let regex = ["", "abc", "a", "bc", "ab"];
899        let start = 5;
900        let flags = ["", "i", "", "", "i"];
901        let expected: Vec<i64> = vec![0, 0, 0, 1, 1];
902
903        values.iter().enumerate().for_each(|(pos, &v)| {
904            // utf8
905            let v_sv = ScalarValue::Utf8(Some(v.to_string()));
906            let regex_sv = ScalarValue::Utf8(regex.get(pos).map(|s| s.to_string()));
907            let start_sv = ScalarValue::Int64(Some(start));
908            let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string()));
909            let expected = expected.get(pos).cloned();
910            let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
911                args: vec![
912                    ColumnarValue::Scalar(v_sv),
913                    ColumnarValue::Scalar(regex_sv),
914                    ColumnarValue::Scalar(start_sv.clone()),
915                    ColumnarValue::Scalar(flags_sv.clone()),
916                ],
917                number_rows: 4,
918                return_type: &Int64,
919            });
920            match re {
921                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
922                    assert_eq!(v, expected, "regexp_count scalar test failed");
923                }
924                _ => panic!("Unexpected result"),
925            }
926
927            // largeutf8
928            let v_sv = ScalarValue::LargeUtf8(Some(v.to_string()));
929            let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string()));
930            let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string()));
931            let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
932                args: vec![
933                    ColumnarValue::Scalar(v_sv),
934                    ColumnarValue::Scalar(regex_sv),
935                    ColumnarValue::Scalar(start_sv.clone()),
936                    ColumnarValue::Scalar(flags_sv.clone()),
937                ],
938                number_rows: 4,
939                return_type: &Int64,
940            });
941            match re {
942                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
943                    assert_eq!(v, expected, "regexp_count scalar test failed");
944                }
945                _ => panic!("Unexpected result"),
946            }
947
948            // utf8view
949            let v_sv = ScalarValue::Utf8View(Some(v.to_string()));
950            let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string()));
951            let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string()));
952            let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs {
953                args: vec![
954                    ColumnarValue::Scalar(v_sv),
955                    ColumnarValue::Scalar(regex_sv),
956                    ColumnarValue::Scalar(start_sv.clone()),
957                    ColumnarValue::Scalar(flags_sv.clone()),
958                ],
959                number_rows: 4,
960                return_type: &Int64,
961            });
962            match re {
963                Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => {
964                    assert_eq!(v, expected, "regexp_count scalar test failed");
965                }
966                _ => panic!("Unexpected result"),
967            }
968        });
969    }
970
971    fn test_case_sensitive_regexp_count_array_complex<A>()
972    where
973        A: From<Vec<&'static str>> + Array + 'static,
974    {
975        let values = A::from(vec!["", "aAbca", "abcabc", "abcAbcab", "abcabcAbc"]);
976        let regex = A::from(vec!["", "abc", "a", "bc", "ab"]);
977        let start = Int64Array::from(vec![1, 2, 3, 4, 5]);
978        let flags = A::from(vec!["", "i", "", "", "i"]);
979
980        let expected = Int64Array::from(vec![0, 1, 1, 1, 1]);
981
982        let re = regexp_count_func(&[
983            Arc::new(values),
984            Arc::new(regex),
985            Arc::new(start),
986            Arc::new(flags),
987        ])
988        .unwrap();
989        assert_eq!(re.as_ref(), &expected);
990    }
991
992    fn test_case_regexp_count_cache_check<A>()
993    where
994        A: From<Vec<&'static str>> + Array + 'static,
995    {
996        let values = A::from(vec!["aaa", "Aaa", "aaa"]);
997        let regex = A::from(vec!["aaa", "aaa", "aaa"]);
998        let start = Int64Array::from(vec![1, 1, 1]);
999        let flags = A::from(vec!["", "i", ""]);
1000
1001        let expected = Int64Array::from(vec![1, 1, 1]);
1002
1003        let re = regexp_count_func(&[
1004            Arc::new(values),
1005            Arc::new(regex),
1006            Arc::new(start),
1007            Arc::new(flags),
1008        ])
1009        .unwrap();
1010        assert_eq!(re.as_ref(), &expected);
1011    }
1012}