datafusion_functions/string/
replace.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray};
22use arrow::datatypes::DataType;
23
24use crate::utils::{make_scalar_function, utf8_to_str_type};
25use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
26use datafusion_common::types::logical_string;
27use datafusion_common::{exec_err, Result};
28use datafusion_expr::type_coercion::binary::{
29    binary_to_string_coercion, string_coercion,
30};
31use datafusion_expr::{
32    Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
33    TypeSignatureClass, Volatility,
34};
35use datafusion_macros::user_doc;
36#[user_doc(
37    doc_section(label = "String Functions"),
38    description = "Replaces all occurrences of a specified substring in a string with a new substring.",
39    syntax_example = "replace(str, substr, replacement)",
40    sql_example = r#"```sql
41> select replace('ABabbaBA', 'ab', 'cd');
42+-------------------------------------------------+
43| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) |
44+-------------------------------------------------+
45| ABcdbaBA                                        |
46+-------------------------------------------------+
47```"#,
48    standard_argument(name = "str", prefix = "String"),
49    standard_argument(
50        name = "substr",
51        prefix = "Substring expression to replace in the input string. Substring"
52    ),
53    standard_argument(name = "replacement", prefix = "Replacement substring")
54)]
55#[derive(Debug)]
56pub struct ReplaceFunc {
57    signature: Signature,
58}
59
60impl Default for ReplaceFunc {
61    fn default() -> Self {
62        Self::new()
63    }
64}
65
66impl ReplaceFunc {
67    pub fn new() -> Self {
68        Self {
69            signature: Signature::coercible(
70                vec![
71                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
72                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
73                    Coercion::new_exact(TypeSignatureClass::Native(logical_string())),
74                ],
75                Volatility::Immutable,
76            ),
77        }
78    }
79}
80
81impl ScalarUDFImpl for ReplaceFunc {
82    fn as_any(&self) -> &dyn Any {
83        self
84    }
85
86    fn name(&self) -> &str {
87        "replace"
88    }
89
90    fn signature(&self) -> &Signature {
91        &self.signature
92    }
93
94    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
95        if let Some(coercion_data_type) = string_coercion(&arg_types[0], &arg_types[1])
96            .and_then(|dt| string_coercion(&dt, &arg_types[2]))
97            .or_else(|| {
98                binary_to_string_coercion(&arg_types[0], &arg_types[1])
99                    .and_then(|dt| binary_to_string_coercion(&dt, &arg_types[2]))
100            })
101        {
102            utf8_to_str_type(&coercion_data_type, "replace")
103        } else {
104            exec_err!("Unsupported data types for replace. Expected Utf8, LargeUtf8 or Utf8View")
105        }
106    }
107
108    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
109        let data_types = args
110            .args
111            .iter()
112            .map(|arg| arg.data_type())
113            .collect::<Vec<_>>();
114
115        if let Some(coercion_type) = string_coercion(&data_types[0], &data_types[1])
116            .and_then(|dt| string_coercion(&dt, &data_types[2]))
117            .or_else(|| {
118                binary_to_string_coercion(&data_types[0], &data_types[1])
119                    .and_then(|dt| binary_to_string_coercion(&dt, &data_types[2]))
120            })
121        {
122            let mut converted_args = Vec::with_capacity(args.args.len());
123            for arg in &args.args {
124                if arg.data_type() == coercion_type {
125                    converted_args.push(arg.clone());
126                } else {
127                    let converted = arg.cast_to(&coercion_type, None)?;
128                    converted_args.push(converted);
129                }
130            }
131
132            match coercion_type {
133                DataType::Utf8 => {
134                    make_scalar_function(replace::<i32>, vec![])(&converted_args)
135                }
136                DataType::LargeUtf8 => {
137                    make_scalar_function(replace::<i64>, vec![])(&converted_args)
138                }
139                DataType::Utf8View => {
140                    make_scalar_function(replace_view, vec![])(&converted_args)
141                }
142                other => exec_err!(
143                    "Unsupported coercion data type {other:?} for function replace"
144                ),
145            }
146        } else {
147            exec_err!(
148                "Unsupported data type {:?}, {:?}, {:?} for function replace.",
149                data_types[0],
150                data_types[1],
151                data_types[2]
152            )
153        }
154    }
155
156    fn documentation(&self) -> Option<&Documentation> {
157        self.doc()
158    }
159}
160
161fn replace_view(args: &[ArrayRef]) -> Result<ArrayRef> {
162    let string_array = as_string_view_array(&args[0])?;
163    let from_array = as_string_view_array(&args[1])?;
164    let to_array = as_string_view_array(&args[2])?;
165
166    let result = string_array
167        .iter()
168        .zip(from_array.iter())
169        .zip(to_array.iter())
170        .map(|((string, from), to)| match (string, from, to) {
171            (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)),
172            _ => None,
173        })
174        .collect::<StringArray>();
175
176    Ok(Arc::new(result) as ArrayRef)
177}
178
179/// Replaces all occurrences in string of substring from with substring to.
180/// replace('abcdefabcdef', 'cd', 'XX') = 'abXXefabXXef'
181fn replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
182    let string_array = as_generic_string_array::<T>(&args[0])?;
183    let from_array = as_generic_string_array::<T>(&args[1])?;
184    let to_array = as_generic_string_array::<T>(&args[2])?;
185
186    let result = string_array
187        .iter()
188        .zip(from_array.iter())
189        .zip(to_array.iter())
190        .map(|((string, from), to)| match (string, from, to) {
191            (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)),
192            _ => None,
193        })
194        .collect::<GenericStringArray<T>>();
195
196    Ok(Arc::new(result) as ArrayRef)
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use crate::utils::test::test_function;
203    use arrow::array::Array;
204    use arrow::array::LargeStringArray;
205    use arrow::array::StringArray;
206    use arrow::datatypes::DataType::{LargeUtf8, Utf8};
207    use datafusion_common::ScalarValue;
208    #[test]
209    fn test_functions() -> Result<()> {
210        test_function!(
211            ReplaceFunc::new(),
212            vec![
213                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))),
214                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))),
215                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))),
216            ],
217            Ok(Some("aacccdqcccc")),
218            &str,
219            Utf8,
220            StringArray
221        );
222
223        test_function!(
224            ReplaceFunc::new(),
225            vec![
226                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from(
227                    "aabbb"
228                )))),
229                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("bbb")))),
230                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("cc")))),
231            ],
232            Ok(Some("aacc")),
233            &str,
234            LargeUtf8,
235            LargeStringArray
236        );
237
238        test_function!(
239            ReplaceFunc::new(),
240            vec![
241                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
242                    "aabbbcw"
243                )))),
244                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("bb")))),
245                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("cc")))),
246            ],
247            Ok(Some("aaccbcw")),
248            &str,
249            Utf8,
250            StringArray
251        );
252
253        Ok(())
254    }
255}