polars_plan/dsl/function_expr/
strings.rs

1use std::borrow::Cow;
2
3use arrow::legacy::utils::CustomIterTools;
4#[cfg(feature = "timezones")]
5use once_cell::sync::Lazy;
6#[cfg(feature = "timezones")]
7use polars_core::chunked_array::temporal::validate_time_zone;
8use polars_core::utils::handle_casting_failures;
9#[cfg(feature = "dtype-struct")]
10use polars_utils::format_pl_smallstr;
11#[cfg(feature = "regex")]
12use regex::{escape, NoExpand, Regex};
13#[cfg(feature = "serde")]
14use serde::{Deserialize, Serialize};
15
16use super::*;
17use crate::{map, map_as_slice};
18
19#[cfg(all(feature = "regex", feature = "timezones"))]
20static TZ_AWARE_RE: Lazy<Regex> =
21    Lazy::new(|| Regex::new(r"(%z)|(%:z)|(%::z)|(%:::z)|(%#z)|(^%\+$)").unwrap());
22
23#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
24#[derive(Clone, PartialEq, Debug, Eq, Hash)]
25pub enum StringFunction {
26    #[cfg(feature = "concat_str")]
27    ConcatHorizontal {
28        delimiter: PlSmallStr,
29        ignore_nulls: bool,
30    },
31    #[cfg(feature = "concat_str")]
32    ConcatVertical {
33        delimiter: PlSmallStr,
34        ignore_nulls: bool,
35    },
36    #[cfg(feature = "regex")]
37    Contains {
38        literal: bool,
39        strict: bool,
40    },
41    CountMatches(bool),
42    EndsWith,
43    Extract(usize),
44    ExtractAll,
45    #[cfg(feature = "extract_groups")]
46    ExtractGroups {
47        dtype: DataType,
48        pat: PlSmallStr,
49    },
50    #[cfg(feature = "regex")]
51    Find {
52        literal: bool,
53        strict: bool,
54    },
55    #[cfg(feature = "string_to_integer")]
56    ToInteger(bool),
57    LenBytes,
58    LenChars,
59    Lowercase,
60    #[cfg(feature = "extract_jsonpath")]
61    JsonDecode {
62        dtype: Option<DataType>,
63        infer_schema_len: Option<usize>,
64    },
65    #[cfg(feature = "extract_jsonpath")]
66    JsonPathMatch,
67    #[cfg(feature = "regex")]
68    Replace {
69        // negative is replace all
70        // how many matches to replace
71        n: i64,
72        literal: bool,
73    },
74    #[cfg(feature = "string_normalize")]
75    Normalize {
76        form: UnicodeForm,
77    },
78    #[cfg(feature = "string_reverse")]
79    Reverse,
80    #[cfg(feature = "string_pad")]
81    PadStart {
82        length: usize,
83        fill_char: char,
84    },
85    #[cfg(feature = "string_pad")]
86    PadEnd {
87        length: usize,
88        fill_char: char,
89    },
90    Slice,
91    Head,
92    Tail,
93    #[cfg(feature = "string_encoding")]
94    HexEncode,
95    #[cfg(feature = "binary_encoding")]
96    HexDecode(bool),
97    #[cfg(feature = "string_encoding")]
98    Base64Encode,
99    #[cfg(feature = "binary_encoding")]
100    Base64Decode(bool),
101    StartsWith,
102    StripChars,
103    StripCharsStart,
104    StripCharsEnd,
105    StripPrefix,
106    StripSuffix,
107    #[cfg(feature = "dtype-struct")]
108    SplitExact {
109        n: usize,
110        inclusive: bool,
111    },
112    #[cfg(feature = "dtype-struct")]
113    SplitN(usize),
114    #[cfg(feature = "temporal")]
115    Strptime(DataType, StrptimeOptions),
116    Split(bool),
117    #[cfg(feature = "dtype-decimal")]
118    ToDecimal(usize),
119    #[cfg(feature = "nightly")]
120    Titlecase,
121    Uppercase,
122    #[cfg(feature = "string_pad")]
123    ZFill,
124    #[cfg(feature = "find_many")]
125    ContainsAny {
126        ascii_case_insensitive: bool,
127    },
128    #[cfg(feature = "find_many")]
129    ReplaceMany {
130        ascii_case_insensitive: bool,
131    },
132    #[cfg(feature = "find_many")]
133    ExtractMany {
134        ascii_case_insensitive: bool,
135        overlapping: bool,
136    },
137    #[cfg(feature = "find_many")]
138    FindMany {
139        ascii_case_insensitive: bool,
140        overlapping: bool,
141    },
142    #[cfg(feature = "regex")]
143    EscapeRegex,
144}
145
146impl StringFunction {
147    pub(super) fn get_field(&self, mapper: FieldsMapper) -> PolarsResult<Field> {
148        use StringFunction::*;
149        match self {
150            #[cfg(feature = "concat_str")]
151            ConcatVertical { .. } | ConcatHorizontal { .. } => mapper.with_dtype(DataType::String),
152            #[cfg(feature = "regex")]
153            Contains { .. } => mapper.with_dtype(DataType::Boolean),
154            CountMatches(_) => mapper.with_dtype(DataType::UInt32),
155            EndsWith | StartsWith => mapper.with_dtype(DataType::Boolean),
156            Extract(_) => mapper.with_same_dtype(),
157            ExtractAll => mapper.with_dtype(DataType::List(Box::new(DataType::String))),
158            #[cfg(feature = "extract_groups")]
159            ExtractGroups { dtype, .. } => mapper.with_dtype(dtype.clone()),
160            #[cfg(feature = "string_to_integer")]
161            ToInteger { .. } => mapper.with_dtype(DataType::Int64),
162            #[cfg(feature = "regex")]
163            Find { .. } => mapper.with_dtype(DataType::UInt32),
164            #[cfg(feature = "extract_jsonpath")]
165            JsonDecode { dtype, .. } => mapper.with_opt_dtype(dtype.clone()),
166            #[cfg(feature = "extract_jsonpath")]
167            JsonPathMatch => mapper.with_dtype(DataType::String),
168            LenBytes => mapper.with_dtype(DataType::UInt32),
169            LenChars => mapper.with_dtype(DataType::UInt32),
170            #[cfg(feature = "regex")]
171            Replace { .. } => mapper.with_same_dtype(),
172            #[cfg(feature = "string_normalize")]
173            Normalize { .. } => mapper.with_same_dtype(),
174            #[cfg(feature = "string_reverse")]
175            Reverse => mapper.with_same_dtype(),
176            #[cfg(feature = "temporal")]
177            Strptime(dtype, _) => mapper.with_dtype(dtype.clone()),
178            Split(_) => mapper.with_dtype(DataType::List(Box::new(DataType::String))),
179            #[cfg(feature = "nightly")]
180            Titlecase => mapper.with_same_dtype(),
181            #[cfg(feature = "dtype-decimal")]
182            ToDecimal(_) => mapper.with_dtype(DataType::Decimal(None, None)),
183            #[cfg(feature = "string_encoding")]
184            HexEncode => mapper.with_same_dtype(),
185            #[cfg(feature = "binary_encoding")]
186            HexDecode(_) => mapper.with_dtype(DataType::Binary),
187            #[cfg(feature = "string_encoding")]
188            Base64Encode => mapper.with_same_dtype(),
189            #[cfg(feature = "binary_encoding")]
190            Base64Decode(_) => mapper.with_dtype(DataType::Binary),
191            Uppercase | Lowercase | StripChars | StripCharsStart | StripCharsEnd | StripPrefix
192            | StripSuffix | Slice | Head | Tail => mapper.with_same_dtype(),
193            #[cfg(feature = "string_pad")]
194            PadStart { .. } | PadEnd { .. } | ZFill => mapper.with_same_dtype(),
195            #[cfg(feature = "dtype-struct")]
196            SplitExact { n, .. } => mapper.with_dtype(DataType::Struct(
197                (0..n + 1)
198                    .map(|i| Field::new(format_pl_smallstr!("field_{i}"), DataType::String))
199                    .collect(),
200            )),
201            #[cfg(feature = "dtype-struct")]
202            SplitN(n) => mapper.with_dtype(DataType::Struct(
203                (0..*n)
204                    .map(|i| Field::new(format_pl_smallstr!("field_{i}"), DataType::String))
205                    .collect(),
206            )),
207            #[cfg(feature = "find_many")]
208            ContainsAny { .. } => mapper.with_dtype(DataType::Boolean),
209            #[cfg(feature = "find_many")]
210            ReplaceMany { .. } => mapper.with_same_dtype(),
211            #[cfg(feature = "find_many")]
212            ExtractMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::String))),
213            #[cfg(feature = "find_many")]
214            FindMany { .. } => mapper.with_dtype(DataType::List(Box::new(DataType::UInt32))),
215            #[cfg(feature = "regex")]
216            EscapeRegex => mapper.with_same_dtype(),
217        }
218    }
219}
220
221impl Display for StringFunction {
222    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
223        use StringFunction::*;
224        let s = match self {
225            #[cfg(feature = "regex")]
226            Contains { .. } => "contains",
227            CountMatches(_) => "count_matches",
228            EndsWith { .. } => "ends_with",
229            Extract(_) => "extract",
230            #[cfg(feature = "concat_str")]
231            ConcatHorizontal { .. } => "concat_horizontal",
232            #[cfg(feature = "concat_str")]
233            ConcatVertical { .. } => "concat_vertical",
234            ExtractAll => "extract_all",
235            #[cfg(feature = "extract_groups")]
236            ExtractGroups { .. } => "extract_groups",
237            #[cfg(feature = "string_to_integer")]
238            ToInteger { .. } => "to_integer",
239            #[cfg(feature = "regex")]
240            Find { .. } => "find",
241            Head { .. } => "head",
242            Tail { .. } => "tail",
243            #[cfg(feature = "extract_jsonpath")]
244            JsonDecode { .. } => "json_decode",
245            #[cfg(feature = "extract_jsonpath")]
246            JsonPathMatch => "json_path_match",
247            LenBytes => "len_bytes",
248            Lowercase => "lowercase",
249            LenChars => "len_chars",
250            #[cfg(feature = "string_pad")]
251            PadEnd { .. } => "pad_end",
252            #[cfg(feature = "string_pad")]
253            PadStart { .. } => "pad_start",
254            #[cfg(feature = "regex")]
255            Replace { .. } => "replace",
256            #[cfg(feature = "string_normalize")]
257            Normalize { .. } => "normalize",
258            #[cfg(feature = "string_reverse")]
259            Reverse => "reverse",
260            #[cfg(feature = "string_encoding")]
261            HexEncode => "hex_encode",
262            #[cfg(feature = "binary_encoding")]
263            HexDecode(_) => "hex_decode",
264            #[cfg(feature = "string_encoding")]
265            Base64Encode => "base64_encode",
266            #[cfg(feature = "binary_encoding")]
267            Base64Decode(_) => "base64_decode",
268            Slice => "slice",
269            StartsWith { .. } => "starts_with",
270            StripChars => "strip_chars",
271            StripCharsStart => "strip_chars_start",
272            StripCharsEnd => "strip_chars_end",
273            StripPrefix => "strip_prefix",
274            StripSuffix => "strip_suffix",
275            #[cfg(feature = "dtype-struct")]
276            SplitExact { inclusive, .. } => {
277                if *inclusive {
278                    "split_exact_inclusive"
279                } else {
280                    "split_exact"
281                }
282            },
283            #[cfg(feature = "dtype-struct")]
284            SplitN(_) => "splitn",
285            #[cfg(feature = "temporal")]
286            Strptime(_, _) => "strptime",
287            Split(inclusive) => {
288                if *inclusive {
289                    "split_inclusive"
290                } else {
291                    "split"
292                }
293            },
294            #[cfg(feature = "nightly")]
295            Titlecase => "titlecase",
296            #[cfg(feature = "dtype-decimal")]
297            ToDecimal(_) => "to_decimal",
298            Uppercase => "uppercase",
299            #[cfg(feature = "string_pad")]
300            ZFill => "zfill",
301            #[cfg(feature = "find_many")]
302            ContainsAny { .. } => "contains_any",
303            #[cfg(feature = "find_many")]
304            ReplaceMany { .. } => "replace_many",
305            #[cfg(feature = "find_many")]
306            ExtractMany { .. } => "extract_many",
307            #[cfg(feature = "find_many")]
308            FindMany { .. } => "extract_many",
309            #[cfg(feature = "regex")]
310            EscapeRegex => "escape_regex",
311        };
312        write!(f, "str.{s}")
313    }
314}
315
316impl From<StringFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
317    fn from(func: StringFunction) -> Self {
318        use StringFunction::*;
319        match func {
320            #[cfg(feature = "regex")]
321            Contains { literal, strict } => map_as_slice!(strings::contains, literal, strict),
322            CountMatches(literal) => {
323                map_as_slice!(strings::count_matches, literal)
324            },
325            EndsWith { .. } => map_as_slice!(strings::ends_with),
326            StartsWith { .. } => map_as_slice!(strings::starts_with),
327            Extract(group_index) => map_as_slice!(strings::extract, group_index),
328            ExtractAll => {
329                map_as_slice!(strings::extract_all)
330            },
331            #[cfg(feature = "extract_groups")]
332            ExtractGroups { pat, dtype } => {
333                map!(strings::extract_groups, &pat, &dtype)
334            },
335            #[cfg(feature = "regex")]
336            Find { literal, strict } => map_as_slice!(strings::find, literal, strict),
337            LenBytes => map!(strings::len_bytes),
338            LenChars => map!(strings::len_chars),
339            #[cfg(feature = "string_pad")]
340            PadEnd { length, fill_char } => {
341                map!(strings::pad_end, length, fill_char)
342            },
343            #[cfg(feature = "string_pad")]
344            PadStart { length, fill_char } => {
345                map!(strings::pad_start, length, fill_char)
346            },
347            #[cfg(feature = "string_pad")]
348            ZFill => {
349                map_as_slice!(strings::zfill)
350            },
351            #[cfg(feature = "temporal")]
352            Strptime(dtype, options) => {
353                map_as_slice!(strings::strptime, dtype.clone(), &options)
354            },
355            Split(inclusive) => {
356                map_as_slice!(strings::split, inclusive)
357            },
358            #[cfg(feature = "dtype-struct")]
359            SplitExact { n, inclusive } => map_as_slice!(strings::split_exact, n, inclusive),
360            #[cfg(feature = "dtype-struct")]
361            SplitN(n) => map_as_slice!(strings::splitn, n),
362            #[cfg(feature = "concat_str")]
363            ConcatVertical {
364                delimiter,
365                ignore_nulls,
366            } => map!(strings::join, &delimiter, ignore_nulls),
367            #[cfg(feature = "concat_str")]
368            ConcatHorizontal {
369                delimiter,
370                ignore_nulls,
371            } => map_as_slice!(strings::concat_hor, &delimiter, ignore_nulls),
372            #[cfg(feature = "regex")]
373            Replace { n, literal } => map_as_slice!(strings::replace, literal, n),
374            #[cfg(feature = "string_normalize")]
375            Normalize { form } => map!(strings::normalize, form.clone()),
376            #[cfg(feature = "string_reverse")]
377            Reverse => map!(strings::reverse),
378            Uppercase => map!(uppercase),
379            Lowercase => map!(lowercase),
380            #[cfg(feature = "nightly")]
381            Titlecase => map!(strings::titlecase),
382            StripChars => map_as_slice!(strings::strip_chars),
383            StripCharsStart => map_as_slice!(strings::strip_chars_start),
384            StripCharsEnd => map_as_slice!(strings::strip_chars_end),
385            StripPrefix => map_as_slice!(strings::strip_prefix),
386            StripSuffix => map_as_slice!(strings::strip_suffix),
387            #[cfg(feature = "string_to_integer")]
388            ToInteger(strict) => map_as_slice!(strings::to_integer, strict),
389            Slice => map_as_slice!(strings::str_slice),
390            Head => map_as_slice!(strings::str_head),
391            Tail => map_as_slice!(strings::str_tail),
392            #[cfg(feature = "string_encoding")]
393            HexEncode => map!(strings::hex_encode),
394            #[cfg(feature = "binary_encoding")]
395            HexDecode(strict) => map!(strings::hex_decode, strict),
396            #[cfg(feature = "string_encoding")]
397            Base64Encode => map!(strings::base64_encode),
398            #[cfg(feature = "binary_encoding")]
399            Base64Decode(strict) => map!(strings::base64_decode, strict),
400            #[cfg(feature = "dtype-decimal")]
401            ToDecimal(infer_len) => map!(strings::to_decimal, infer_len),
402            #[cfg(feature = "extract_jsonpath")]
403            JsonDecode {
404                dtype,
405                infer_schema_len,
406            } => map!(strings::json_decode, dtype.clone(), infer_schema_len),
407            #[cfg(feature = "extract_jsonpath")]
408            JsonPathMatch => map_as_slice!(strings::json_path_match),
409            #[cfg(feature = "find_many")]
410            ContainsAny {
411                ascii_case_insensitive,
412            } => {
413                map_as_slice!(contains_any, ascii_case_insensitive)
414            },
415            #[cfg(feature = "find_many")]
416            ReplaceMany {
417                ascii_case_insensitive,
418            } => {
419                map_as_slice!(replace_many, ascii_case_insensitive)
420            },
421            #[cfg(feature = "find_many")]
422            ExtractMany {
423                ascii_case_insensitive,
424                overlapping,
425            } => {
426                map_as_slice!(extract_many, ascii_case_insensitive, overlapping)
427            },
428            #[cfg(feature = "find_many")]
429            FindMany {
430                ascii_case_insensitive,
431                overlapping,
432            } => {
433                map_as_slice!(find_many, ascii_case_insensitive, overlapping)
434            },
435            #[cfg(feature = "regex")]
436            EscapeRegex => map!(escape_regex),
437        }
438    }
439}
440
441#[cfg(feature = "find_many")]
442fn contains_any(s: &[Column], ascii_case_insensitive: bool) -> PolarsResult<Column> {
443    let ca = s[0].str()?;
444    let patterns = s[1].str()?;
445    polars_ops::chunked_array::strings::contains_any(ca, patterns, ascii_case_insensitive)
446        .map(|out| out.into_column())
447}
448
449#[cfg(feature = "find_many")]
450fn replace_many(s: &[Column], ascii_case_insensitive: bool) -> PolarsResult<Column> {
451    let ca = s[0].str()?;
452    let patterns = s[1].str()?;
453    let replace_with = s[2].str()?;
454    polars_ops::chunked_array::strings::replace_all(
455        ca,
456        patterns,
457        replace_with,
458        ascii_case_insensitive,
459    )
460    .map(|out| out.into_column())
461}
462
463#[cfg(feature = "find_many")]
464fn extract_many(
465    s: &[Column],
466    ascii_case_insensitive: bool,
467    overlapping: bool,
468) -> PolarsResult<Column> {
469    _check_same_length(s, "extract_many")?;
470    let ca = s[0].str()?;
471    let patterns = &s[1];
472
473    polars_ops::chunked_array::strings::extract_many(
474        ca,
475        patterns.as_materialized_series(),
476        ascii_case_insensitive,
477        overlapping,
478    )
479    .map(|out| out.into_column())
480}
481
482#[cfg(feature = "find_many")]
483fn find_many(
484    s: &[Column],
485    ascii_case_insensitive: bool,
486    overlapping: bool,
487) -> PolarsResult<Column> {
488    let ca = s[0].str()?;
489    let patterns = &s[1];
490
491    polars_ops::chunked_array::strings::find_many(
492        ca,
493        patterns.as_materialized_series(),
494        ascii_case_insensitive,
495        overlapping,
496    )
497    .map(|out| out.into_column())
498}
499
500fn uppercase(s: &Column) -> PolarsResult<Column> {
501    let ca = s.str()?;
502    Ok(ca.to_uppercase().into_column())
503}
504
505fn lowercase(s: &Column) -> PolarsResult<Column> {
506    let ca = s.str()?;
507    Ok(ca.to_lowercase().into_column())
508}
509
510#[cfg(feature = "nightly")]
511pub(super) fn titlecase(s: &Column) -> PolarsResult<Column> {
512    let ca = s.str()?;
513    Ok(ca.to_titlecase().into_column())
514}
515
516pub(super) fn len_chars(s: &Column) -> PolarsResult<Column> {
517    let ca = s.str()?;
518    Ok(ca.str_len_chars().into_column())
519}
520
521pub(super) fn len_bytes(s: &Column) -> PolarsResult<Column> {
522    let ca = s.str()?;
523    Ok(ca.str_len_bytes().into_column())
524}
525
526#[cfg(feature = "regex")]
527pub(super) fn contains(s: &[Column], literal: bool, strict: bool) -> PolarsResult<Column> {
528    _check_same_length(s, "contains")?;
529    let ca = s[0].str()?;
530    let pat = s[1].str()?;
531    ca.contains_chunked(pat, literal, strict)
532        .map(|ok| ok.into_column())
533}
534
535#[cfg(feature = "regex")]
536pub(super) fn find(s: &[Column], literal: bool, strict: bool) -> PolarsResult<Column> {
537    _check_same_length(s, "find")?;
538    let ca = s[0].str()?;
539    let pat = s[1].str()?;
540    ca.find_chunked(pat, literal, strict)
541        .map(|ok| ok.into_column())
542}
543
544pub(super) fn ends_with(s: &[Column]) -> PolarsResult<Column> {
545    _check_same_length(s, "ends_with")?;
546    let ca = &s[0].str()?.as_binary();
547    let suffix = &s[1].str()?.as_binary();
548
549    Ok(ca.ends_with_chunked(suffix).into_column())
550}
551
552pub(super) fn starts_with(s: &[Column]) -> PolarsResult<Column> {
553    _check_same_length(s, "starts_with")?;
554    let ca = s[0].str()?;
555    let prefix = s[1].str()?;
556    Ok(ca.starts_with_chunked(prefix).into_column())
557}
558
559/// Extract a regex pattern from the a string value.
560pub(super) fn extract(s: &[Column], group_index: usize) -> PolarsResult<Column> {
561    let ca = s[0].str()?;
562    let pat = s[1].str()?;
563    ca.extract(pat, group_index).map(|ca| ca.into_column())
564}
565
566#[cfg(feature = "extract_groups")]
567/// Extract all capture groups from a regex pattern as a struct
568pub(super) fn extract_groups(s: &Column, pat: &str, dtype: &DataType) -> PolarsResult<Column> {
569    let ca = s.str()?;
570    ca.extract_groups(pat, dtype).map(Column::from)
571}
572
573#[cfg(feature = "string_pad")]
574pub(super) fn pad_start(s: &Column, length: usize, fill_char: char) -> PolarsResult<Column> {
575    let ca = s.str()?;
576    Ok(ca.pad_start(length, fill_char).into_column())
577}
578
579#[cfg(feature = "string_pad")]
580pub(super) fn pad_end(s: &Column, length: usize, fill_char: char) -> PolarsResult<Column> {
581    let ca = s.str()?;
582    Ok(ca.pad_end(length, fill_char).into_column())
583}
584
585#[cfg(feature = "string_pad")]
586pub(super) fn zfill(s: &[Column]) -> PolarsResult<Column> {
587    _check_same_length(s, "zfill")?;
588    let ca = s[0].str()?;
589    let length_s = s[1].strict_cast(&DataType::UInt64)?;
590    let length = length_s.u64()?;
591    Ok(ca.zfill(length).into_column())
592}
593
594pub(super) fn strip_chars(s: &[Column]) -> PolarsResult<Column> {
595    _check_same_length(s, "strip_chars")?;
596    let ca = s[0].str()?;
597    let pat_s = &s[1];
598    ca.strip_chars(pat_s).map(|ok| ok.into_column())
599}
600
601pub(super) fn strip_chars_start(s: &[Column]) -> PolarsResult<Column> {
602    _check_same_length(s, "strip_chars_start")?;
603    let ca = s[0].str()?;
604    let pat_s = &s[1];
605    ca.strip_chars_start(pat_s).map(|ok| ok.into_column())
606}
607
608pub(super) fn strip_chars_end(s: &[Column]) -> PolarsResult<Column> {
609    _check_same_length(s, "strip_chars_end")?;
610    let ca = s[0].str()?;
611    let pat_s = &s[1];
612    ca.strip_chars_end(pat_s).map(|ok| ok.into_column())
613}
614
615pub(super) fn strip_prefix(s: &[Column]) -> PolarsResult<Column> {
616    _check_same_length(s, "strip_prefix")?;
617    let ca = s[0].str()?;
618    let prefix = s[1].str()?;
619    Ok(ca.strip_prefix(prefix).into_column())
620}
621
622pub(super) fn strip_suffix(s: &[Column]) -> PolarsResult<Column> {
623    _check_same_length(s, "strip_suffix")?;
624    let ca = s[0].str()?;
625    let suffix = s[1].str()?;
626    Ok(ca.strip_suffix(suffix).into_column())
627}
628
629pub(super) fn extract_all(args: &[Column]) -> PolarsResult<Column> {
630    let s = &args[0];
631    let pat = &args[1];
632
633    let ca = s.str()?;
634    let pat = pat.str()?;
635
636    if pat.len() == 1 {
637        if let Some(pat) = pat.get(0) {
638            ca.extract_all(pat).map(|ca| ca.into_column())
639        } else {
640            Ok(Column::full_null(
641                ca.name().clone(),
642                ca.len(),
643                &DataType::List(Box::new(DataType::String)),
644            ))
645        }
646    } else {
647        ca.extract_all_many(pat).map(|ca| ca.into_column())
648    }
649}
650
651pub(super) fn count_matches(args: &[Column], literal: bool) -> PolarsResult<Column> {
652    let s = &args[0];
653    let pat = &args[1];
654
655    let ca = s.str()?;
656    let pat = pat.str()?;
657    if pat.len() == 1 {
658        if let Some(pat) = pat.get(0) {
659            ca.count_matches(pat, literal).map(|ca| ca.into_column())
660        } else {
661            Ok(Column::full_null(
662                ca.name().clone(),
663                ca.len(),
664                &DataType::UInt32,
665            ))
666        }
667    } else {
668        ca.count_matches_many(pat, literal)
669            .map(|ca| ca.into_column())
670    }
671}
672
673#[cfg(feature = "temporal")]
674pub(super) fn strptime(
675    s: &[Column],
676    dtype: DataType,
677    options: &StrptimeOptions,
678) -> PolarsResult<Column> {
679    match dtype {
680        #[cfg(feature = "dtype-date")]
681        DataType::Date => to_date(&s[0], options),
682        #[cfg(feature = "dtype-datetime")]
683        DataType::Datetime(time_unit, time_zone) => {
684            to_datetime(s, &time_unit, time_zone.as_ref(), options)
685        },
686        #[cfg(feature = "dtype-time")]
687        DataType::Time => to_time(&s[0], options),
688        dt => polars_bail!(ComputeError: "not implemented for dtype {}", dt),
689    }
690}
691
692#[cfg(feature = "dtype-struct")]
693pub(super) fn split_exact(s: &[Column], n: usize, inclusive: bool) -> PolarsResult<Column> {
694    let ca = s[0].str()?;
695    let by = s[1].str()?;
696
697    if inclusive {
698        ca.split_exact_inclusive(by, n).map(|ca| ca.into_column())
699    } else {
700        ca.split_exact(by, n).map(|ca| ca.into_column())
701    }
702}
703
704#[cfg(feature = "dtype-struct")]
705pub(super) fn splitn(s: &[Column], n: usize) -> PolarsResult<Column> {
706    let ca = s[0].str()?;
707    let by = s[1].str()?;
708
709    ca.splitn(by, n).map(|ca| ca.into_column())
710}
711
712pub(super) fn split(s: &[Column], inclusive: bool) -> PolarsResult<Column> {
713    let ca = s[0].str()?;
714    let by = s[1].str()?;
715
716    if inclusive {
717        Ok(ca.split_inclusive(by).into_column())
718    } else {
719        Ok(ca.split(by).into_column())
720    }
721}
722
723#[cfg(feature = "dtype-date")]
724fn to_date(s: &Column, options: &StrptimeOptions) -> PolarsResult<Column> {
725    let ca = s.str()?;
726    let out = {
727        if options.exact {
728            ca.as_date(options.format.as_deref(), options.cache)?
729                .into_column()
730        } else {
731            ca.as_date_not_exact(options.format.as_deref())?
732                .into_column()
733        }
734    };
735
736    if options.strict && ca.null_count() != out.null_count() {
737        handle_casting_failures(s.as_materialized_series(), out.as_materialized_series())?;
738    }
739    Ok(out.into_column())
740}
741
742#[cfg(feature = "dtype-datetime")]
743fn to_datetime(
744    s: &[Column],
745    time_unit: &TimeUnit,
746    time_zone: Option<&TimeZone>,
747    options: &StrptimeOptions,
748) -> PolarsResult<Column> {
749    let datetime_strings = &s[0].str()?;
750    let ambiguous = &s[1].str()?;
751    let tz_aware = match &options.format {
752        #[cfg(all(feature = "regex", feature = "timezones"))]
753        Some(format) => TZ_AWARE_RE.is_match(format),
754        _ => false,
755    };
756    #[cfg(feature = "timezones")]
757    if let Some(time_zone) = time_zone {
758        validate_time_zone(time_zone)?;
759    }
760    let out = if options.exact {
761        datetime_strings
762            .as_datetime(
763                options.format.as_deref(),
764                *time_unit,
765                options.cache,
766                tz_aware,
767                time_zone,
768                ambiguous,
769            )?
770            .into_column()
771    } else {
772        datetime_strings
773            .as_datetime_not_exact(
774                options.format.as_deref(),
775                *time_unit,
776                tz_aware,
777                time_zone,
778                ambiguous,
779            )?
780            .into_column()
781    };
782
783    if options.strict && datetime_strings.null_count() != out.null_count() {
784        handle_casting_failures(s[0].as_materialized_series(), out.as_materialized_series())?;
785    }
786    Ok(out.into_column())
787}
788
789#[cfg(feature = "dtype-time")]
790fn to_time(s: &Column, options: &StrptimeOptions) -> PolarsResult<Column> {
791    polars_ensure!(
792        options.exact, ComputeError: "non-exact not implemented for Time data type"
793    );
794
795    let ca = s.str()?;
796    let out = ca
797        .as_time(options.format.as_deref(), options.cache)?
798        .into_column();
799
800    if options.strict && ca.null_count() != out.null_count() {
801        handle_casting_failures(s.as_materialized_series(), out.as_materialized_series())?;
802    }
803    Ok(out.into_column())
804}
805
806#[cfg(feature = "concat_str")]
807pub(super) fn join(s: &Column, delimiter: &str, ignore_nulls: bool) -> PolarsResult<Column> {
808    let str_s = s.cast(&DataType::String)?;
809    let joined = polars_ops::chunked_array::str_join(str_s.str()?, delimiter, ignore_nulls);
810    Ok(joined.into_column())
811}
812
813#[cfg(feature = "concat_str")]
814pub(super) fn concat_hor(
815    series: &[Column],
816    delimiter: &str,
817    ignore_nulls: bool,
818) -> PolarsResult<Column> {
819    let str_series: Vec<_> = series
820        .iter()
821        .map(|s| s.cast(&DataType::String))
822        .collect::<PolarsResult<_>>()?;
823    let cas: Vec<_> = str_series.iter().map(|s| s.str().unwrap()).collect();
824    Ok(polars_ops::chunked_array::hor_str_concat(&cas, delimiter, ignore_nulls)?.into_column())
825}
826
827impl From<StringFunction> for FunctionExpr {
828    fn from(str: StringFunction) -> Self {
829        FunctionExpr::StringExpr(str)
830    }
831}
832
833#[cfg(feature = "regex")]
834fn get_pat(pat: &StringChunked) -> PolarsResult<&str> {
835    pat.get(0).ok_or_else(
836        || polars_err!(ComputeError: "pattern cannot be 'null' in 'replace' expression"),
837    )
838}
839
840// used only if feature="regex"
841#[allow(dead_code)]
842fn iter_and_replace<'a, F>(ca: &'a StringChunked, val: &'a StringChunked, f: F) -> StringChunked
843where
844    F: Fn(&'a str, &'a str) -> Cow<'a, str>,
845{
846    let mut out: StringChunked = ca
847        .into_iter()
848        .zip(val)
849        .map(|(opt_src, opt_val)| match (opt_src, opt_val) {
850            (Some(src), Some(val)) => Some(f(src, val)),
851            _ => None,
852        })
853        .collect_trusted();
854
855    out.rename(ca.name().clone());
856    out
857}
858
859#[cfg(feature = "regex")]
860fn is_literal_pat(pat: &str) -> bool {
861    pat.chars().all(|c| !c.is_ascii_punctuation())
862}
863
864#[cfg(feature = "regex")]
865fn replace_n<'a>(
866    ca: &'a StringChunked,
867    pat: &'a StringChunked,
868    val: &'a StringChunked,
869    literal: bool,
870    n: usize,
871) -> PolarsResult<StringChunked> {
872    match (pat.len(), val.len()) {
873        (1, 1) => {
874            let pat = get_pat(pat)?;
875            let val = val.get(0).ok_or_else(
876                || polars_err!(ComputeError: "value cannot be 'null' in 'replace' expression"),
877            )?;
878            let literal = literal || is_literal_pat(pat);
879
880            match literal {
881                true => ca.replace_literal(pat, val, n),
882                false => {
883                    if n > 1 {
884                        polars_bail!(ComputeError: "regex replacement with 'n > 1' not yet supported")
885                    }
886                    ca.replace(pat, val)
887                },
888            }
889        },
890        (1, len_val) => {
891            if n > 1 {
892                polars_bail!(ComputeError: "multivalue replacement with 'n > 1' not yet supported")
893            }
894            let mut pat = get_pat(pat)?.to_string();
895            polars_ensure!(
896                len_val == ca.len(),
897                ComputeError:
898                "replacement value length ({}) does not match string column length ({})",
899                len_val, ca.len(),
900            );
901            let lit = is_literal_pat(&pat);
902            let literal_pat = literal || lit;
903
904            if literal_pat {
905                pat = escape(&pat)
906            }
907
908            let reg = Regex::new(&pat)?;
909
910            let f = |s: &'a str, val: &'a str| {
911                if lit && (s.len() <= 32) {
912                    Cow::Owned(s.replacen(&pat, val, 1))
913                } else {
914                    // According to the docs for replace
915                    // when literal = True then capture groups are ignored.
916                    if literal {
917                        reg.replace(s, NoExpand(val))
918                    } else {
919                        reg.replace(s, val)
920                    }
921                }
922            };
923            Ok(iter_and_replace(ca, val, f))
924        },
925        _ => polars_bail!(
926            ComputeError: "dynamic pattern length in 'str.replace' expressions is not supported yet"
927        ),
928    }
929}
930
931#[cfg(feature = "regex")]
932fn replace_all<'a>(
933    ca: &'a StringChunked,
934    pat: &'a StringChunked,
935    val: &'a StringChunked,
936    literal: bool,
937) -> PolarsResult<StringChunked> {
938    match (pat.len(), val.len()) {
939        (1, 1) => {
940            let pat = get_pat(pat)?;
941            let val = val.get(0).ok_or_else(
942                || polars_err!(ComputeError: "value cannot be 'null' in 'replace' expression"),
943            )?;
944            let literal = literal || is_literal_pat(pat);
945
946            match literal {
947                true => ca.replace_literal_all(pat, val),
948                false => ca.replace_all(pat, val),
949            }
950        },
951        (1, len_val) => {
952            let mut pat = get_pat(pat)?.to_string();
953            polars_ensure!(
954                len_val == ca.len(),
955                ComputeError:
956                "replacement value length ({}) does not match string column length ({})",
957                len_val, ca.len(),
958            );
959
960            let literal_pat = literal || is_literal_pat(&pat);
961
962            if literal_pat {
963                pat = escape(&pat)
964            }
965
966            let reg = Regex::new(&pat)?;
967
968            let f = |s: &'a str, val: &'a str| {
969                // According to the docs for replace_all
970                // when literal = True then capture groups are ignored.
971                if literal {
972                    reg.replace_all(s, NoExpand(val))
973                } else {
974                    reg.replace_all(s, val)
975                }
976            };
977
978            Ok(iter_and_replace(ca, val, f))
979        },
980        _ => polars_bail!(
981            ComputeError: "dynamic pattern length in 'str.replace' expressions is not supported yet"
982        ),
983    }
984}
985
986#[cfg(feature = "regex")]
987pub(super) fn replace(s: &[Column], literal: bool, n: i64) -> PolarsResult<Column> {
988    let column = &s[0];
989    let pat = &s[1];
990    let val = &s[2];
991    let all = n < 0;
992
993    let column = column.str()?;
994    let pat = pat.str()?;
995    let val = val.str()?;
996
997    if all {
998        replace_all(column, pat, val, literal)
999    } else {
1000        replace_n(column, pat, val, literal, n as usize)
1001    }
1002    .map(|ca| ca.into_column())
1003}
1004
1005#[cfg(feature = "string_normalize")]
1006pub(super) fn normalize(s: &Column, form: UnicodeForm) -> PolarsResult<Column> {
1007    let ca = s.str()?;
1008    Ok(ca.str_normalize(form).into_column())
1009}
1010
1011#[cfg(feature = "string_reverse")]
1012pub(super) fn reverse(s: &Column) -> PolarsResult<Column> {
1013    let ca = s.str()?;
1014    Ok(ca.str_reverse().into_column())
1015}
1016
1017#[cfg(feature = "string_to_integer")]
1018pub(super) fn to_integer(s: &[Column], strict: bool) -> PolarsResult<Column> {
1019    let ca = s[0].str()?;
1020    let base = s[1].strict_cast(&DataType::UInt32)?;
1021    ca.to_integer(base.u32()?, strict)
1022        .map(|ok| ok.into_column())
1023}
1024
1025fn _ensure_lengths(s: &[Column]) -> bool {
1026    // Calculate the post-broadcast length and ensure everything is consistent.
1027    let len = s
1028        .iter()
1029        .map(|series| series.len())
1030        .filter(|l| *l != 1)
1031        .max()
1032        .unwrap_or(1);
1033    s.iter()
1034        .all(|series| series.len() == 1 || series.len() == len)
1035}
1036
1037fn _check_same_length(s: &[Column], fn_name: &str) -> Result<(), PolarsError> {
1038    polars_ensure!(
1039        _ensure_lengths(s),
1040        ComputeError: "all series in `str.{}()` should have equal or unit length",
1041        fn_name
1042    );
1043    Ok(())
1044}
1045
1046pub(super) fn str_slice(s: &[Column]) -> PolarsResult<Column> {
1047    _check_same_length(s, "slice")?;
1048    let ca = s[0].str()?;
1049    let offset = &s[1];
1050    let length = &s[2];
1051    Ok(ca.str_slice(offset, length)?.into_column())
1052}
1053
1054pub(super) fn str_head(s: &[Column]) -> PolarsResult<Column> {
1055    _check_same_length(s, "head")?;
1056    let ca = s[0].str()?;
1057    let n = &s[1];
1058    Ok(ca.str_head(n)?.into_column())
1059}
1060
1061pub(super) fn str_tail(s: &[Column]) -> PolarsResult<Column> {
1062    _check_same_length(s, "tail")?;
1063    let ca = s[0].str()?;
1064    let n = &s[1];
1065    Ok(ca.str_tail(n)?.into_column())
1066}
1067
1068#[cfg(feature = "string_encoding")]
1069pub(super) fn hex_encode(s: &Column) -> PolarsResult<Column> {
1070    Ok(s.str()?.hex_encode().into_column())
1071}
1072
1073#[cfg(feature = "binary_encoding")]
1074pub(super) fn hex_decode(s: &Column, strict: bool) -> PolarsResult<Column> {
1075    s.str()?.hex_decode(strict).map(|ca| ca.into_column())
1076}
1077
1078#[cfg(feature = "string_encoding")]
1079pub(super) fn base64_encode(s: &Column) -> PolarsResult<Column> {
1080    Ok(s.str()?.base64_encode().into_column())
1081}
1082
1083#[cfg(feature = "binary_encoding")]
1084pub(super) fn base64_decode(s: &Column, strict: bool) -> PolarsResult<Column> {
1085    s.str()?.base64_decode(strict).map(|ca| ca.into_column())
1086}
1087
1088#[cfg(feature = "dtype-decimal")]
1089pub(super) fn to_decimal(s: &Column, infer_len: usize) -> PolarsResult<Column> {
1090    let ca = s.str()?;
1091    ca.to_decimal(infer_len).map(Column::from)
1092}
1093
1094#[cfg(feature = "extract_jsonpath")]
1095pub(super) fn json_decode(
1096    s: &Column,
1097    dtype: Option<DataType>,
1098    infer_schema_len: Option<usize>,
1099) -> PolarsResult<Column> {
1100    let ca = s.str()?;
1101    ca.json_decode(dtype, infer_schema_len).map(Column::from)
1102}
1103
1104#[cfg(feature = "extract_jsonpath")]
1105pub(super) fn json_path_match(s: &[Column]) -> PolarsResult<Column> {
1106    _check_same_length(s, "json_path_match")?;
1107    let ca = s[0].str()?;
1108    let pat = s[1].str()?;
1109    Ok(ca.json_path_match(pat)?.into_column())
1110}
1111
1112#[cfg(feature = "regex")]
1113pub(super) fn escape_regex(s: &Column) -> PolarsResult<Column> {
1114    let ca = s.str()?;
1115    Ok(ca.str_escape_regex().into_column())
1116}