1use crate::utils::utf8_to_str_type;
19use arrow::array::{
20 ArrayRef, GenericStringArray, Int64Array, OffsetSizeTrait, StringArrayType,
21 StringViewArray,
22};
23use arrow::array::{AsArray, GenericStringBuilder};
24use arrow::datatypes::DataType;
25use datafusion_common::cast::as_int64_array;
26use datafusion_common::ScalarValue;
27use datafusion_common::{exec_err, DataFusionError, Result};
28use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility};
29use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
30use datafusion_macros::user_doc;
31use std::any::Any;
32use std::sync::Arc;
33
34#[user_doc(
35 doc_section(label = "String Functions"),
36 description = "Splits a string based on a specified delimiter and returns the substring in the specified position.",
37 syntax_example = "split_part(str, delimiter, pos)",
38 sql_example = r#"```sql
39> select split_part('1.2.3.4.5', '.', 3);
40+--------------------------------------------------+
41| split_part(Utf8("1.2.3.4.5"),Utf8("."),Int64(3)) |
42+--------------------------------------------------+
43| 3 |
44+--------------------------------------------------+
45```"#,
46 standard_argument(name = "str", prefix = "String"),
47 argument(name = "delimiter", description = "String or character to split on."),
48 argument(name = "pos", description = "Position of the part to return.")
49)]
50#[derive(Debug)]
51pub struct SplitPartFunc {
52 signature: Signature,
53}
54
55impl Default for SplitPartFunc {
56 fn default() -> Self {
57 Self::new()
58 }
59}
60
61impl SplitPartFunc {
62 pub fn new() -> Self {
63 use DataType::*;
64 Self {
65 signature: Signature::one_of(
66 vec![
67 TypeSignature::Exact(vec![Utf8View, Utf8View, Int64]),
68 TypeSignature::Exact(vec![Utf8View, Utf8, Int64]),
69 TypeSignature::Exact(vec![Utf8View, LargeUtf8, Int64]),
70 TypeSignature::Exact(vec![Utf8, Utf8View, Int64]),
71 TypeSignature::Exact(vec![Utf8, Utf8, Int64]),
72 TypeSignature::Exact(vec![LargeUtf8, Utf8View, Int64]),
73 TypeSignature::Exact(vec![LargeUtf8, Utf8, Int64]),
74 TypeSignature::Exact(vec![Utf8, LargeUtf8, Int64]),
75 TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64]),
76 ],
77 Volatility::Immutable,
78 ),
79 }
80 }
81}
82
83impl ScalarUDFImpl for SplitPartFunc {
84 fn as_any(&self) -> &dyn Any {
85 self
86 }
87
88 fn name(&self) -> &str {
89 "split_part"
90 }
91
92 fn signature(&self) -> &Signature {
93 &self.signature
94 }
95
96 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
97 utf8_to_str_type(&arg_types[0], "split_part")
98 }
99
100 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
101 let ScalarFunctionArgs { args, .. } = args;
102
103 let len = args.iter().find_map(|arg| match arg {
105 ColumnarValue::Array(a) => Some(a.len()),
106 _ => None,
107 });
108
109 let inferred_length = len.unwrap_or(1);
110 let is_scalar = len.is_none();
111
112 let args = args
114 .iter()
115 .map(|arg| match arg {
116 ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(inferred_length),
117 ColumnarValue::Array(array) => Ok(Arc::clone(array)),
118 })
119 .collect::<Result<Vec<_>>>()?;
120
121 let n_array = as_int64_array(&args[2])?;
123 let result = match (args[0].data_type(), args[1].data_type()) {
124 (DataType::Utf8View, DataType::Utf8View) => {
125 split_part_impl::<&StringViewArray, &StringViewArray, i32>(
126 args[0].as_string_view(),
127 args[1].as_string_view(),
128 n_array,
129 )
130 }
131 (DataType::Utf8View, DataType::Utf8) => {
132 split_part_impl::<&StringViewArray, &GenericStringArray<i32>, i32>(
133 args[0].as_string_view(),
134 args[1].as_string::<i32>(),
135 n_array,
136 )
137 }
138 (DataType::Utf8View, DataType::LargeUtf8) => {
139 split_part_impl::<&StringViewArray, &GenericStringArray<i64>, i32>(
140 args[0].as_string_view(),
141 args[1].as_string::<i64>(),
142 n_array,
143 )
144 }
145 (DataType::Utf8, DataType::Utf8View) => {
146 split_part_impl::<&GenericStringArray<i32>, &StringViewArray, i32>(
147 args[0].as_string::<i32>(),
148 args[1].as_string_view(),
149 n_array,
150 )
151 }
152 (DataType::LargeUtf8, DataType::Utf8View) => {
153 split_part_impl::<&GenericStringArray<i64>, &StringViewArray, i64>(
154 args[0].as_string::<i64>(),
155 args[1].as_string_view(),
156 n_array,
157 )
158 }
159 (DataType::Utf8, DataType::Utf8) => {
160 split_part_impl::<&GenericStringArray<i32>, &GenericStringArray<i32>, i32>(
161 args[0].as_string::<i32>(),
162 args[1].as_string::<i32>(),
163 n_array,
164 )
165 }
166 (DataType::LargeUtf8, DataType::LargeUtf8) => {
167 split_part_impl::<&GenericStringArray<i64>, &GenericStringArray<i64>, i64>(
168 args[0].as_string::<i64>(),
169 args[1].as_string::<i64>(),
170 n_array,
171 )
172 }
173 (DataType::Utf8, DataType::LargeUtf8) => {
174 split_part_impl::<&GenericStringArray<i32>, &GenericStringArray<i64>, i32>(
175 args[0].as_string::<i32>(),
176 args[1].as_string::<i64>(),
177 n_array,
178 )
179 }
180 (DataType::LargeUtf8, DataType::Utf8) => {
181 split_part_impl::<&GenericStringArray<i64>, &GenericStringArray<i32>, i64>(
182 args[0].as_string::<i64>(),
183 args[1].as_string::<i32>(),
184 n_array,
185 )
186 }
187 _ => exec_err!("Unsupported combination of argument types for split_part"),
188 };
189 if is_scalar {
190 let result = result.and_then(|arr| ScalarValue::try_from_array(&arr, 0));
192 result.map(ColumnarValue::Scalar)
193 } else {
194 result.map(ColumnarValue::Array)
195 }
196 }
197
198 fn documentation(&self) -> Option<&Documentation> {
199 self.doc()
200 }
201}
202
203pub fn split_part_impl<'a, StringArrType, DelimiterArrType, StringArrayLen>(
205 string_array: StringArrType,
206 delimiter_array: DelimiterArrType,
207 n_array: &Int64Array,
208) -> Result<ArrayRef>
209where
210 StringArrType: StringArrayType<'a>,
211 DelimiterArrType: StringArrayType<'a>,
212 StringArrayLen: OffsetSizeTrait,
213{
214 let mut builder: GenericStringBuilder<StringArrayLen> = GenericStringBuilder::new();
215
216 string_array
217 .iter()
218 .zip(delimiter_array.iter())
219 .zip(n_array.iter())
220 .try_for_each(|((string, delimiter), n)| -> Result<(), DataFusionError> {
221 match (string, delimiter, n) {
222 (Some(string), Some(delimiter), Some(n)) => {
223 let split_string: Vec<&str> = string.split(delimiter).collect();
224 let len = split_string.len();
225
226 let index = match n.cmp(&0) {
227 std::cmp::Ordering::Less => len as i64 + n,
228 std::cmp::Ordering::Equal => {
229 return exec_err!("field position must not be zero");
230 }
231 std::cmp::Ordering::Greater => n - 1,
232 } as usize;
233
234 if index < len {
235 builder.append_value(split_string[index]);
236 } else {
237 builder.append_value("");
238 }
239 }
240 _ => builder.append_null(),
241 }
242 Ok(())
243 })?;
244
245 Ok(Arc::new(builder.finish()) as ArrayRef)
246}
247
248#[cfg(test)]
249mod tests {
250 use arrow::array::{Array, StringArray};
251 use arrow::datatypes::DataType::Utf8;
252
253 use datafusion_common::ScalarValue;
254 use datafusion_common::{exec_err, Result};
255 use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
256
257 use crate::string::split_part::SplitPartFunc;
258 use crate::utils::test::test_function;
259
260 #[test]
261 fn test_functions() -> Result<()> {
262 test_function!(
263 SplitPartFunc::new(),
264 vec![
265 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
266 "abc~@~def~@~ghi"
267 )))),
268 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
269 ColumnarValue::Scalar(ScalarValue::Int64(Some(2))),
270 ],
271 Ok(Some("def")),
272 &str,
273 Utf8,
274 StringArray
275 );
276 test_function!(
277 SplitPartFunc::new(),
278 vec![
279 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
280 "abc~@~def~@~ghi"
281 )))),
282 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
283 ColumnarValue::Scalar(ScalarValue::Int64(Some(20))),
284 ],
285 Ok(Some("")),
286 &str,
287 Utf8,
288 StringArray
289 );
290 test_function!(
291 SplitPartFunc::new(),
292 vec![
293 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
294 "abc~@~def~@~ghi"
295 )))),
296 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
297 ColumnarValue::Scalar(ScalarValue::Int64(Some(-1))),
298 ],
299 Ok(Some("ghi")),
300 &str,
301 Utf8,
302 StringArray
303 );
304 test_function!(
305 SplitPartFunc::new(),
306 vec![
307 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from(
308 "abc~@~def~@~ghi"
309 )))),
310 ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("~@~")))),
311 ColumnarValue::Scalar(ScalarValue::Int64(Some(0))),
312 ],
313 exec_err!("field position must not be zero"),
314 &str,
315 Utf8,
316 StringArray
317 );
318
319 Ok(())
320 }
321}