datafusion_functions/math/
nanvl.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::make_scalar_function;
22
23use arrow::array::{ArrayRef, AsArray, Float32Array, Float64Array};
24use arrow::datatypes::DataType::{Float32, Float64};
25use arrow::datatypes::{DataType, Float32Type, Float64Type};
26use datafusion_common::{exec_err, DataFusionError, Result};
27use datafusion_expr::TypeSignature::Exact;
28use datafusion_expr::{
29    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
30    Volatility,
31};
32use datafusion_macros::user_doc;
33
34#[user_doc(
35    doc_section(label = "Math Functions"),
36    description = r#"Returns the first argument if it's not _NaN_.
37Returns the second argument otherwise."#,
38    syntax_example = "nanvl(expression_x, expression_y)",
39    argument(
40        name = "expression_x",
41        description = "Numeric expression to return if it's not _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators."
42    ),
43    argument(
44        name = "expression_y",
45        description = "Numeric expression to return if the first expression is _NaN_. Can be a constant, column, or function, and any combination of arithmetic operators."
46    )
47)]
48#[derive(Debug)]
49pub struct NanvlFunc {
50    signature: Signature,
51}
52
53impl Default for NanvlFunc {
54    fn default() -> Self {
55        NanvlFunc::new()
56    }
57}
58
59impl NanvlFunc {
60    pub fn new() -> Self {
61        use DataType::*;
62        Self {
63            signature: Signature::one_of(
64                vec![Exact(vec![Float32, Float32]), Exact(vec![Float64, Float64])],
65                Volatility::Immutable,
66            ),
67        }
68    }
69}
70
71impl ScalarUDFImpl for NanvlFunc {
72    fn as_any(&self) -> &dyn Any {
73        self
74    }
75
76    fn name(&self) -> &str {
77        "nanvl"
78    }
79
80    fn signature(&self) -> &Signature {
81        &self.signature
82    }
83
84    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
85        match &arg_types[0] {
86            Float32 => Ok(Float32),
87            _ => Ok(Float64),
88        }
89    }
90
91    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
92        make_scalar_function(nanvl, vec![])(&args.args)
93    }
94
95    fn documentation(&self) -> Option<&Documentation> {
96        self.doc()
97    }
98}
99
100/// Nanvl SQL function
101fn nanvl(args: &[ArrayRef]) -> Result<ArrayRef> {
102    match args[0].data_type() {
103        Float64 => {
104            let compute_nanvl = |x: f64, y: f64| {
105                if x.is_nan() {
106                    y
107                } else {
108                    x
109                }
110            };
111
112            let x = args[0].as_primitive() as &Float64Array;
113            let y = args[1].as_primitive() as &Float64Array;
114            arrow::compute::binary::<_, _, _, Float64Type>(x, y, compute_nanvl)
115                .map(|res| Arc::new(res) as _)
116                .map_err(DataFusionError::from)
117        }
118        Float32 => {
119            let compute_nanvl = |x: f32, y: f32| {
120                if x.is_nan() {
121                    y
122                } else {
123                    x
124                }
125            };
126
127            let x = args[0].as_primitive() as &Float32Array;
128            let y = args[1].as_primitive() as &Float32Array;
129            arrow::compute::binary::<_, _, _, Float32Type>(x, y, compute_nanvl)
130                .map(|res| Arc::new(res) as _)
131                .map_err(DataFusionError::from)
132        }
133        other => exec_err!("Unsupported data type {other:?} for function nanvl"),
134    }
135}
136
137#[cfg(test)]
138mod test {
139    use std::sync::Arc;
140
141    use crate::math::nanvl::nanvl;
142
143    use arrow::array::{ArrayRef, Float32Array, Float64Array};
144    use datafusion_common::cast::{as_float32_array, as_float64_array};
145
146    #[test]
147    fn test_nanvl_f64() {
148        let args: Vec<ArrayRef> = vec![
149            Arc::new(Float64Array::from(vec![1.0, f64::NAN, 3.0, f64::NAN])), // y
150            Arc::new(Float64Array::from(vec![5.0, 6.0, f64::NAN, f64::NAN])), // x
151        ];
152
153        let result = nanvl(&args).expect("failed to initialize function nanvl");
154        let floats =
155            as_float64_array(&result).expect("failed to initialize function nanvl");
156
157        assert_eq!(floats.len(), 4);
158        assert_eq!(floats.value(0), 1.0);
159        assert_eq!(floats.value(1), 6.0);
160        assert_eq!(floats.value(2), 3.0);
161        assert!(floats.value(3).is_nan());
162    }
163
164    #[test]
165    fn test_nanvl_f32() {
166        let args: Vec<ArrayRef> = vec![
167            Arc::new(Float32Array::from(vec![1.0, f32::NAN, 3.0, f32::NAN])), // y
168            Arc::new(Float32Array::from(vec![5.0, 6.0, f32::NAN, f32::NAN])), // x
169        ];
170
171        let result = nanvl(&args).expect("failed to initialize function nanvl");
172        let floats =
173            as_float32_array(&result).expect("failed to initialize function nanvl");
174
175        assert_eq!(floats.len(), 4);
176        assert_eq!(floats.value(0), 1.0);
177        assert_eq!(floats.value(1), 6.0);
178        assert_eq!(floats.value(2), 3.0);
179        assert!(floats.value(3).is_nan());
180    }
181}