use arrow::array::{
new_null_array, Array, ArrayDataBuilder, ArrayRef, BufferBuilder, GenericStringArray,
OffsetSizeTrait,
};
use arrow::compute;
use datafusion_common::plan_err;
use datafusion_common::{
cast::as_generic_string_array, internal_err, DataFusionError, Result,
};
use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation};
use hashbrown::HashMap;
use regex::Regex;
use std::sync::{Arc, OnceLock};
use crate::functions::{make_scalar_function, make_scalar_function_with_hints, Hint};
macro_rules! fetch_string_arg {
($ARG:expr, $NAME:expr, $T:ident, $EARLY_ABORT:ident) => {{
let array = as_generic_string_array::<T>($ARG)?;
if array.len() == 0 || array.is_null(0) {
return $EARLY_ABORT(array);
} else {
array.value(0)
}
}};
}
pub fn regexp_match<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
match args.len() {
2 => {
let values = as_generic_string_array::<T>(&args[0])?;
let regex = as_generic_string_array::<T>(&args[1])?;
compute::regexp_match(values, regex, None).map_err(DataFusionError::ArrowError)
}
3 => {
let values = as_generic_string_array::<T>(&args[0])?;
let regex = as_generic_string_array::<T>(&args[1])?;
let flags = Some(as_generic_string_array::<T>(&args[2])?);
match flags {
Some(f) if f.iter().any(|s| s == Some("g")) => {
plan_err!("regexp_match() does not support the \"global\" option")
},
_ => compute::regexp_match(values, regex, flags).map_err(DataFusionError::ArrowError),
}
}
other => internal_err!(
"regexp_match was called with {other} arguments. It requires at least 2 and at most 3."
),
}
}
fn regex_replace_posix_groups(replacement: &str) -> String {
fn capture_groups_re() -> &'static Regex {
static CAPTURE_GROUPS_RE_LOCK: OnceLock<Regex> = OnceLock::new();
CAPTURE_GROUPS_RE_LOCK.get_or_init(|| Regex::new(r"(\\)(\d*)").unwrap())
}
capture_groups_re()
.replace_all(replacement, "$${$2}")
.into_owned()
}
pub fn regexp_replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let mut patterns: HashMap<String, Regex> = HashMap::new();
match args.len() {
3 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let pattern_array = as_generic_string_array::<T>(&args[1])?;
let replacement_array = as_generic_string_array::<T>(&args[2])?;
let result = string_array
.iter()
.zip(pattern_array.iter())
.zip(replacement_array.iter())
.map(|((string, pattern), replacement)| match (string, pattern, replacement) {
(Some(string), Some(pattern), Some(replacement)) => {
let replacement = regex_replace_posix_groups(replacement);
let re = match patterns.get(pattern) {
Some(re) => Ok(re.clone()),
None => {
match Regex::new(pattern) {
Ok(re) => {
patterns.insert(pattern.to_string(), re.clone());
Ok(re)
},
Err(err) => Err(DataFusionError::External(Box::new(err))),
}
}
};
Some(re.map(|re| re.replace(string, replacement.as_str()))).transpose()
}
_ => Ok(None)
})
.collect::<Result<GenericStringArray<T>>>()?;
Ok(Arc::new(result) as ArrayRef)
}
4 => {
let string_array = as_generic_string_array::<T>(&args[0])?;
let pattern_array = as_generic_string_array::<T>(&args[1])?;
let replacement_array = as_generic_string_array::<T>(&args[2])?;
let flags_array = as_generic_string_array::<T>(&args[3])?;
let result = string_array
.iter()
.zip(pattern_array.iter())
.zip(replacement_array.iter())
.zip(flags_array.iter())
.map(|(((string, pattern), replacement), flags)| match (string, pattern, replacement, flags) {
(Some(string), Some(pattern), Some(replacement), Some(flags)) => {
let replacement = regex_replace_posix_groups(replacement);
let (pattern, replace_all) = if flags == "g" {
(pattern.to_string(), true)
} else if flags.contains('g') {
(format!("(?{}){}", flags.to_string().replace('g', ""), pattern), true)
} else {
(format!("(?{flags}){pattern}"), false)
};
let re = match patterns.get(&pattern) {
Some(re) => Ok(re.clone()),
None => {
match Regex::new(pattern.as_str()) {
Ok(re) => {
patterns.insert(pattern, re.clone());
Ok(re)
},
Err(err) => Err(DataFusionError::External(Box::new(err))),
}
}
};
Some(re.map(|re| {
if replace_all {
re.replace_all(string, replacement.as_str())
} else {
re.replace(string, replacement.as_str())
}
})).transpose()
}
_ => Ok(None)
})
.collect::<Result<GenericStringArray<T>>>()?;
Ok(Arc::new(result) as ArrayRef)
}
other => internal_err!(
"regexp_replace was called with {other} arguments. It requires at least 3 and at most 4."
),
}
}
fn _regexp_replace_early_abort<T: OffsetSizeTrait>(
input_array: &GenericStringArray<T>,
) -> Result<ArrayRef> {
Ok(new_null_array(input_array.data_type(), input_array.len()))
}
fn _regexp_replace_static_pattern_replace<T: OffsetSizeTrait>(
args: &[ArrayRef],
) -> Result<ArrayRef> {
let string_array = as_generic_string_array::<T>(&args[0])?;
let pattern = fetch_string_arg!(&args[1], "pattern", T, _regexp_replace_early_abort);
let replacement =
fetch_string_arg!(&args[2], "replacement", T, _regexp_replace_early_abort);
let flags = match args.len() {
3 => None,
4 => Some(fetch_string_arg!(&args[3], "flags", T, _regexp_replace_early_abort)),
other => {
return internal_err!(
"regexp_replace was called with {other} arguments. It requires at least 3 and at most 4."
)
}
};
let (pattern, limit) = match flags {
Some("g") => (pattern.to_string(), 0),
Some(flags) => (
format!("(?{}){}", flags.to_string().replace('g', ""), pattern),
!flags.contains('g') as usize,
),
None => (pattern.to_string(), 1),
};
let re =
Regex::new(&pattern).map_err(|err| DataFusionError::External(Box::new(err)))?;
let replacement = regex_replace_posix_groups(replacement);
let mut vals = BufferBuilder::<u8>::new({
let offsets = string_array.value_offsets();
(offsets[string_array.len()] - offsets[0])
.to_usize()
.unwrap()
});
let mut new_offsets = BufferBuilder::<T>::new(string_array.len() + 1);
new_offsets.append(T::zero());
string_array.iter().for_each(|val| {
if let Some(val) = val {
let result = re.replacen(val, limit, replacement.as_str());
vals.append_slice(result.as_bytes());
}
new_offsets.append(T::from_usize(vals.len()).unwrap());
});
let data = ArrayDataBuilder::new(GenericStringArray::<T>::DATA_TYPE)
.len(string_array.len())
.nulls(string_array.nulls().cloned())
.buffers(vec![new_offsets.finish(), vals.finish()])
.build()?;
let result_array = GenericStringArray::<T>::from(data);
Ok(Arc::new(result_array) as ArrayRef)
}
pub fn specialize_regexp_replace<T: OffsetSizeTrait>(
args: &[ColumnarValue],
) -> Result<ScalarFunctionImplementation> {
let (is_source_scalar, is_pattern_scalar, is_replacement_scalar, is_flags_scalar) = (
matches!(args[0], ColumnarValue::Scalar(_)),
matches!(args[1], ColumnarValue::Scalar(_)),
matches!(args[2], ColumnarValue::Scalar(_)),
matches!(args.get(3), Some(ColumnarValue::Scalar(_)) | None),
);
match (
is_source_scalar,
is_pattern_scalar,
is_replacement_scalar,
is_flags_scalar,
) {
(_, true, true, true) => Ok(make_scalar_function_with_hints(
_regexp_replace_static_pattern_replace::<T>,
vec![
Hint::Pad,
Hint::AcceptsSingular,
Hint::AcceptsSingular,
Hint::AcceptsSingular,
],
)),
(_, _, _, _) => Ok(make_scalar_function(regexp_replace::<T>)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::*;
use datafusion_common::ScalarValue;
#[test]
fn test_case_sensitive_regexp_match() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns =
StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]);
let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.values().append_value("a");
expected_builder.append(true);
expected_builder.append(false);
expected_builder.values().append_value("b");
expected_builder.append(true);
expected_builder.append(false);
expected_builder.append(false);
let expected = expected_builder.finish();
let re = regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns)]).unwrap();
assert_eq!(re.as_ref(), &expected);
}
#[test]
fn test_case_insensitive_regexp_match() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns =
StringArray::from(vec!["^(a)", "^(A)", "(b|d)", "(B|D)", "^(b|c)"]);
let flags = StringArray::from(vec!["i"; 5]);
let elem_builder: GenericStringBuilder<i32> = GenericStringBuilder::new();
let mut expected_builder = ListBuilder::new(elem_builder);
expected_builder.values().append_value("a");
expected_builder.append(true);
expected_builder.values().append_value("a");
expected_builder.append(true);
expected_builder.values().append_value("b");
expected_builder.append(true);
expected_builder.values().append_value("b");
expected_builder.append(true);
expected_builder.append(false);
let expected = expected_builder.finish();
let re =
regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
.unwrap();
assert_eq!(re.as_ref(), &expected);
}
#[test]
fn test_unsupported_global_flag_regexp_match() {
let values = StringArray::from(vec!["abc"]);
let patterns = StringArray::from(vec!["^(a)"]);
let flags = StringArray::from(vec!["g"]);
let re_err =
regexp_match::<i32>(&[Arc::new(values), Arc::new(patterns), Arc::new(flags)])
.expect_err("unsupported flag should have failed");
assert_eq!(re_err.strip_backtrace(), "Error during planning: regexp_match() does not support the \"global\" option");
}
#[test]
fn test_static_pattern_regexp_replace() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns = StringArray::from(vec!["b"; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);
let expected = StringArray::from(vec!["afooc"; 5]);
let re = _regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
])
.unwrap();
assert_eq!(re.as_ref(), &expected);
}
#[test]
fn test_static_pattern_regexp_replace_with_flags() {
let values = StringArray::from(vec!["abc", "ABC", "aBc", "AbC", "aBC"]);
let patterns = StringArray::from(vec!["b"; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);
let flags = StringArray::from(vec!["i"; 5]);
let expected =
StringArray::from(vec!["afooc", "AfooC", "afooc", "AfooC", "afooC"]);
let re = _regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Arc::new(flags),
])
.unwrap();
assert_eq!(re.as_ref(), &expected);
}
#[test]
fn test_static_pattern_regexp_replace_early_abort() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns = StringArray::from(vec![None::<&str>; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);
let expected = StringArray::from(vec![None::<&str>; 5]);
let re = _regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
])
.unwrap();
assert_eq!(re.as_ref(), &expected);
}
#[test]
fn test_static_pattern_regexp_replace_early_abort_when_empty() {
let values = StringArray::from(Vec::<Option<&str>>::new());
let patterns = StringArray::from(Vec::<Option<&str>>::new());
let replacements = StringArray::from(Vec::<Option<&str>>::new());
let expected = StringArray::from(Vec::<Option<&str>>::new());
let re = _regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
])
.unwrap();
assert_eq!(re.as_ref(), &expected);
}
#[test]
fn test_static_pattern_regexp_replace_early_abort_flags() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns = StringArray::from(vec!["a"; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);
let flags = StringArray::from(vec![None::<&str>; 5]);
let expected = StringArray::from(vec![None::<&str>; 5]);
let re = _regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
Arc::new(flags),
])
.unwrap();
assert_eq!(re.as_ref(), &expected);
}
#[test]
fn test_static_pattern_regexp_replace_pattern_error() {
let values = StringArray::from(vec!["abc"; 5]);
let patterns = StringArray::from(vec!["["; 5]);
let replacements = StringArray::from(vec!["foo"; 5]);
let re = _regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
]);
let pattern_err = re.expect_err("broken pattern should have failed");
assert_eq!(
pattern_err.strip_backtrace(),
"External error: regex parse error:\n [\n ^\nerror: unclosed character class"
);
}
#[test]
fn test_regexp_can_specialize_all_cases() {
macro_rules! make_scalar {
() => {
ColumnarValue::Scalar(ScalarValue::Utf8(Some("foo".to_string())))
};
}
macro_rules! make_array {
() => {
ColumnarValue::Array(
Arc::new(StringArray::from(vec!["bar"; 2])) as ArrayRef
)
};
}
for source in [make_scalar!(), make_array!()] {
for pattern in [make_scalar!(), make_array!()] {
for replacement in [make_scalar!(), make_array!()] {
for flags in [Some(make_scalar!()), Some(make_array!()), None] {
let mut args =
vec![source.clone(), pattern.clone(), replacement.clone()];
if let Some(flags) = flags {
args.push(flags.clone());
}
let regex_func = specialize_regexp_replace::<i32>(&args);
assert!(regex_func.is_ok());
}
}
}
}
}
#[test]
fn test_static_pattern_regexp_replace_with_null_buffers() {
let values = StringArray::from(vec![
Some("a"),
None,
Some("b"),
None,
Some("a"),
None,
None,
Some("c"),
]);
let patterns = StringArray::from(vec!["a"; 1]);
let replacements = StringArray::from(vec!["foo"; 1]);
let expected = StringArray::from(vec![
Some("foo"),
None,
Some("b"),
None,
Some("foo"),
None,
None,
Some("c"),
]);
let re = _regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
])
.unwrap();
assert_eq!(re.as_ref(), &expected);
assert_eq!(re.null_count(), 4);
}
#[test]
fn test_static_pattern_regexp_replace_with_sliced_null_buffer() {
let values = StringArray::from(vec![
Some("a"),
None,
Some("b"),
None,
Some("a"),
None,
None,
Some("c"),
]);
let values = values.slice(2, 5);
let patterns = StringArray::from(vec!["a"; 1]);
let replacements = StringArray::from(vec!["foo"; 1]);
let expected = StringArray::from(vec![Some("b"), None, Some("foo"), None, None]);
let re = _regexp_replace_static_pattern_replace::<i32>(&[
Arc::new(values),
Arc::new(patterns),
Arc::new(replacements),
])
.unwrap();
assert_eq!(re.as_ref(), &expected);
assert_eq!(re.null_count(), 3);
}
}