datafusion_functions/math/
abs.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! math expressions
19
20use std::any::Any;
21use std::sync::Arc;
22
23use arrow::array::{
24    ArrayRef, Decimal128Array, Decimal256Array, Float32Array, Float64Array, Int16Array,
25    Int32Array, Int64Array, Int8Array,
26};
27use arrow::datatypes::DataType;
28use arrow::error::ArrowError;
29use datafusion_common::{
30    internal_datafusion_err, not_impl_err, utils::take_function_args, Result,
31};
32use datafusion_expr::interval_arithmetic::Interval;
33use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
34use datafusion_expr::{
35    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
36    Volatility,
37};
38use datafusion_macros::user_doc;
39
40type MathArrayFunction = fn(&ArrayRef) -> Result<ArrayRef>;
41
42macro_rules! make_abs_function {
43    ($ARRAY_TYPE:ident) => {{
44        |input: &ArrayRef| {
45            let array = downcast_named_arg!(&input, "abs arg", $ARRAY_TYPE);
46            let res: $ARRAY_TYPE = array.unary(|x| x.abs());
47            Ok(Arc::new(res) as ArrayRef)
48        }
49    }};
50}
51
52macro_rules! make_try_abs_function {
53    ($ARRAY_TYPE:ident) => {{
54        |input: &ArrayRef| {
55            let array = downcast_named_arg!(&input, "abs arg", $ARRAY_TYPE);
56            let res: $ARRAY_TYPE = array.try_unary(|x| {
57                x.checked_abs().ok_or_else(|| {
58                    ArrowError::ComputeError(format!(
59                        "{} overflow on abs({})",
60                        stringify!($ARRAY_TYPE),
61                        x
62                    ))
63                })
64            })?;
65            Ok(Arc::new(res) as ArrayRef)
66        }
67    }};
68}
69
70macro_rules! make_decimal_abs_function {
71    ($ARRAY_TYPE:ident) => {{
72        |input: &ArrayRef| {
73            let array = downcast_named_arg!(&input, "abs arg", $ARRAY_TYPE);
74            let res: $ARRAY_TYPE = array
75                .unary(|x| x.wrapping_abs())
76                .with_data_type(input.data_type().clone());
77            Ok(Arc::new(res) as ArrayRef)
78        }
79    }};
80}
81
82/// Abs SQL function
83/// Return different implementations based on input datatype to reduce branches during execution
84fn create_abs_function(input_data_type: &DataType) -> Result<MathArrayFunction> {
85    match input_data_type {
86        DataType::Float32 => Ok(make_abs_function!(Float32Array)),
87        DataType::Float64 => Ok(make_abs_function!(Float64Array)),
88
89        // Types that may overflow, such as abs(-128_i8).
90        DataType::Int8 => Ok(make_try_abs_function!(Int8Array)),
91        DataType::Int16 => Ok(make_try_abs_function!(Int16Array)),
92        DataType::Int32 => Ok(make_try_abs_function!(Int32Array)),
93        DataType::Int64 => Ok(make_try_abs_function!(Int64Array)),
94
95        // Types of results are the same as the input.
96        DataType::Null
97        | DataType::UInt8
98        | DataType::UInt16
99        | DataType::UInt32
100        | DataType::UInt64 => Ok(|input: &ArrayRef| Ok(Arc::clone(input))),
101
102        // Decimal types
103        DataType::Decimal128(_, _) => Ok(make_decimal_abs_function!(Decimal128Array)),
104        DataType::Decimal256(_, _) => Ok(make_decimal_abs_function!(Decimal256Array)),
105
106        other => not_impl_err!("Unsupported data type {other:?} for function abs"),
107    }
108}
109#[user_doc(
110    doc_section(label = "Math Functions"),
111    description = "Returns the absolute value of a number.",
112    syntax_example = "abs(numeric_expression)",
113    standard_argument(name = "numeric_expression", prefix = "Numeric")
114)]
115#[derive(Debug)]
116pub struct AbsFunc {
117    signature: Signature,
118}
119
120impl Default for AbsFunc {
121    fn default() -> Self {
122        Self::new()
123    }
124}
125
126impl AbsFunc {
127    pub fn new() -> Self {
128        Self {
129            signature: Signature::numeric(1, Volatility::Immutable),
130        }
131    }
132}
133
134impl ScalarUDFImpl for AbsFunc {
135    fn as_any(&self) -> &dyn Any {
136        self
137    }
138    fn name(&self) -> &str {
139        "abs"
140    }
141
142    fn signature(&self) -> &Signature {
143        &self.signature
144    }
145
146    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
147        match arg_types[0] {
148            DataType::Float32 => Ok(DataType::Float32),
149            DataType::Float64 => Ok(DataType::Float64),
150            DataType::Int8 => Ok(DataType::Int8),
151            DataType::Int16 => Ok(DataType::Int16),
152            DataType::Int32 => Ok(DataType::Int32),
153            DataType::Int64 => Ok(DataType::Int64),
154            DataType::Null => Ok(DataType::Null),
155            DataType::UInt8 => Ok(DataType::UInt8),
156            DataType::UInt16 => Ok(DataType::UInt16),
157            DataType::UInt32 => Ok(DataType::UInt32),
158            DataType::UInt64 => Ok(DataType::UInt64),
159            DataType::Decimal128(precision, scale) => {
160                Ok(DataType::Decimal128(precision, scale))
161            }
162            DataType::Decimal256(precision, scale) => {
163                Ok(DataType::Decimal256(precision, scale))
164            }
165            _ => not_impl_err!(
166                "Unsupported data type {} for function abs",
167                arg_types[0].to_string()
168            ),
169        }
170    }
171
172    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
173        let args = ColumnarValue::values_to_arrays(&args.args)?;
174        let [input] = take_function_args(self.name(), args)?;
175
176        let input_data_type = input.data_type();
177        let abs_fun = create_abs_function(input_data_type)?;
178
179        abs_fun(&input).map(ColumnarValue::Array)
180    }
181
182    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
183        // Non-decreasing for x ≥ 0 and symmetrically non-increasing for x ≤ 0.
184        let arg = &input[0];
185        let range = &arg.range;
186        let zero_point = Interval::make_zero(&range.lower().data_type())?;
187
188        if range.gt_eq(&zero_point)? == Interval::CERTAINLY_TRUE {
189            Ok(arg.sort_properties)
190        } else if range.lt_eq(&zero_point)? == Interval::CERTAINLY_TRUE {
191            Ok(-arg.sort_properties)
192        } else {
193            Ok(SortProperties::Unordered)
194        }
195    }
196
197    fn documentation(&self) -> Option<&Documentation> {
198        self.doc()
199    }
200}