datafusion_functions/core/
nvl2.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::array::Array;
19use arrow::compute::is_not_null;
20use arrow::compute::kernels::zip::zip;
21use arrow::datatypes::DataType;
22use datafusion_common::{internal_err, utils::take_function_args, Result};
23use datafusion_expr::{
24    type_coercion::binary::comparison_coercion, ColumnarValue, Documentation,
25    ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
26};
27use datafusion_macros::user_doc;
28use std::sync::Arc;
29
30#[user_doc(
31    doc_section(label = "Conditional Functions"),
32    description = "Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_.",
33    syntax_example = "nvl2(expression1, expression2, expression3)",
34    sql_example = r#"```sql
35> select nvl2(null, 'a', 'b');
36+--------------------------------+
37| nvl2(NULL,Utf8("a"),Utf8("b")) |
38+--------------------------------+
39| b                              |
40+--------------------------------+
41> select nvl2('data', 'a', 'b');
42+----------------------------------------+
43| nvl2(Utf8("data"),Utf8("a"),Utf8("b")) |
44+----------------------------------------+
45| a                                      |
46+----------------------------------------+
47```
48"#,
49    argument(
50        name = "expression1",
51        description = "Expression to test for null. Can be a constant, column, or function, and any combination of operators."
52    ),
53    argument(
54        name = "expression2",
55        description = "Expression to return if expr1 is not null. Can be a constant, column, or function, and any combination of operators."
56    ),
57    argument(
58        name = "expression3",
59        description = "Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators."
60    )
61)]
62#[derive(Debug)]
63pub struct NVL2Func {
64    signature: Signature,
65}
66
67impl Default for NVL2Func {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl NVL2Func {
74    pub fn new() -> Self {
75        Self {
76            signature: Signature::user_defined(Volatility::Immutable),
77        }
78    }
79}
80
81impl ScalarUDFImpl for NVL2Func {
82    fn as_any(&self) -> &dyn std::any::Any {
83        self
84    }
85
86    fn name(&self) -> &str {
87        "nvl2"
88    }
89
90    fn signature(&self) -> &Signature {
91        &self.signature
92    }
93
94    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
95        Ok(arg_types[1].clone())
96    }
97
98    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
99        nvl2_func(&args.args)
100    }
101
102    fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
103        let [tested, if_non_null, if_null] = take_function_args(self.name(), arg_types)?;
104        let new_type =
105            [if_non_null, if_null]
106                .iter()
107                .try_fold(tested.clone(), |acc, x| {
108                    // The coerced types found by `comparison_coercion` are not guaranteed to be
109                    // coercible for the arguments. `comparison_coercion` returns more loose
110                    // types that can be coerced to both `acc` and `x` for comparison purpose.
111                    // See `maybe_data_types` for the actual coercion.
112                    let coerced_type = comparison_coercion(&acc, x);
113                    if let Some(coerced_type) = coerced_type {
114                        Ok(coerced_type)
115                    } else {
116                        internal_err!("Coercion from {acc:?} to {x:?} failed.")
117                    }
118                })?;
119        Ok(vec![new_type; arg_types.len()])
120    }
121
122    fn documentation(&self) -> Option<&Documentation> {
123        self.doc()
124    }
125}
126
127fn nvl2_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
128    let mut len = 1;
129    let mut is_array = false;
130    for arg in args {
131        if let ColumnarValue::Array(array) = arg {
132            len = array.len();
133            is_array = true;
134            break;
135        }
136    }
137    if is_array {
138        let args = args
139            .iter()
140            .map(|arg| match arg {
141                ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(len),
142                ColumnarValue::Array(array) => Ok(Arc::clone(array)),
143            })
144            .collect::<Result<Vec<_>>>()?;
145        let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?;
146        let to_apply = is_not_null(&tested)?;
147        let value = zip(&to_apply, &if_non_null, &if_null)?;
148        Ok(ColumnarValue::Array(value))
149    } else {
150        let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?;
151        match &tested {
152            ColumnarValue::Array(_) => {
153                internal_err!("except Scalar value, but got Array")
154            }
155            ColumnarValue::Scalar(scalar) => {
156                if scalar.is_null() {
157                    Ok(if_null.clone())
158                } else {
159                    Ok(if_non_null.clone())
160                }
161            }
162        }
163    }
164}