datafusion_functions/string/
contains.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::utils::make_scalar_function;
19use arrow::array::{Array, ArrayRef, AsArray};
20use arrow::compute::contains as arrow_contains;
21use arrow::datatypes::DataType;
22use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View};
23use datafusion_common::types::logical_string;
24use datafusion_common::{exec_err, DataFusionError, Result};
25use datafusion_expr::binary::{binary_to_string_coercion, string_coercion};
26use datafusion_expr::{
27    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
28    TypeSignatureClass, Volatility,
29};
30use datafusion_macros::user_doc;
31use std::any::Any;
32use std::sync::Arc;
33
34#[user_doc(
35    doc_section(label = "String Functions"),
36    description = "Return true if search_str is found within string (case-sensitive).",
37    syntax_example = "contains(str, search_str)",
38    sql_example = r#"```sql
39> select contains('the quick brown fox', 'row');
40+---------------------------------------------------+
41| contains(Utf8("the quick brown fox"),Utf8("row")) |
42+---------------------------------------------------+
43| true                                              |
44+---------------------------------------------------+
45```"#,
46    standard_argument(name = "str", prefix = "String"),
47    argument(name = "search_str", description = "The string to search for in str.")
48)]
49#[derive(Debug)]
50pub struct ContainsFunc {
51    signature: Signature,
52}
53
54impl Default for ContainsFunc {
55    fn default() -> Self {
56        ContainsFunc::new()
57    }
58}
59
60impl ContainsFunc {
61    pub fn new() -> Self {
62        Self {
63            signature: Signature::coercible(
64                vec![
65                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
66                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
67                ],
68                Volatility::Immutable,
69            ),
70        }
71    }
72}
73
74impl ScalarUDFImpl for ContainsFunc {
75    fn as_any(&self) -> &dyn Any {
76        self
77    }
78
79    fn name(&self) -> &str {
80        "contains"
81    }
82
83    fn signature(&self) -> &Signature {
84        &self.signature
85    }
86
87    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
88        Ok(Boolean)
89    }
90
91    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
92        make_scalar_function(contains, vec![])(&args.args)
93    }
94
95    fn documentation(&self) -> Option<&Documentation> {
96        self.doc()
97    }
98}
99
100/// use `arrow::compute::contains` to do the calculation for contains
101fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
102    if let Some(coercion_data_type) =
103        string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| {
104            binary_to_string_coercion(args[0].data_type(), args[1].data_type())
105        })
106    {
107        let arg0 = if args[0].data_type() == &coercion_data_type {
108            Arc::clone(&args[0])
109        } else {
110            arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)?
111        };
112        let arg1 = if args[1].data_type() == &coercion_data_type {
113            Arc::clone(&args[1])
114        } else {
115            arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)?
116        };
117
118        match coercion_data_type {
119            Utf8View => {
120                let mod_str = arg0.as_string_view();
121                let match_str = arg1.as_string_view();
122                let res = arrow_contains(mod_str, match_str)?;
123                Ok(Arc::new(res) as ArrayRef)
124            }
125            Utf8 => {
126                let mod_str = arg0.as_string::<i32>();
127                let match_str = arg1.as_string::<i32>();
128                let res = arrow_contains(mod_str, match_str)?;
129                Ok(Arc::new(res) as ArrayRef)
130            }
131            LargeUtf8 => {
132                let mod_str = arg0.as_string::<i64>();
133                let match_str = arg1.as_string::<i64>();
134                let res = arrow_contains(mod_str, match_str)?;
135                Ok(Arc::new(res) as ArrayRef)
136            }
137            other => {
138                exec_err!("Unsupported data type {other:?} for function `contains`.")
139            }
140        }
141    } else {
142        exec_err!(
143            "Unsupported data type {:?}, {:?} for function `contains`.",
144            args[0].data_type(),
145            args[1].data_type()
146        )
147    }
148}
149
150#[cfg(test)]
151mod test {
152    use super::ContainsFunc;
153    use arrow::array::{BooleanArray, StringArray};
154    use arrow::datatypes::DataType;
155    use datafusion_common::ScalarValue;
156    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
157    use std::sync::Arc;
158
159    #[test]
160    fn test_contains_udf() {
161        let udf = ContainsFunc::new();
162        let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
163            Some("xxx?()"),
164            Some("yyy?()"),
165        ])));
166        let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string())));
167
168        let args = ScalarFunctionArgs {
169            args: vec![array, scalar],
170            number_rows: 2,
171            return_type: &DataType::Boolean,
172        };
173
174        let actual = udf.invoke_with_args(args).unwrap();
175        let expect = ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
176            Some(true),
177            Some(false),
178        ])));
179        assert_eq!(
180            *actual.into_array(2).unwrap(),
181            *expect.into_array(2).unwrap()
182        );
183    }
184}