datafusion_functions/string/
btrim.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::string::common::*;
19use crate::utils::{make_scalar_function, utf8_to_str_type};
20use arrow::array::{ArrayRef, OffsetSizeTrait};
21use arrow::datatypes::DataType;
22use datafusion_common::types::logical_string;
23use datafusion_common::{exec_err, Result};
24use datafusion_expr::function::Hint;
25use datafusion_expr::{
26    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
27    TypeSignature, TypeSignatureClass, Volatility,
28};
29use datafusion_macros::user_doc;
30use std::any::Any;
31use std::sync::Arc;
32
33/// Returns the longest string with leading and trailing characters removed. If the characters are not specified, whitespace is removed.
34/// btrim('xyxtrimyyx', 'xyz') = 'trim'
35fn btrim<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
36    let use_string_view = args[0].data_type() == &DataType::Utf8View;
37    let args = if args.len() > 1 {
38        let arg1 = arrow::compute::kernels::cast::cast(&args[1], args[0].data_type())?;
39        vec![Arc::clone(&args[0]), arg1]
40    } else {
41        args.to_owned()
42    };
43    general_trim::<T>(&args, TrimType::Both, use_string_view)
44}
45
46#[user_doc(
47    doc_section(label = "String Functions"),
48    description = "Trims the specified trim string from the start and end of a string. If no trim string is provided, all whitespace is removed from the start and end of the input string.",
49    syntax_example = "btrim(str[, trim_str])",
50    sql_example = r#"```sql
51> select btrim('__datafusion____', '_');
52+-------------------------------------------+
53| btrim(Utf8("__datafusion____"),Utf8("_")) |
54+-------------------------------------------+
55| datafusion                                |
56+-------------------------------------------+
57```"#,
58    standard_argument(name = "str", prefix = "String"),
59    argument(
60        name = "trim_str",
61        description = r"String expression to operate on. Can be a constant, column, or function, and any combination of operators. _Default is whitespace characters._"
62    ),
63    alternative_syntax = "trim(BOTH trim_str FROM str)",
64    alternative_syntax = "trim(trim_str FROM str)",
65    related_udf(name = "ltrim"),
66    related_udf(name = "rtrim")
67)]
68#[derive(Debug)]
69pub struct BTrimFunc {
70    signature: Signature,
71    aliases: Vec<String>,
72}
73
74impl Default for BTrimFunc {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl BTrimFunc {
81    pub fn new() -> Self {
82        Self {
83            signature: Signature::one_of(
84                vec![
85                    TypeSignature::Coercible(vec![
86                        Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
87                        Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
88                    ]),
89                    TypeSignature::Coercible(vec![Coercion::new_exact(
90                        TypeSignatureClass::Native(logical_string()),
91                    )]),
92                ],
93                Volatility::Immutable,
94            ),
95            aliases: vec![String::from("trim")],
96        }
97    }
98}
99
100impl ScalarUDFImpl for BTrimFunc {
101    fn as_any(&self) -> &dyn Any {
102        self
103    }
104
105    fn name(&self) -> &str {
106        "btrim"
107    }
108
109    fn signature(&self) -> &Signature {
110        &self.signature
111    }
112
113    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
114        if arg_types[0] == DataType::Utf8View {
115            Ok(DataType::Utf8View)
116        } else {
117            utf8_to_str_type(&arg_types[0], "btrim")
118        }
119    }
120
121    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
122        match args.args[0].data_type() {
123            DataType::Utf8 | DataType::Utf8View => make_scalar_function(
124                btrim::<i32>,
125                vec![Hint::Pad, Hint::AcceptsSingular],
126            )(&args.args),
127            DataType::LargeUtf8 => make_scalar_function(
128                btrim::<i64>,
129                vec![Hint::Pad, Hint::AcceptsSingular],
130            )(&args.args),
131            other => exec_err!(
132                "Unsupported data type {other:?} for function btrim,\
133                expected Utf8, LargeUtf8 or Utf8View."
134            ),
135        }
136    }
137
138    fn aliases(&self) -> &[String] {
139        &self.aliases
140    }
141
142    fn documentation(&self) -> Option<&Documentation> {
143        self.doc()
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use arrow::array::{Array, StringArray, StringViewArray};
150    use arrow::datatypes::DataType::{Utf8, Utf8View};
151
152    use datafusion_common::{Result, ScalarValue};
153    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
154
155    use crate::string::btrim::BTrimFunc;
156    use crate::utils::test::test_function;
157
158    #[test]
159    fn test_functions() {
160        // String view cases for checking normal logic
161        test_function!(
162            BTrimFunc::new(),
163            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
164                String::from("alphabet  ")
165            )))],
166            Ok(Some("alphabet")),
167            &str,
168            Utf8View,
169            StringViewArray
170        );
171        test_function!(
172            BTrimFunc::new(),
173            vec![ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
174                String::from("  alphabet  ")
175            ))),],
176            Ok(Some("alphabet")),
177            &str,
178            Utf8View,
179            StringViewArray
180        );
181        test_function!(
182            BTrimFunc::new(),
183            vec![
184                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
185                    "alphabet"
186                )))),
187                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("t")))),
188            ],
189            Ok(Some("alphabe")),
190            &str,
191            Utf8View,
192            StringViewArray
193        );
194        test_function!(
195            BTrimFunc::new(),
196            vec![
197                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
198                    "alphabet"
199                )))),
200                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
201                    "alphabe"
202                )))),
203            ],
204            Ok(Some("t")),
205            &str,
206            Utf8View,
207            StringViewArray
208        );
209        test_function!(
210            BTrimFunc::new(),
211            vec![
212                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
213                    "alphabet"
214                )))),
215                ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
216            ],
217            Ok(None),
218            &str,
219            Utf8View,
220            StringViewArray
221        );
222        // Special string view case for checking unlined output(len > 12)
223        test_function!(
224            BTrimFunc::new(),
225            vec![
226                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
227                    "xxxalphabetalphabetxxx"
228                )))),
229                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("x")))),
230            ],
231            Ok(Some("alphabetalphabet")),
232            &str,
233            Utf8View,
234            StringViewArray
235        );
236        // String cases
237        test_function!(
238            BTrimFunc::new(),
239            vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(
240                String::from("alphabet  ")
241            ))),],
242            Ok(Some("alphabet")),
243            &str,
244            Utf8,
245            StringArray
246        );
247        test_function!(
248            BTrimFunc::new(),
249            vec![ColumnarValue::Scalar(ScalarValue::Utf8(Some(
250                String::from("alphabet  ")
251            ))),],
252            Ok(Some("alphabet")),
253            &str,
254            Utf8,
255            StringArray
256        );
257        test_function!(
258            BTrimFunc::new(),
259            vec![
260                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))),
261                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("t")))),
262            ],
263            Ok(Some("alphabe")),
264            &str,
265            Utf8,
266            StringArray
267        );
268        test_function!(
269            BTrimFunc::new(),
270            vec![
271                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))),
272                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabe")))),
273            ],
274            Ok(Some("t")),
275            &str,
276            Utf8,
277            StringArray
278        );
279        test_function!(
280            BTrimFunc::new(),
281            vec![
282                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("alphabet")))),
283                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
284            ],
285            Ok(None),
286            &str,
287            Utf8,
288            StringArray
289        );
290    }
291}