datafusion_functions/string/
ends_with.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::ArrayRef;
22use arrow::datatypes::DataType;
23
24use crate::utils::make_scalar_function;
25use datafusion_common::types::logical_string;
26use datafusion_common::{internal_err, Result};
27use datafusion_expr::binary::{binary_to_string_coercion, string_coercion};
28use datafusion_expr::{
29    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30    TypeSignatureClass, Volatility,
31};
32use datafusion_macros::user_doc;
33
34#[user_doc(
35    doc_section(label = "String Functions"),
36    description = "Tests if a string ends with a substring.",
37    syntax_example = "ends_with(str, substr)",
38    sql_example = r#"```sql
39>  select ends_with('datafusion', 'soin');
40+--------------------------------------------+
41| ends_with(Utf8("datafusion"),Utf8("soin")) |
42+--------------------------------------------+
43| false                                      |
44+--------------------------------------------+
45> select ends_with('datafusion', 'sion');
46+--------------------------------------------+
47| ends_with(Utf8("datafusion"),Utf8("sion")) |
48+--------------------------------------------+
49| true                                       |
50+--------------------------------------------+
51```"#,
52    standard_argument(name = "str", prefix = "String"),
53    argument(name = "substr", description = "Substring to test for.")
54)]
55#[derive(Debug)]
56pub struct EndsWithFunc {
57    signature: Signature,
58}
59
60impl Default for EndsWithFunc {
61    fn default() -> Self {
62        EndsWithFunc::new()
63    }
64}
65
66impl EndsWithFunc {
67    pub fn new() -> Self {
68        Self {
69            signature: Signature::coercible(
70                vec![
71                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
72                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
73                ],
74                Volatility::Immutable,
75            ),
76        }
77    }
78}
79
80impl ScalarUDFImpl for EndsWithFunc {
81    fn as_any(&self) -> &dyn Any {
82        self
83    }
84
85    fn name(&self) -> &str {
86        "ends_with"
87    }
88
89    fn signature(&self) -> &Signature {
90        &self.signature
91    }
92
93    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
94        Ok(DataType::Boolean)
95    }
96
97    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
98        match args.args[0].data_type() {
99            DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => {
100                make_scalar_function(ends_with, vec![])(&args.args)
101            }
102            other => {
103                internal_err!("Unsupported data type {other:?} for function ends_with. Expected Utf8, LargeUtf8 or Utf8View")?
104            }
105        }
106    }
107
108    fn documentation(&self) -> Option<&Documentation> {
109        self.doc()
110    }
111}
112
113/// Returns true if string ends with suffix.
114/// ends_with('alphabet', 'abet') = 't'
115fn ends_with(args: &[ArrayRef]) -> Result<ArrayRef> {
116    if let Some(coercion_data_type) =
117        string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| {
118            binary_to_string_coercion(args[0].data_type(), args[1].data_type())
119        })
120    {
121        let arg0 = if args[0].data_type() == &coercion_data_type {
122            Arc::clone(&args[0])
123        } else {
124            arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)?
125        };
126        let arg1 = if args[1].data_type() == &coercion_data_type {
127            Arc::clone(&args[1])
128        } else {
129            arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)?
130        };
131        let result = arrow::compute::kernels::comparison::ends_with(&arg0, &arg1)?;
132        Ok(Arc::new(result) as ArrayRef)
133    } else {
134        internal_err!(
135            "Unsupported data types for ends_with. Expected Utf8, LargeUtf8 or Utf8View"
136        )
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use arrow::array::{Array, BooleanArray};
143    use arrow::datatypes::DataType::Boolean;
144
145    use datafusion_common::Result;
146    use datafusion_common::ScalarValue;
147    use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
148
149    use crate::string::ends_with::EndsWithFunc;
150    use crate::utils::test::test_function;
151
152    #[test]
153    fn test_functions() -> Result<()> {
154        test_function!(
155            EndsWithFunc::new(),
156            vec![
157                ColumnarValue::Scalar(ScalarValue::from("alphabet")),
158                ColumnarValue::Scalar(ScalarValue::from("alph")),
159            ],
160            Ok(Some(false)),
161            bool,
162            Boolean,
163            BooleanArray
164        );
165        test_function!(
166            EndsWithFunc::new(),
167            vec![
168                ColumnarValue::Scalar(ScalarValue::from("alphabet")),
169                ColumnarValue::Scalar(ScalarValue::from("bet")),
170            ],
171            Ok(Some(true)),
172            bool,
173            Boolean,
174            BooleanArray
175        );
176        test_function!(
177            EndsWithFunc::new(),
178            vec![
179                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
180                ColumnarValue::Scalar(ScalarValue::from("alph")),
181            ],
182            Ok(None),
183            bool,
184            Boolean,
185            BooleanArray
186        );
187        test_function!(
188            EndsWithFunc::new(),
189            vec![
190                ColumnarValue::Scalar(ScalarValue::from("alphabet")),
191                ColumnarValue::Scalar(ScalarValue::Utf8(None)),
192            ],
193            Ok(None),
194            bool,
195            Boolean,
196            BooleanArray
197        );
198
199        Ok(())
200    }
201}