datafusion_functions/string/
contains.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::utils::make_scalar_function;
19use arrow::array::{Array, ArrayRef, AsArray};
20use arrow::compute::contains as arrow_contains;
21use arrow::datatypes::DataType;
22use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View};
23use datafusion_common::exec_err;
24use datafusion_common::DataFusionError;
25use datafusion_common::Result;
26use datafusion_expr::{
27    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
28    Volatility,
29};
30use datafusion_macros::user_doc;
31use std::any::Any;
32use std::sync::Arc;
33
34#[user_doc(
35    doc_section(label = "String Functions"),
36    description = "Return true if search_str is found within string (case-sensitive).",
37    syntax_example = "contains(str, search_str)",
38    sql_example = r#"```sql
39> select contains('the quick brown fox', 'row');
40+---------------------------------------------------+
41| contains(Utf8("the quick brown fox"),Utf8("row")) |
42+---------------------------------------------------+
43| true                                              |
44+---------------------------------------------------+
45```"#,
46    standard_argument(name = "str", prefix = "String"),
47    argument(name = "search_str", description = "The string to search for in str.")
48)]
49#[derive(Debug)]
50pub struct ContainsFunc {
51    signature: Signature,
52}
53
54impl Default for ContainsFunc {
55    fn default() -> Self {
56        ContainsFunc::new()
57    }
58}
59
60impl ContainsFunc {
61    pub fn new() -> Self {
62        Self {
63            signature: Signature::string(2, Volatility::Immutable),
64        }
65    }
66}
67
68impl ScalarUDFImpl for ContainsFunc {
69    fn as_any(&self) -> &dyn Any {
70        self
71    }
72
73    fn name(&self) -> &str {
74        "contains"
75    }
76
77    fn signature(&self) -> &Signature {
78        &self.signature
79    }
80
81    fn return_type(&self, _: &[DataType]) -> Result<DataType> {
82        Ok(Boolean)
83    }
84
85    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
86        make_scalar_function(contains, vec![])(&args.args)
87    }
88
89    fn documentation(&self) -> Option<&Documentation> {
90        self.doc()
91    }
92}
93
94/// use `arrow::compute::contains` to do the calculation for contains
95pub fn contains(args: &[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
96    match (args[0].data_type(), args[1].data_type()) {
97        (Utf8View, Utf8View) => {
98            let mod_str = args[0].as_string_view();
99            let match_str = args[1].as_string_view();
100            let res = arrow_contains(mod_str, match_str)?;
101            Ok(Arc::new(res) as ArrayRef)
102        }
103        (Utf8, Utf8) => {
104            let mod_str = args[0].as_string::<i32>();
105            let match_str = args[1].as_string::<i32>();
106            let res = arrow_contains(mod_str, match_str)?;
107            Ok(Arc::new(res) as ArrayRef)
108        }
109        (LargeUtf8, LargeUtf8) => {
110            let mod_str = args[0].as_string::<i64>();
111            let match_str = args[1].as_string::<i64>();
112            let res = arrow_contains(mod_str, match_str)?;
113            Ok(Arc::new(res) as ArrayRef)
114        }
115        other => {
116            exec_err!("Unsupported data type {other:?} for function `contains`.")
117        }
118    }
119}
120
121#[cfg(test)]
122mod test {
123    use super::ContainsFunc;
124    use arrow::array::{BooleanArray, StringArray};
125    use arrow::datatypes::DataType;
126    use datafusion_common::ScalarValue;
127    use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
128    use std::sync::Arc;
129
130    #[test]
131    fn test_contains_udf() {
132        let udf = ContainsFunc::new();
133        let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
134            Some("xxx?()"),
135            Some("yyy?()"),
136        ])));
137        let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string())));
138
139        let args = ScalarFunctionArgs {
140            args: vec![array, scalar],
141            number_rows: 2,
142            return_type: &DataType::Boolean,
143        };
144
145        let actual = udf.invoke_with_args(args).unwrap();
146        let expect = ColumnarValue::Array(Arc::new(BooleanArray::from(vec![
147            Some(true),
148            Some(false),
149        ])));
150        assert_eq!(
151            *actual.into_array(2).unwrap(),
152            *expect.into_array(2).unwrap()
153        );
154    }
155}