datafusion_functions/string/
overlay.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
22use arrow::datatypes::DataType;
23
24use crate::utils::{make_scalar_function, utf8_to_str_type};
25use datafusion_common::cast::{
26    as_generic_string_array, as_int64_array, as_string_view_array,
27};
28use datafusion_common::{exec_err, Result};
29use datafusion_expr::{ColumnarValue, Documentation, TypeSignature, Volatility};
30use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature};
31use datafusion_macros::user_doc;
32
33#[user_doc(
34    doc_section(label = "String Functions"),
35    description = "Returns the string which is replaced by another string from the specified position and specified count length.",
36    syntax_example = "overlay(str PLACING substr FROM pos [FOR count])",
37    sql_example = r#"```sql
38> select overlay('Txxxxas' placing 'hom' from 2 for 4);
39+--------------------------------------------------------+
40| overlay(Utf8("Txxxxas"),Utf8("hom"),Int64(2),Int64(4)) |
41+--------------------------------------------------------+
42| Thomas                                                 |
43+--------------------------------------------------------+
44```"#,
45    standard_argument(name = "str", prefix = "String"),
46    argument(name = "substr", description = "Substring to replace in str."),
47    argument(
48        name = "pos",
49        description = "The start position to start the replace in str."
50    ),
51    argument(
52        name = "count",
53        description = "The count of characters to be replaced from start position of str. If not specified, will use substr length instead."
54    )
55)]
56#[derive(Debug)]
57pub struct OverlayFunc {
58    signature: Signature,
59}
60
61impl Default for OverlayFunc {
62    fn default() -> Self {
63        Self::new()
64    }
65}
66
67impl OverlayFunc {
68    pub fn new() -> Self {
69        use DataType::*;
70        Self {
71            signature: Signature::one_of(
72                vec![
73                    TypeSignature::Exact(vec![Utf8View, Utf8View, Int64, Int64]),
74                    TypeSignature::Exact(vec![Utf8, Utf8, Int64, Int64]),
75                    TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64, Int64]),
76                    TypeSignature::Exact(vec![Utf8View, Utf8View, Int64]),
77                    TypeSignature::Exact(vec![Utf8, Utf8, Int64]),
78                    TypeSignature::Exact(vec![LargeUtf8, LargeUtf8, Int64]),
79                ],
80                Volatility::Immutable,
81            ),
82        }
83    }
84}
85
86impl ScalarUDFImpl for OverlayFunc {
87    fn as_any(&self) -> &dyn Any {
88        self
89    }
90
91    fn name(&self) -> &str {
92        "overlay"
93    }
94
95    fn signature(&self) -> &Signature {
96        &self.signature
97    }
98
99    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
100        utf8_to_str_type(&arg_types[0], "overlay")
101    }
102
103    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
104        match args.args[0].data_type() {
105            DataType::Utf8View | DataType::Utf8 => {
106                make_scalar_function(overlay::<i32>, vec![])(&args.args)
107            }
108            DataType::LargeUtf8 => {
109                make_scalar_function(overlay::<i64>, vec![])(&args.args)
110            }
111            other => exec_err!("Unsupported data type {other:?} for function overlay"),
112        }
113    }
114
115    fn documentation(&self) -> Option<&Documentation> {
116        self.doc()
117    }
118}
119
120macro_rules! process_overlay {
121    // For the three-argument case
122    ($string_array:expr, $characters_array:expr, $pos_num:expr) => {{
123        $string_array
124        .iter()
125        .zip($characters_array.iter())
126        .zip($pos_num.iter())
127        .map(|((string, characters), start_pos)| {
128            match (string, characters, start_pos) {
129                (Some(string), Some(characters), Some(start_pos)) => {
130                    let string_len = string.chars().count();
131                    let characters_len = characters.chars().count();
132                    let replace_len = characters_len as i64;
133                    let mut res =
134                        String::with_capacity(string_len.max(characters_len));
135
136                    //as sql replace index start from 1 while string index start from 0
137                    if start_pos > 1 && start_pos - 1 < string_len as i64 {
138                        let start = (start_pos - 1) as usize;
139                        res.push_str(&string[..start]);
140                    }
141                    res.push_str(characters);
142                    // if start + replace_len - 1 >= string_length, just to string end
143                    if start_pos + replace_len - 1 < string_len as i64 {
144                        let end = (start_pos + replace_len - 1) as usize;
145                        res.push_str(&string[end..]);
146                    }
147                    Ok(Some(res))
148                }
149                _ => Ok(None),
150            }
151        })
152        .collect::<Result<GenericStringArray<T>>>()
153    }};
154
155    // For the four-argument case
156    ($string_array:expr, $characters_array:expr, $pos_num:expr, $len_num:expr) => {{
157        $string_array
158        .iter()
159        .zip($characters_array.iter())
160        .zip($pos_num.iter())
161        .zip($len_num.iter())
162        .map(|(((string, characters), start_pos), len)| {
163            match (string, characters, start_pos, len) {
164                (Some(string), Some(characters), Some(start_pos), Some(len)) => {
165                    let string_len = string.chars().count();
166                    let characters_len = characters.chars().count();
167                    let replace_len = len.min(string_len as i64);
168                    let mut res =
169                        String::with_capacity(string_len.max(characters_len));
170
171                    //as sql replace index start from 1 while string index start from 0
172                    if start_pos > 1 && start_pos - 1 < string_len as i64 {
173                        let start = (start_pos - 1) as usize;
174                        res.push_str(&string[..start]);
175                    }
176                    res.push_str(characters);
177                    // if start + replace_len - 1 >= string_length, just to string end
178                    if start_pos + replace_len - 1 < string_len as i64 {
179                        let end = (start_pos + replace_len - 1) as usize;
180                        res.push_str(&string[end..]);
181                    }
182                    Ok(Some(res))
183                }
184                _ => Ok(None),
185            }
186        })
187        .collect::<Result<GenericStringArray<T>>>()
188    }};
189}
190
191/// OVERLAY(string1 PLACING string2 FROM integer FOR integer2)
192/// Replaces a substring of string1 with string2 starting at the integer bit
193/// pgsql overlay('Txxxxas' placing 'hom' from 2 for 4) → Thomas
194/// overlay('Txxxxas' placing 'hom' from 2) -> Thomxas, without for option, str2's len is instead
195fn overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
196    let use_string_view = args[0].data_type() == &DataType::Utf8View;
197    if use_string_view {
198        string_view_overlay::<T>(args)
199    } else {
200        string_overlay::<T>(args)
201    }
202}
203
204pub fn string_overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
205    match args.len() {
206        3 => {
207            let string_array = as_generic_string_array::<T>(&args[0])?;
208            let characters_array = as_generic_string_array::<T>(&args[1])?;
209            let pos_num = as_int64_array(&args[2])?;
210
211            let result = process_overlay!(string_array, characters_array, pos_num)?;
212            Ok(Arc::new(result) as ArrayRef)
213        }
214        4 => {
215            let string_array = as_generic_string_array::<T>(&args[0])?;
216            let characters_array = as_generic_string_array::<T>(&args[1])?;
217            let pos_num = as_int64_array(&args[2])?;
218            let len_num = as_int64_array(&args[3])?;
219
220            let result =
221                process_overlay!(string_array, characters_array, pos_num, len_num)?;
222            Ok(Arc::new(result) as ArrayRef)
223        }
224        other => {
225            exec_err!("overlay was called with {other} arguments. It requires 3 or 4.")
226        }
227    }
228}
229
230pub fn string_view_overlay<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
231    match args.len() {
232        3 => {
233            let string_array = as_string_view_array(&args[0])?;
234            let characters_array = as_string_view_array(&args[1])?;
235            let pos_num = as_int64_array(&args[2])?;
236
237            let result = process_overlay!(string_array, characters_array, pos_num)?;
238            Ok(Arc::new(result) as ArrayRef)
239        }
240        4 => {
241            let string_array = as_string_view_array(&args[0])?;
242            let characters_array = as_string_view_array(&args[1])?;
243            let pos_num = as_int64_array(&args[2])?;
244            let len_num = as_int64_array(&args[3])?;
245
246            let result =
247                process_overlay!(string_array, characters_array, pos_num, len_num)?;
248            Ok(Arc::new(result) as ArrayRef)
249        }
250        other => {
251            exec_err!("overlay was called with {other} arguments. It requires 3 or 4.")
252        }
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use arrow::array::{Int64Array, StringArray};
259
260    use super::*;
261
262    #[test]
263    fn to_overlay() -> Result<()> {
264        let string =
265            Arc::new(StringArray::from(vec!["123", "abcdefg", "xyz", "Txxxxas"]));
266        let replace_string =
267            Arc::new(StringArray::from(vec!["abc", "qwertyasdfg", "ijk", "hom"]));
268        let start = Arc::new(Int64Array::from(vec![4, 1, 1, 2])); // start
269        let end = Arc::new(Int64Array::from(vec![5, 7, 2, 4])); // replace len
270
271        let res = overlay::<i32>(&[string, replace_string, start, end]).unwrap();
272        let result = as_generic_string_array::<i32>(&res).unwrap();
273        let expected = StringArray::from(vec!["abc", "qwertyasdfg", "ijkz", "Thomas"]);
274        assert_eq!(&expected, result);
275
276        Ok(())
277    }
278}