datafusion_functions/unicode/
initcap.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{
22    Array, ArrayRef, GenericStringBuilder, OffsetSizeTrait, StringViewBuilder,
23};
24use arrow::datatypes::DataType;
25
26use crate::utils::{make_scalar_function, utf8_to_str_type};
27use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
28use datafusion_common::types::logical_string;
29use datafusion_common::{exec_err, Result};
30use datafusion_expr::{
31    Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignatureClass,
32    Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37    doc_section(label = "String Functions"),
38    description = "Capitalizes the first character in each word in the input string. \
39            Words are delimited by non-alphanumeric characters.",
40    syntax_example = "initcap(str)",
41    sql_example = r#"```sql
42> select initcap('apache datafusion');
43+------------------------------------+
44| initcap(Utf8("apache datafusion")) |
45+------------------------------------+
46| Apache Datafusion                  |
47+------------------------------------+
48```"#,
49    standard_argument(name = "str", prefix = "String"),
50    related_udf(name = "lower"),
51    related_udf(name = "upper")
52)]
53#[derive(Debug)]
54pub struct InitcapFunc {
55    signature: Signature,
56}
57
58impl Default for InitcapFunc {
59    fn default() -> Self {
60        InitcapFunc::new()
61    }
62}
63
64impl InitcapFunc {
65    pub fn new() -> Self {
66        Self {
67            signature: Signature::coercible(
68                vec![Coercion::new_exact(TypeSignatureClass::Native(
69                    logical_string(),
70                ))],
71                Volatility::Immutable,
72            ),
73        }
74    }
75}
76
77impl ScalarUDFImpl for InitcapFunc {
78    fn as_any(&self) -> &dyn Any {
79        self
80    }
81
82    fn name(&self) -> &str {
83        "initcap"
84    }
85
86    fn signature(&self) -> &Signature {
87        &self.signature
88    }
89
90    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
91        if let DataType::Utf8View = arg_types[0] {
92            Ok(DataType::Utf8View)
93        } else {
94            utf8_to_str_type(&arg_types[0], "initcap")
95        }
96    }
97
98    fn invoke_with_args(
99        &self,
100        args: datafusion_expr::ScalarFunctionArgs,
101    ) -> Result<ColumnarValue> {
102        let args = &args.args;
103        match args[0].data_type() {
104            DataType::Utf8 => make_scalar_function(initcap::<i32>, vec![])(args),
105            DataType::LargeUtf8 => make_scalar_function(initcap::<i64>, vec![])(args),
106            DataType::Utf8View => make_scalar_function(initcap_utf8view, vec![])(args),
107            other => {
108                exec_err!("Unsupported data type {other:?} for function `initcap`")
109            }
110        }
111    }
112
113    fn documentation(&self) -> Option<&Documentation> {
114        self.doc()
115    }
116}
117
118/// Converts the first letter of each word to upper case and the rest to lower
119/// case. Words are sequences of alphanumeric characters separated by
120/// non-alphanumeric characters.
121///
122/// Example:
123/// ```sql
124/// initcap('hi THOMAS') = 'Hi Thomas'
125/// ```
126fn initcap<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
127    let string_array = as_generic_string_array::<T>(&args[0])?;
128
129    let mut builder = GenericStringBuilder::<T>::with_capacity(
130        string_array.len(),
131        string_array.value_data().len(),
132    );
133
134    string_array.iter().for_each(|str| match str {
135        Some(s) => {
136            let initcap_str = initcap_string(s);
137            builder.append_value(initcap_str);
138        }
139        None => builder.append_null(),
140    });
141
142    Ok(Arc::new(builder.finish()) as ArrayRef)
143}
144
145fn initcap_utf8view(args: &[ArrayRef]) -> Result<ArrayRef> {
146    let string_view_array = as_string_view_array(&args[0])?;
147
148    let mut builder = StringViewBuilder::with_capacity(string_view_array.len());
149
150    string_view_array.iter().for_each(|str| match str {
151        Some(s) => {
152            let initcap_str = initcap_string(s);
153            builder.append_value(initcap_str);
154        }
155        None => builder.append_null(),
156    });
157
158    Ok(Arc::new(builder.finish()) as ArrayRef)
159}
160
161fn initcap_string(input: &str) -> String {
162    let mut result = String::with_capacity(input.len());
163    let mut prev_is_alphanumeric = false;
164
165    if input.is_ascii() {
166        for c in input.chars() {
167            if prev_is_alphanumeric {
168                result.push(c.to_ascii_lowercase());
169            } else {
170                result.push(c.to_ascii_uppercase());
171            };
172            prev_is_alphanumeric = c.is_ascii_alphanumeric();
173        }
174    } else {
175        for c in input.chars() {
176            if prev_is_alphanumeric {
177                result.extend(c.to_lowercase());
178            } else {
179                result.extend(c.to_uppercase());
180            }
181            prev_is_alphanumeric = c.is_alphanumeric();
182        }
183    }
184
185    result
186}
187
188#[cfg(test)]
189mod tests {
190    use crate::unicode::initcap::InitcapFunc;
191    use crate::utils::test::test_function;
192    use arrow::array::{Array, StringArray, StringViewArray};
193    use arrow::datatypes::DataType::{Utf8, Utf8View};
194    use datafusion_common::{Result, ScalarValue};
195    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
196
197    #[test]
198    fn test_functions() -> Result<()> {
199        test_function!(
200            InitcapFunc::new(),
201            vec![ColumnarValue::Scalar(ScalarValue::from("hi THOMAS"))],
202            Ok(Some("Hi Thomas")),
203            &str,
204            Utf8,
205            StringArray
206        );
207        test_function!(
208            InitcapFunc::new(),
209            vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(
210                "êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ"
211                    .to_string()
212            )))],
213            Ok(Some(
214                "Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική"
215            )),
216            &str,
217            Utf8,
218            StringArray
219        );
220        test_function!(
221            InitcapFunc::new(),
222            vec![ColumnarValue::Scalar(ScalarValue::from(""))],
223            Ok(Some("")),
224            &str,
225            Utf8,
226            StringArray
227        );
228        test_function!(
229            InitcapFunc::new(),
230            vec![ColumnarValue::Scalar(ScalarValue::from(""))],
231            Ok(Some("")),
232            &str,
233            Utf8,
234            StringArray
235        );
236        test_function!(
237            InitcapFunc::new(),
238            vec![ColumnarValue::Scalar(ScalarValue::Utf8(None))],
239            Ok(None),
240            &str,
241            Utf8,
242            StringArray
243        );
244
245        test_function!(
246            InitcapFunc::new(),
247            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
248                "hi THOMAS".to_string()
249            )))],
250            Ok(Some("Hi Thomas")),
251            &str,
252            Utf8View,
253            StringViewArray
254        );
255        test_function!(
256            InitcapFunc::new(),
257            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
258                "hi THOMAS wIth M0re ThAN 12 ChaRs".to_string()
259            )))],
260            Ok(Some("Hi Thomas With M0re Than 12 Chars")),
261            &str,
262            Utf8View,
263            StringViewArray
264        );
265        test_function!(
266            InitcapFunc::new(),
267            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
268                "đẸp đẼ êM ả ñAnDÚ ÁrBOL ОлЕГ ИвАНОВИч ÍslENsku ÞjóðaRiNNaR εΛλΗΝΙκΉ"
269                    .to_string()
270            )))],
271            Ok(Some(
272                "Đẹp Đẽ Êm Ả Ñandú Árbol Олег Иванович Íslensku Þjóðarinnar Ελληνική"
273            )),
274            &str,
275            Utf8View,
276            StringViewArray
277        );
278        test_function!(
279            InitcapFunc::new(),
280            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
281                "".to_string()
282            )))],
283            Ok(Some("")),
284            &str,
285            Utf8View,
286            StringViewArray
287        );
288        test_function!(
289            InitcapFunc::new(),
290            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(None))],
291            Ok(None),
292            &str,
293            Utf8View,
294            StringViewArray
295        );
296
297        Ok(())
298    }
299}