datafusion_functions/string/
common.rs1use std::fmt::{Display, Formatter};
21use std::sync::Arc;
22
23use crate::strings::make_and_append_view;
24use arrow::array::{
25 new_null_array, Array, ArrayRef, GenericStringArray, GenericStringBuilder,
26 NullBufferBuilder, OffsetSizeTrait, StringBuilder, StringViewArray,
27};
28use arrow::buffer::{Buffer, ScalarBuffer};
29use arrow::datatypes::DataType;
30use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
31use datafusion_common::Result;
32use datafusion_common::{exec_err, ScalarValue};
33use datafusion_expr::ColumnarValue;
34
35pub(crate) enum TrimType {
36 Left,
37 Right,
38 Both,
39}
40
41impl Display for TrimType {
42 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
43 match self {
44 TrimType::Left => write!(f, "ltrim"),
45 TrimType::Right => write!(f, "rtrim"),
46 TrimType::Both => write!(f, "btrim"),
47 }
48 }
49}
50
51pub(crate) fn general_trim<T: OffsetSizeTrait>(
52 args: &[ArrayRef],
53 trim_type: TrimType,
54 use_string_view: bool,
55) -> Result<ArrayRef> {
56 let func = match trim_type {
57 TrimType::Left => |input, pattern: &str| {
58 let pattern = pattern.chars().collect::<Vec<char>>();
59 let ltrimmed_str =
60 str::trim_start_matches::<&[char]>(input, pattern.as_ref());
61 let start_offset = input.len() - ltrimmed_str.len();
64
65 (ltrimmed_str, start_offset as u32)
66 },
67 TrimType::Right => |input, pattern: &str| {
68 let pattern = pattern.chars().collect::<Vec<char>>();
69 let rtrimmed_str = str::trim_end_matches::<&[char]>(input, pattern.as_ref());
70
71 (rtrimmed_str, 0)
73 },
74 TrimType::Both => |input, pattern: &str| {
75 let pattern = pattern.chars().collect::<Vec<char>>();
76 let ltrimmed_str =
77 str::trim_start_matches::<&[char]>(input, pattern.as_ref());
78 let start_offset = input.len() - ltrimmed_str.len();
81 let btrimmed_str =
82 str::trim_end_matches::<&[char]>(ltrimmed_str, pattern.as_ref());
83
84 (btrimmed_str, start_offset as u32)
85 },
86 };
87
88 if use_string_view {
89 string_view_trim(func, args)
90 } else {
91 string_trim::<T>(func, args)
92 }
93}
94
95fn string_view_trim<'a>(
119 trim_func: fn(&'a str, &'a str) -> (&'a str, u32),
120 args: &'a [ArrayRef],
121) -> Result<ArrayRef> {
122 let string_view_array = as_string_view_array(&args[0])?;
123 let mut views_buf = Vec::with_capacity(string_view_array.len());
124 let mut null_builder = NullBufferBuilder::new(string_view_array.len());
125
126 match args.len() {
127 1 => {
128 let array_iter = string_view_array.iter();
129 let views_iter = string_view_array.views().iter();
130 for (src_str_opt, raw_view) in array_iter.zip(views_iter) {
131 trim_and_append_str(
132 src_str_opt,
133 Some(" "),
134 trim_func,
135 &mut views_buf,
136 &mut null_builder,
137 raw_view,
138 );
139 }
140 }
141 2 => {
142 let characters_array = as_string_view_array(&args[1])?;
143
144 if characters_array.len() == 1 {
145 if characters_array.is_null(0) {
147 return Ok(new_null_array(
148 &DataType::Utf8View,
150 string_view_array.len(),
151 ));
152 }
153
154 let characters = characters_array.value(0);
155 let array_iter = string_view_array.iter();
156 let views_iter = string_view_array.views().iter();
157 for (src_str_opt, raw_view) in array_iter.zip(views_iter) {
158 trim_and_append_str(
159 src_str_opt,
160 Some(characters),
161 trim_func,
162 &mut views_buf,
163 &mut null_builder,
164 raw_view,
165 );
166 }
167 } else {
168 let characters_iter = characters_array.iter();
170 let array_iter = string_view_array.iter();
171 let views_iter = string_view_array.views().iter();
172 for ((src_str_opt, raw_view), characters_opt) in
173 array_iter.zip(views_iter).zip(characters_iter)
174 {
175 trim_and_append_str(
176 src_str_opt,
177 characters_opt,
178 trim_func,
179 &mut views_buf,
180 &mut null_builder,
181 raw_view,
182 );
183 }
184 }
185 }
186 other => {
187 return exec_err!(
188 "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
189 );
190 }
191 }
192
193 let views_buf = ScalarBuffer::from(views_buf);
194 let nulls_buf = null_builder.finish();
195
196 unsafe {
201 let array = StringViewArray::new_unchecked(
202 views_buf,
203 string_view_array.data_buffers().to_vec(),
204 nulls_buf,
205 );
206 Ok(Arc::new(array) as ArrayRef)
207 }
208}
209
210fn trim_and_append_str<'a>(
224 src_str_opt: Option<&'a str>,
225 trim_characters_opt: Option<&'a str>,
226 trim_func: fn(&'a str, &'a str) -> (&'a str, u32),
227 views_buf: &mut Vec<u128>,
228 null_builder: &mut NullBufferBuilder,
229 original_view: &u128,
230) {
231 if let (Some(src_str), Some(characters)) = (src_str_opt, trim_characters_opt) {
232 let (trim_str, start_offset) = trim_func(src_str, characters);
233 make_and_append_view(
234 views_buf,
235 null_builder,
236 original_view,
237 trim_str,
238 start_offset,
239 );
240 } else {
241 null_builder.append_null();
242 views_buf.push(0);
243 }
244}
245
246fn string_trim<'a, T: OffsetSizeTrait>(
251 func: fn(&'a str, &'a str) -> (&'a str, u32),
252 args: &'a [ArrayRef],
253) -> Result<ArrayRef> {
254 let string_array = as_generic_string_array::<T>(&args[0])?;
255
256 match args.len() {
257 1 => {
258 let result = string_array
259 .iter()
260 .map(|string| string.map(|string: &str| func(string, " ").0))
261 .collect::<GenericStringArray<T>>();
262
263 Ok(Arc::new(result) as ArrayRef)
264 }
265 2 => {
266 let characters_array = as_generic_string_array::<T>(&args[1])?;
267
268 if characters_array.len() == 1 {
269 if characters_array.is_null(0) {
270 return Ok(new_null_array(
271 string_array.data_type(),
272 string_array.len(),
273 ));
274 }
275
276 let characters = characters_array.value(0);
277 let result = string_array
278 .iter()
279 .map(|item| item.map(|string| func(string, characters).0))
280 .collect::<GenericStringArray<T>>();
281 return Ok(Arc::new(result) as ArrayRef);
282 }
283
284 let result = string_array
285 .iter()
286 .zip(characters_array.iter())
287 .map(|(string, characters)| match (string, characters) {
288 (Some(string), Some(characters)) => Some(func(string, characters).0),
289 _ => None,
290 })
291 .collect::<GenericStringArray<T>>();
292
293 Ok(Arc::new(result) as ArrayRef)
294 }
295 other => {
296 exec_err!(
297 "Function TRIM was called with {other} arguments. It requires at least 1 and at most 2."
298 )
299 }
300 }
301}
302
303pub(crate) fn to_lower(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
304 case_conversion(args, |string| string.to_lowercase(), name)
305}
306
307pub(crate) fn to_upper(args: &[ColumnarValue], name: &str) -> Result<ColumnarValue> {
308 case_conversion(args, |string| string.to_uppercase(), name)
309}
310
311fn case_conversion<'a, F>(
312 args: &'a [ColumnarValue],
313 op: F,
314 name: &str,
315) -> Result<ColumnarValue>
316where
317 F: Fn(&'a str) -> String,
318{
319 match &args[0] {
320 ColumnarValue::Array(array) => match array.data_type() {
321 DataType::Utf8 => Ok(ColumnarValue::Array(case_conversion_array::<i32, _>(
322 array, op,
323 )?)),
324 DataType::LargeUtf8 => Ok(ColumnarValue::Array(case_conversion_array::<
325 i64,
326 _,
327 >(array, op)?)),
328 DataType::Utf8View => {
329 let string_array = as_string_view_array(array)?;
330 let mut string_builder = StringBuilder::with_capacity(
331 string_array.len(),
332 string_array.get_array_memory_size(),
333 );
334
335 for str in string_array.iter() {
336 if let Some(str) = str {
337 string_builder.append_value(op(str));
338 } else {
339 string_builder.append_null();
340 }
341 }
342
343 Ok(ColumnarValue::Array(Arc::new(string_builder.finish())))
344 }
345 other => exec_err!("Unsupported data type {other:?} for function {name}"),
346 },
347 ColumnarValue::Scalar(scalar) => match scalar {
348 ScalarValue::Utf8(a) => {
349 let result = a.as_ref().map(|x| op(x));
350 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
351 }
352 ScalarValue::LargeUtf8(a) => {
353 let result = a.as_ref().map(|x| op(x));
354 Ok(ColumnarValue::Scalar(ScalarValue::LargeUtf8(result)))
355 }
356 ScalarValue::Utf8View(a) => {
357 let result = a.as_ref().map(|x| op(x));
358 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(result)))
359 }
360 other => exec_err!("Unsupported data type {other:?} for function {name}"),
361 },
362 }
363}
364
365fn case_conversion_array<'a, O, F>(array: &'a ArrayRef, op: F) -> Result<ArrayRef>
366where
367 O: OffsetSizeTrait,
368 F: Fn(&'a str) -> String,
369{
370 const PRE_ALLOC_BYTES: usize = 8;
371
372 let string_array = as_generic_string_array::<O>(array)?;
373 let value_data = string_array.value_data();
374
375 if value_data.is_ascii() {
377 return case_conversion_ascii_array::<O, _>(string_array, op);
378 }
379
380 let item_len = string_array.len();
382 let capacity = string_array.value_data().len() + PRE_ALLOC_BYTES;
383 let mut builder = GenericStringBuilder::<O>::with_capacity(item_len, capacity);
384
385 if string_array.null_count() == 0 {
386 let iter =
387 (0..item_len).map(|i| Some(op(unsafe { string_array.value_unchecked(i) })));
388 builder.extend(iter);
389 } else {
390 let iter = string_array.iter().map(|string| string.map(&op));
391 builder.extend(iter);
392 }
393 Ok(Arc::new(builder.finish()))
394}
395
396fn case_conversion_ascii_array<'a, O, F>(
400 string_array: &'a GenericStringArray<O>,
401 op: F,
402) -> Result<ArrayRef>
403where
404 O: OffsetSizeTrait,
405 F: Fn(&'a str) -> String,
406{
407 let value_data = string_array.value_data();
408 let str_values = unsafe { std::str::from_utf8_unchecked(value_data) };
411
412 let converted_values = op(str_values);
414 assert_eq!(converted_values.len(), str_values.len());
415 let bytes = converted_values.into_bytes();
416
417 let values = Buffer::from_vec(bytes);
419 let offsets = string_array.offsets().clone();
420 let nulls = string_array.nulls().cloned();
421 Ok(Arc::new(unsafe {
423 GenericStringArray::<O>::new_unchecked(offsets, values, nulls)
424 }))
425}