datafusion_functions/string/
replace.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray};
22use arrow::datatypes::DataType;
23
24use crate::utils::{make_scalar_function, utf8_to_str_type};
25use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
26use datafusion_common::{exec_err, Result};
27use datafusion_expr::{ColumnarValue, Documentation, Volatility};
28use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
29use datafusion_macros::user_doc;
30#[user_doc(
31    doc_section(label = "String Functions"),
32    description = "Replaces all occurrences of a specified substring in a string with a new substring.",
33    syntax_example = "replace(str, substr, replacement)",
34    sql_example = r#"```sql
35> select replace('ABabbaBA', 'ab', 'cd');
36+-------------------------------------------------+
37| replace(Utf8("ABabbaBA"),Utf8("ab"),Utf8("cd")) |
38+-------------------------------------------------+
39| ABcdbaBA                                        |
40+-------------------------------------------------+
41```"#,
42    standard_argument(name = "str", prefix = "String"),
43    standard_argument(
44        name = "substr",
45        prefix = "Substring expression to replace in the input string. Substring"
46    ),
47    standard_argument(name = "replacement", prefix = "Replacement substring")
48)]
49#[derive(Debug)]
50pub struct ReplaceFunc {
51    signature: Signature,
52}
53
54impl Default for ReplaceFunc {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl ReplaceFunc {
61    pub fn new() -> Self {
62        Self {
63            signature: Signature::string(3, Volatility::Immutable),
64        }
65    }
66}
67
68impl ScalarUDFImpl for ReplaceFunc {
69    fn as_any(&self) -> &dyn Any {
70        self
71    }
72
73    fn name(&self) -> &str {
74        "replace"
75    }
76
77    fn signature(&self) -> &Signature {
78        &self.signature
79    }
80
81    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
82        utf8_to_str_type(&arg_types[0], "replace")
83    }
84
85    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
86        match args.args[0].data_type() {
87            DataType::Utf8 => make_scalar_function(replace::<i32>, vec![])(&args.args),
88            DataType::LargeUtf8 => {
89                make_scalar_function(replace::<i64>, vec![])(&args.args)
90            }
91            DataType::Utf8View => make_scalar_function(replace_view, vec![])(&args.args),
92            other => {
93                exec_err!("Unsupported data type {other:?} for function replace")
94            }
95        }
96    }
97
98    fn documentation(&self) -> Option<&Documentation> {
99        self.doc()
100    }
101}
102
103fn replace_view(args: &[ArrayRef]) -> Result<ArrayRef> {
104    let string_array = as_string_view_array(&args[0])?;
105    let from_array = as_string_view_array(&args[1])?;
106    let to_array = as_string_view_array(&args[2])?;
107
108    let result = string_array
109        .iter()
110        .zip(from_array.iter())
111        .zip(to_array.iter())
112        .map(|((string, from), to)| match (string, from, to) {
113            (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)),
114            _ => None,
115        })
116        .collect::<StringArray>();
117
118    Ok(Arc::new(result) as ArrayRef)
119}
120/// Replaces all occurrences in string of substring from with substring to.
121/// replace('abcdefabcdef', 'cd', 'XX') = 'abXXefabXXef'
122fn replace<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
123    let string_array = as_generic_string_array::<T>(&args[0])?;
124    let from_array = as_generic_string_array::<T>(&args[1])?;
125    let to_array = as_generic_string_array::<T>(&args[2])?;
126
127    let result = string_array
128        .iter()
129        .zip(from_array.iter())
130        .zip(to_array.iter())
131        .map(|((string, from), to)| match (string, from, to) {
132            (Some(string), Some(from), Some(to)) => Some(string.replace(from, to)),
133            _ => None,
134        })
135        .collect::<GenericStringArray<T>>();
136
137    Ok(Arc::new(result) as ArrayRef)
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use crate::utils::test::test_function;
144    use arrow::array::Array;
145    use arrow::array::LargeStringArray;
146    use arrow::array::StringArray;
147    use arrow::datatypes::DataType::{LargeUtf8, Utf8};
148    use datafusion_common::ScalarValue;
149    #[test]
150    fn test_functions() -> Result<()> {
151        test_function!(
152            ReplaceFunc::new(),
153            vec![
154                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("aabbdqcbb")))),
155                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("bb")))),
156                ColumnarValue::Scalar(ScalarValue::Utf8(Some(String::from("ccc")))),
157            ],
158            Ok(Some("aacccdqcccc")),
159            &str,
160            Utf8,
161            StringArray
162        );
163
164        test_function!(
165            ReplaceFunc::new(),
166            vec![
167                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from(
168                    "aabbb"
169                )))),
170                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("bbb")))),
171                ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(String::from("cc")))),
172            ],
173            Ok(Some("aacc")),
174            &str,
175            LargeUtf8,
176            LargeStringArray
177        );
178
179        test_function!(
180            ReplaceFunc::new(),
181            vec![
182                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
183                    "aabbbcw"
184                )))),
185                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("bb")))),
186                ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("cc")))),
187            ],
188            Ok(Some("aaccbcw")),
189            &str,
190            Utf8,
191            StringArray
192        );
193
194        Ok(())
195    }
196}