datafusion_functions/math/
nanvl.rs1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::make_scalar_function;
22
23use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array};
24use arrow::datatypes::DataType::{Float32, Float64};
25use arrow::datatypes::{DataType, Float32Type, Float64Type};
26use datafusion_common::{exec_err, DataFusionError, Result};
27use datafusion_expr::TypeSignature::Exact;
28use datafusion_expr::{
29 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30 Volatility,
31};
32use datafusion_macros::user_doc;
33
34#[user_doc(
35 doc_section(label = "Math Functions"),
36 description = r#"Returns the first argument if it's not _NaN_.
37Returns the second argument otherwise."#,
38 syntax_example = "nanvl(expression_x, expression_y)",
39 argument(
40 name = "expression_x",
41 description = "Numeric expression to return if it's not _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators."
42 ),
43 argument(
44 name = "expression_y",
45 description = "Numeric expression to return if the first expression is _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators."
46 )
47)]
48#[derive(Debug)]
49pub struct NanvlFunc {
50 signature: Signature,
51}
52
53impl Default for NanvlFunc {
54 fn default() -> Self {
55 NanvlFunc::new()
56 }
57}
58
59impl NanvlFunc {
60 pub fn new() -> Self {
61 use DataType::*;
62 Self {
63 signature: Signature::one_of(
64 vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
65 Volatility::Immutable,
66 ),
67 }
68 }
69}
70
71impl ScalarUDFImpl for NanvlFunc {
72 fn as_any(&self) -> &dyn Any {
73 self
74 }
75
76 fn name(&self) -> &str {
77 "nanvl"
78 }
79
80 fn signature(&self) -> &Signature {
81 &self.signature
82 }
83
84 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
85 match &arg_types[0] {
86 Float32 => Ok(Float32),
87 _ => Ok(Float64),
88 }
89 }
90
91 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
92 make_scalar_function(nanvl, vec![])(&args.args)
93 }
94
95 fn documentation(&self) -> Option<&Documentation> {
96 self.doc()
97 }
98}
99
100fn nanvl(args: &[ArrayRef]) -> Result<ArrayRef> {
102 match args[0].data_type() {
103 Float64 => {
104 let compute_nanvl = |x: f64, y: f64| {
105 if x.is_nan() {
106 y
107 } else {
108 x
109 }
110 };
111
112 let x = args[0].as_primitive() as &Float64Array;
113 let y = args[1].as_primitive() as &Float64Array;
114 arrow::compute::binary::<_, _, _, Float64Type>(x, y, compute_nanvl)
115 .map(|res| Arc::new(res) as _)
116 .map_err(DataFusionError::from)
117 }
118 Float32 => {
119 let compute_nanvl = |x: f32, y: f32| {
120 if x.is_nan() {
121 y
122 } else {
123 x
124 }
125 };
126
127 let x = args[0].as_primitive() as &Float32Array;
128 let y = args[1].as_primitive() as &Float32Array;
129 arrow::compute::binary::<_, _, _, Float32Type>(x, y, compute_nanvl)
130 .map(|res| Arc::new(res) as _)
131 .map_err(DataFusionError::from)
132 }
133 other => exec_err!("Unsupported data type {other:?} for function nanvl"),
134 }
135}
136
137#[cfg(test)]
138mod test {
139 use std::sync::Arc;
140
141 use crate::math::nanvl::nanvl;
142
143 use arrow::array::{ArrayRef, Float32Array, Float64Array};
144 use datafusion_common::cast::{as_float32_array, as_float64_array};
145
146 #[test]
147 fn test_nanvl_f64() {
148 let args: Vec<ArrayRef> = vec![
149 Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), ];
152
153 let result = nanvl(&args).expect("failed to initialize function nanvl");
154 let floats =
155 as_float64_array(&result).expect("failed to initialize function nanvl");
156
157 assert_eq!(floats.len(), 4);
158 assert_eq!(floats.value(0), 1.0);
159 assert_eq!(floats.value(1), 6.0);
160 assert_eq!(floats.value(2), 3.0);
161 assert!(floats.value(3).is_nan());
162 }
163
164 #[test]
165 fn test_nanvl_f32() {
166 let args: Vec<ArrayRef> = vec![
167 Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), ];
170
171 let result = nanvl(&args).expect("failed to initialize function nanvl");
172 let floats =
173 as_float32_array(&result).expect("failed to initialize function nanvl");
174
175 assert_eq!(floats.len(), 4);
176 assert_eq!(floats.value(0), 1.0);
177 assert_eq!(floats.value(1), 6.0);
178 assert_eq!(floats.value(2), 3.0);
179 assert!(floats.value(3).is_nan());
180 }
181}