use std::any::Any;
use std::sync::Arc;
use arrow::array::{ArrayRef, OffsetSizeTrait, StringBuilder};
use arrow::datatypes::DataType;
use datafusion_common::cast::{as_generic_string_array, as_int64_array};
use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
use crate::utils::{make_scalar_function, utf8_to_str_type};
#[derive(Debug)]
pub struct SubstrIndexFunc {
signature: Signature,
aliases: Vec<String>,
}
impl Default for SubstrIndexFunc {
fn default() -> Self {
Self::new()
}
}
impl SubstrIndexFunc {
pub fn new() -> Self {
use DataType::*;
Self {
signature: Signature::one_of(
vec![
Exact(vec![Utf8, Utf8, Int64]),
Exact(vec![LargeUtf8, LargeUtf8, Int64]),
],
Volatility::Immutable,
),
aliases: vec![String::from("substring_index")],
}
}
}
impl ScalarUDFImpl for SubstrIndexFunc {
fn as_any(&self) -> &dyn Any {
self
}
fn name(&self) -> &str {
"substr_index"
}
fn signature(&self) -> &Signature {
&self.signature
}
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
utf8_to_str_type(&arg_types[0], "substr_index")
}
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
match args[0].data_type() {
DataType::Utf8 => make_scalar_function(substr_index::<i32>, vec![])(args),
DataType::LargeUtf8 => {
make_scalar_function(substr_index::<i64>, vec![])(args)
}
other => {
exec_err!("Unsupported data type {other:?} for function substr_index")
}
}
}
fn aliases(&self) -> &[String] {
&self.aliases
}
}
pub fn substr_index<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() != 3 {
return exec_err!(
"substr_index was called with {} arguments. It requires 3.",
args.len()
);
}
let string_array = as_generic_string_array::<T>(&args[0])?;
let delimiter_array = as_generic_string_array::<T>(&args[1])?;
let count_array = as_int64_array(&args[2])?;
let mut builder = StringBuilder::new();
string_array
.iter()
.zip(delimiter_array.iter())
.zip(count_array.iter())
.for_each(|((string, delimiter), n)| match (string, delimiter, n) {
(Some(string), Some(delimiter), Some(n)) => {
if n == 0 || string.is_empty() || delimiter.is_empty() {
builder.append_value("");
return;
}
let occurrences = usize::try_from(n.unsigned_abs()).unwrap_or(usize::MAX);
let length = if n > 0 {
let split = string.split(delimiter);
split
.take(occurrences)
.map(|s| s.len() + delimiter.len())
.sum::<usize>()
- delimiter.len()
} else {
let split = string.rsplit(delimiter);
split
.take(occurrences)
.map(|s| s.len() + delimiter.len())
.sum::<usize>()
- delimiter.len()
};
if n > 0 {
match string.get(..length) {
Some(substring) => builder.append_value(substring),
None => builder.append_null(),
}
} else {
match string.get(string.len().saturating_sub(length)..) {
Some(substring) => builder.append_value(substring),
None => builder.append_null(),
}
}
}
_ => builder.append_null(),
});
Ok(Arc::new(builder.finish()) as ArrayRef)
}
#[cfg(test)]
mod tests {
use arrow::array::{Array, StringArray};
use arrow::datatypes::DataType::Utf8;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
use crate::unicode::substrindex::SubstrIndexFunc;
use crate::utils::test::test_function;
#[test]
fn test_functions() -> Result<()> {
test_function!(
SubstrIndexFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
ColumnarValue::Scalar(ScalarValue::from(".")),
ColumnarValue::Scalar(ScalarValue::from(1i64)),
],
Ok(Some("www")),
&str,
Utf8,
StringArray
);
test_function!(
SubstrIndexFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
ColumnarValue::Scalar(ScalarValue::from(".")),
ColumnarValue::Scalar(ScalarValue::from(2i64)),
],
Ok(Some("www.apache")),
&str,
Utf8,
StringArray
);
test_function!(
SubstrIndexFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
ColumnarValue::Scalar(ScalarValue::from(".")),
ColumnarValue::Scalar(ScalarValue::from(-2i64)),
],
Ok(Some("apache.org")),
&str,
Utf8,
StringArray
);
test_function!(
SubstrIndexFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
ColumnarValue::Scalar(ScalarValue::from(".")),
ColumnarValue::Scalar(ScalarValue::from(-1i64)),
],
Ok(Some("org")),
&str,
Utf8,
StringArray
);
test_function!(
SubstrIndexFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
ColumnarValue::Scalar(ScalarValue::from(".")),
ColumnarValue::Scalar(ScalarValue::from(0i64)),
],
Ok(Some("")),
&str,
Utf8,
StringArray
);
test_function!(
SubstrIndexFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::from("")),
ColumnarValue::Scalar(ScalarValue::from(".")),
ColumnarValue::Scalar(ScalarValue::from(1i64)),
],
Ok(Some("")),
&str,
Utf8,
StringArray
);
test_function!(
SubstrIndexFunc::new(),
&[
ColumnarValue::Scalar(ScalarValue::from("www.apache.org")),
ColumnarValue::Scalar(ScalarValue::from("")),
ColumnarValue::Scalar(ScalarValue::from(1i64)),
],
Ok(Some("")),
&str,
Utf8,
StringArray
);
Ok(())
}
}