datafusion_functions/core/
nvl.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::Array;
19use arrow::compute::is_not_null;
20use arrow::compute::kernels::zip::zip;
21use arrow::datatypes::DataType;
22use datafusion_common::{utils::take_function_args, Result};
23use datafusion_expr::{
24    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
25    Volatility,
26};
27use datafusion_macros::user_doc;
28use std::sync::Arc;
29
30#[user_doc(
31    doc_section(label = "Conditional Functions"),
32    description = "Returns _expression2_ if _expression1_ is NULL otherwise it returns _expression1_.",
33    syntax_example = "nvl(expression1, expression2)",
34    sql_example = r#"```sql
35> select nvl(null, 'a');
36+---------------------+
37| nvl(NULL,Utf8("a")) |
38+---------------------+
39| a                   |
40+---------------------+\
41> select nvl('b', 'a');
42+--------------------------+
43| nvl(Utf8("b"),Utf8("a")) |
44+--------------------------+
45| b                        |
46+--------------------------+
47```
48"#,
49    argument(
50        name = "expression1",
51        description = "Expression to return if not null. Can be a constant, column, or function, and any combination of operators."
52    ),
53    argument(
54        name = "expression2",
55        description = "Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators."
56    )
57)]
58#[derive(Debug)]
59pub struct NVLFunc {
60    signature: Signature,
61    aliases: Vec<String>,
62}
63
64/// Currently supported types by the nvl/ifnull function.
65/// The order of these types correspond to the order on which coercion applies
66/// This should thus be from least informative to most informative
67static SUPPORTED_NVL_TYPES: &[DataType] = &[
68    DataType::Boolean,
69    DataType::UInt8,
70    DataType::UInt16,
71    DataType::UInt32,
72    DataType::UInt64,
73    DataType::Int8,
74    DataType::Int16,
75    DataType::Int32,
76    DataType::Int64,
77    DataType::Float32,
78    DataType::Float64,
79    DataType::Utf8View,
80    DataType::Utf8,
81    DataType::LargeUtf8,
82];
83
84impl Default for NVLFunc {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90impl NVLFunc {
91    pub fn new() -> Self {
92        Self {
93            signature: Signature::uniform(
94                2,
95                SUPPORTED_NVL_TYPES.to_vec(),
96                Volatility::Immutable,
97            ),
98            aliases: vec![String::from("ifnull")],
99        }
100    }
101}
102
103impl ScalarUDFImpl for NVLFunc {
104    fn as_any(&self) -> &dyn std::any::Any {
105        self
106    }
107
108    fn name(&self) -> &str {
109        "nvl"
110    }
111
112    fn signature(&self) -> &Signature {
113        &self.signature
114    }
115
116    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
117        Ok(arg_types[0].clone())
118    }
119
120    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
121        nvl_func(&args.args)
122    }
123
124    fn aliases(&self) -> &[String] {
125        &self.aliases
126    }
127
128    fn documentation(&self) -> Option<&Documentation> {
129        self.doc()
130    }
131}
132
133fn nvl_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
134    let [lhs, rhs] = take_function_args("nvl/ifnull", args)?;
135    let (lhs_array, rhs_array) = match (lhs, rhs) {
136        (ColumnarValue::Array(lhs), ColumnarValue::Scalar(rhs)) => {
137            (Arc::clone(lhs), rhs.to_array_of_size(lhs.len())?)
138        }
139        (ColumnarValue::Array(lhs), ColumnarValue::Array(rhs)) => {
140            (Arc::clone(lhs), Arc::clone(rhs))
141        }
142        (ColumnarValue::Scalar(lhs), ColumnarValue::Array(rhs)) => {
143            (lhs.to_array_of_size(rhs.len())?, Arc::clone(rhs))
144        }
145        (ColumnarValue::Scalar(lhs), ColumnarValue::Scalar(rhs)) => {
146            let mut current_value = lhs;
147            if lhs.is_null() {
148                current_value = rhs;
149            }
150            return Ok(ColumnarValue::Scalar(current_value.clone()));
151        }
152    };
153    let to_apply = is_not_null(&lhs_array)?;
154    let value = zip(&to_apply, &lhs_array, &rhs_array)?;
155    Ok(ColumnarValue::Array(value))
156}
157
158#[cfg(test)]
159mod tests {
160    use std::sync::Arc;
161
162    use arrow::array::*;
163
164    use super::*;
165    use datafusion_common::ScalarValue;
166
167    #[test]
168    fn nvl_int32() -> Result<()> {
169        let a = Int32Array::from(vec![
170            Some(1),
171            Some(2),
172            None,
173            None,
174            Some(3),
175            None,
176            None,
177            Some(4),
178            Some(5),
179        ]);
180        let a = ColumnarValue::Array(Arc::new(a));
181
182        let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(6i32)));
183
184        let result = nvl_func(&[a, lit_array])?;
185        let result = result.into_array(0).expect("Failed to convert to array");
186
187        let expected = Arc::new(Int32Array::from(vec![
188            Some(1),
189            Some(2),
190            Some(6),
191            Some(6),
192            Some(3),
193            Some(6),
194            Some(6),
195            Some(4),
196            Some(5),
197        ])) as ArrayRef;
198        assert_eq!(expected.as_ref(), result.as_ref());
199        Ok(())
200    }
201
202    #[test]
203    // Ensure that arrays with no nulls can also invoke nvl() correctly
204    fn nvl_int32_non_nulls() -> Result<()> {
205        let a = Int32Array::from(vec![1, 3, 10, 7, 8, 1, 2, 4, 5]);
206        let a = ColumnarValue::Array(Arc::new(a));
207
208        let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(20i32)));
209
210        let result = nvl_func(&[a, lit_array])?;
211        let result = result.into_array(0).expect("Failed to convert to array");
212
213        let expected = Arc::new(Int32Array::from(vec![
214            Some(1),
215            Some(3),
216            Some(10),
217            Some(7),
218            Some(8),
219            Some(1),
220            Some(2),
221            Some(4),
222            Some(5),
223        ])) as ArrayRef;
224        assert_eq!(expected.as_ref(), result.as_ref());
225        Ok(())
226    }
227
228    #[test]
229    fn nvl_boolean() -> Result<()> {
230        let a = BooleanArray::from(vec![Some(true), Some(false), None]);
231        let a = ColumnarValue::Array(Arc::new(a));
232
233        let lit_array = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false)));
234
235        let result = nvl_func(&[a, lit_array])?;
236        let result = result.into_array(0).expect("Failed to convert to array");
237
238        let expected = Arc::new(BooleanArray::from(vec![
239            Some(true),
240            Some(false),
241            Some(false),
242        ])) as ArrayRef;
243
244        assert_eq!(expected.as_ref(), result.as_ref());
245        Ok(())
246    }
247
248    #[test]
249    fn nvl_string() -> Result<()> {
250        let a = StringArray::from(vec![Some("foo"), Some("bar"), None, Some("baz")]);
251        let a = ColumnarValue::Array(Arc::new(a));
252
253        let lit_array = ColumnarValue::Scalar(ScalarValue::from("bax"));
254
255        let result = nvl_func(&[a, lit_array])?;
256        let result = result.into_array(0).expect("Failed to convert to array");
257
258        let expected = Arc::new(StringArray::from(vec![
259            Some("foo"),
260            Some("bar"),
261            Some("bax"),
262            Some("baz"),
263        ])) as ArrayRef;
264
265        assert_eq!(expected.as_ref(), result.as_ref());
266        Ok(())
267    }
268
269    #[test]
270    fn nvl_literal_first() -> Result<()> {
271        let a = Int32Array::from(vec![Some(1), Some(2), None, None, Some(3), Some(4)]);
272        let a = ColumnarValue::Array(Arc::new(a));
273
274        let lit_array = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));
275
276        let result = nvl_func(&[lit_array, a])?;
277        let result = result.into_array(0).expect("Failed to convert to array");
278
279        let expected = Arc::new(Int32Array::from(vec![
280            Some(2),
281            Some(2),
282            Some(2),
283            Some(2),
284            Some(2),
285            Some(2),
286        ])) as ArrayRef;
287        assert_eq!(expected.as_ref(), result.as_ref());
288        Ok(())
289    }
290
291    #[test]
292    fn nvl_scalar() -> Result<()> {
293        let a_null = ColumnarValue::Scalar(ScalarValue::Int32(None));
294        let b_null = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));
295
296        let result_null = nvl_func(&[a_null, b_null])?;
297        let result_null = result_null
298            .into_array(1)
299            .expect("Failed to convert to array");
300
301        let expected_null = Arc::new(Int32Array::from(vec![Some(2i32)])) as ArrayRef;
302
303        assert_eq!(expected_null.as_ref(), result_null.as_ref());
304
305        let a_nnull = ColumnarValue::Scalar(ScalarValue::Int32(Some(2i32)));
306        let b_nnull = ColumnarValue::Scalar(ScalarValue::Int32(Some(1i32)));
307
308        let result_nnull = nvl_func(&[a_nnull, b_nnull])?;
309        let result_nnull = result_nnull
310            .into_array(1)
311            .expect("Failed to convert to array");
312
313        let expected_nnull = Arc::new(Int32Array::from(vec![Some(2i32)])) as ArrayRef;
314        assert_eq!(expected_nnull.as_ref(), result_nnull.as_ref());
315
316        Ok(())
317    }
318}