datafusion_functions/core/
nvl2.rs1use arrow::array::Array;
19use arrow::compute::is_not_null;
20use arrow::compute::kernels::zip::zip;
21use arrow::datatypes::DataType;
22use datafusion_common::{internal_err, utils::take_function_args, Result};
23use datafusion_expr::{
24 type_coercion::binary::comparison_coercion, ColumnarValue, Documentation,
25 ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility,
26};
27use datafusion_macros::user_doc;
28use std::sync::Arc;
29
30#[user_doc(
31 doc_section(label = "Conditional Functions"),
32 description = "Returns _expression2_ if _expression1_ is not NULL; otherwise it returns _expression3_.",
33 syntax_example = "nvl2(expression1, expression2, expression3)",
34 sql_example = r#"```sql
35> select nvl2(null, 'a', 'b');
36+--------------------------------+
37| nvl2(NULL,Utf8("a"),Utf8("b")) |
38+--------------------------------+
39| b |
40+--------------------------------+
41> select nvl2('data', 'a', 'b');
42+----------------------------------------+
43| nvl2(Utf8("data"),Utf8("a"),Utf8("b")) |
44+----------------------------------------+
45| a |
46+----------------------------------------+
47```
48"#,
49 argument(
50 name = "expression1",
51 description = "Expression to test for null. Can be a constant, column, or function, and any combination of operators."
52 ),
53 argument(
54 name = "expression2",
55 description = "Expression to return if expr1 is not null. Can be a constant, column, or function, and any combination of operators."
56 ),
57 argument(
58 name = "expression3",
59 description = "Expression to return if expr1 is null. Can be a constant, column, or function, and any combination of operators."
60 )
61)]
62#[derive(Debug)]
63pub struct NVL2Func {
64 signature: Signature,
65}
66
67impl Default for NVL2Func {
68 fn default() -> Self {
69 Self::new()
70 }
71}
72
73impl NVL2Func {
74 pub fn new() -> Self {
75 Self {
76 signature: Signature::user_defined(Volatility::Immutable),
77 }
78 }
79}
80
81impl ScalarUDFImpl for NVL2Func {
82 fn as_any(&self) -> &dyn std::any::Any {
83 self
84 }
85
86 fn name(&self) -> &str {
87 "nvl2"
88 }
89
90 fn signature(&self) -> &Signature {
91 &self.signature
92 }
93
94 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
95 Ok(arg_types[1].clone())
96 }
97
98 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
99 nvl2_func(&args.args)
100 }
101
102 fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
103 let [tested, if_non_null, if_null] = take_function_args(self.name(), arg_types)?;
104 let new_type =
105 [if_non_null, if_null]
106 .iter()
107 .try_fold(tested.clone(), |acc, x| {
108 let coerced_type = comparison_coercion(&acc, x);
113 if let Some(coerced_type) = coerced_type {
114 Ok(coerced_type)
115 } else {
116 internal_err!("Coercion from {acc:?} to {x:?} failed.")
117 }
118 })?;
119 Ok(vec![new_type; arg_types.len()])
120 }
121
122 fn documentation(&self) -> Option<&Documentation> {
123 self.doc()
124 }
125}
126
127fn nvl2_func(args: &[ColumnarValue]) -> Result<ColumnarValue> {
128 let mut len = 1;
129 let mut is_array = false;
130 for arg in args {
131 if let ColumnarValue::Array(array) = arg {
132 len = array.len();
133 is_array = true;
134 break;
135 }
136 }
137 if is_array {
138 let args = args
139 .iter()
140 .map(|arg| match arg {
141 ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(len),
142 ColumnarValue::Array(array) => Ok(Arc::clone(array)),
143 })
144 .collect::<Result<Vec<_>>>()?;
145 let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?;
146 let to_apply = is_not_null(&tested)?;
147 let value = zip(&to_apply, &if_non_null, &if_null)?;
148 Ok(ColumnarValue::Array(value))
149 } else {
150 let [tested, if_non_null, if_null] = take_function_args("nvl2", args)?;
151 match &tested {
152 ColumnarValue::Array(_) => {
153 internal_err!("except Scalar value, but got Array")
154 }
155 ColumnarValue::Scalar(scalar) => {
156 if scalar.is_null() {
157 Ok(if_null.clone())
158 } else {
159 Ok(if_non_null.clone())
160 }
161 }
162 }
163 }
164}