datafusion_functions/math/
round.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::make_scalar_function;
22
23use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
24use arrow::compute::{cast_with_options, CastOptions};
25use arrow::datatypes::DataType::{Float32, Float64, Int32};
26use arrow::datatypes::{DataType, Float32Type, Float64Type, Int32Type};
27use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
28use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
29use datafusion_expr::TypeSignature::Exact;
30use datafusion_expr::{
31    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
32    Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37    doc_section(label = "Math Functions"),
38    description = "Rounds a number to the nearest integer.",
39    syntax_example = "round(numeric_expression[, decimal_places])",
40    standard_argument(name = "numeric_expression", prefix = "Numeric"),
41    argument(
42        name = "decimal_places",
43        description = "Optional. The number of decimal places to round to. Defaults to 0."
44    )
45)]
46#[derive(Debug)]
47pub struct RoundFunc {
48    signature: Signature,
49}
50
51impl Default for RoundFunc {
52    fn default() -> Self {
53        RoundFunc::new()
54    }
55}
56
57impl RoundFunc {
58    pub fn new() -> Self {
59        use DataType::*;
60        Self {
61            signature: Signature::one_of(
62                vec![
63                    Exact(vec![Float64, Int64]),
64                    Exact(vec![Float32, Int64]),
65                    Exact(vec![Float64]),
66                    Exact(vec![Float32]),
67                ],
68                Volatility::Immutable,
69            ),
70        }
71    }
72}
73
74impl ScalarUDFImpl for RoundFunc {
75    fn as_any(&self) -> &dyn Any {
76        self
77    }
78
79    fn name(&self) -> &str {
80        "round"
81    }
82
83    fn signature(&self) -> &Signature {
84        &self.signature
85    }
86
87    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
88        match arg_types[0] {
89            Float32 => Ok(Float32),
90            _ => Ok(Float64),
91        }
92    }
93
94    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
95        make_scalar_function(round, vec![])(&args.args)
96    }
97
98    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
99        // round preserves the order of the first argument
100        let value = &input[0];
101        let precision = input.get(1);
102
103        if precision
104            .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
105            .unwrap_or(true)
106        {
107            Ok(value.sort_properties)
108        } else {
109            Ok(SortProperties::Unordered)
110        }
111    }
112
113    fn documentation(&self) -> Option<&Documentation> {
114        self.doc()
115    }
116}
117
118/// Round SQL function
119pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
120    if args.len() != 1 && args.len() != 2 {
121        return exec_err!(
122            "round function requires one or two arguments, got {}",
123            args.len()
124        );
125    }
126
127    let mut decimal_places = ColumnarValue::Scalar(ScalarValue::Int64(Some(0)));
128
129    if args.len() == 2 {
130        decimal_places = ColumnarValue::Array(Arc::clone(&args[1]));
131    }
132
133    match args[0].data_type() {
134        Float64 => match decimal_places {
135            ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => {
136                let decimal_places: i32 = decimal_places.try_into().map_err(|e| {
137                    exec_datafusion_err!(
138                        "Invalid value for decimal places: {decimal_places}: {e}"
139                    )
140                })?;
141
142                let result = args[0]
143                    .as_primitive::<Float64Type>()
144                    .unary::<_, Float64Type>(|value: f64| {
145                        (value * 10.0_f64.powi(decimal_places)).round()
146                            / 10.0_f64.powi(decimal_places)
147                    });
148                Ok(Arc::new(result) as _)
149            }
150            ColumnarValue::Array(decimal_places) => {
151                let options = CastOptions {
152                    safe: false, // raise error if the cast is not possible
153                    ..Default::default()
154                };
155                let decimal_places = cast_with_options(&decimal_places, &Int32, &options)
156                    .map_err(|e| {
157                        exec_datafusion_err!("Invalid values for decimal places: {e}")
158                    })?;
159
160                let values = args[0].as_primitive::<Float64Type>();
161                let decimal_places = decimal_places.as_primitive::<Int32Type>();
162                let result = arrow::compute::binary::<_, _, _, Float64Type>(
163                    values,
164                    decimal_places,
165                    |value, decimal_places| {
166                        (value * 10.0_f64.powi(decimal_places)).round()
167                            / 10.0_f64.powi(decimal_places)
168                    },
169                )?;
170                Ok(Arc::new(result) as _)
171            }
172            _ => {
173                exec_err!("round function requires a scalar or array for decimal_places")
174            }
175        },
176
177        Float32 => match decimal_places {
178            ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => {
179                let decimal_places: i32 = decimal_places.try_into().map_err(|e| {
180                    exec_datafusion_err!(
181                        "Invalid value for decimal places: {decimal_places}: {e}"
182                    )
183                })?;
184                let result = args[0]
185                    .as_primitive::<Float32Type>()
186                    .unary::<_, Float32Type>(|value: f32| {
187                        (value * 10.0_f32.powi(decimal_places)).round()
188                            / 10.0_f32.powi(decimal_places)
189                    });
190                Ok(Arc::new(result) as _)
191            }
192            ColumnarValue::Array(_) => {
193                let ColumnarValue::Array(decimal_places) =
194                    decimal_places.cast_to(&Int32, None).map_err(|e| {
195                        exec_datafusion_err!("Invalid values for decimal places: {e}")
196                    })?
197                else {
198                    panic!("Unexpected result of ColumnarValue::Array.cast")
199                };
200
201                let values = args[0].as_primitive::<Float32Type>();
202                let decimal_places = decimal_places.as_primitive::<Int32Type>();
203                let result: PrimitiveArray<Float32Type> = arrow::compute::binary(
204                    values,
205                    decimal_places,
206                    |value, decimal_places| {
207                        (value * 10.0_f32.powi(decimal_places)).round()
208                            / 10.0_f32.powi(decimal_places)
209                    },
210                )?;
211                Ok(Arc::new(result) as _)
212            }
213            _ => {
214                exec_err!("round function requires a scalar or array for decimal_places")
215            }
216        },
217
218        other => exec_err!("Unsupported data type {other:?} for function round"),
219    }
220}
221
222#[cfg(test)]
223mod test {
224    use std::sync::Arc;
225
226    use crate::math::round::round;
227
228    use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
229    use datafusion_common::cast::{as_float32_array, as_float64_array};
230    use datafusion_common::DataFusionError;
231
232    #[test]
233    fn test_round_f32() {
234        let args: Vec<ArrayRef> = vec![
235            Arc::new(Float32Array::from(vec![125.2345; 10])), // input
236            Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
237        ];
238
239        let result = round(&args).expect("failed to initialize function round");
240        let floats =
241            as_float32_array(&result).expect("failed to initialize function round");
242
243        let expected = Float32Array::from(vec![
244            125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
245        ]);
246
247        assert_eq!(floats, &expected);
248    }
249
250    #[test]
251    fn test_round_f64() {
252        let args: Vec<ArrayRef> = vec![
253            Arc::new(Float64Array::from(vec![125.2345; 10])), // input
254            Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), // decimal_places
255        ];
256
257        let result = round(&args).expect("failed to initialize function round");
258        let floats =
259            as_float64_array(&result).expect("failed to initialize function round");
260
261        let expected = Float64Array::from(vec![
262            125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
263        ]);
264
265        assert_eq!(floats, &expected);
266    }
267
268    #[test]
269    fn test_round_f32_one_input() {
270        let args: Vec<ArrayRef> = vec![
271            Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input
272        ];
273
274        let result = round(&args).expect("failed to initialize function round");
275        let floats =
276            as_float32_array(&result).expect("failed to initialize function round");
277
278        let expected = Float32Array::from(vec![125.0, 12.0, 1.0, 0.0]);
279
280        assert_eq!(floats, &expected);
281    }
282
283    #[test]
284    fn test_round_f64_one_input() {
285        let args: Vec<ArrayRef> = vec![
286            Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), // input
287        ];
288
289        let result = round(&args).expect("failed to initialize function round");
290        let floats =
291            as_float64_array(&result).expect("failed to initialize function round");
292
293        let expected = Float64Array::from(vec![125.0, 12.0, 1.0, 0.0]);
294
295        assert_eq!(floats, &expected);
296    }
297
298    #[test]
299    fn test_round_f32_cast_fail() {
300        let args: Vec<ArrayRef> = vec![
301            Arc::new(Float64Array::from(vec![125.2345])), // input
302            Arc::new(Int64Array::from(vec![2147483648])), // decimal_places
303        ];
304
305        let result = round(&args);
306
307        assert!(result.is_err());
308        assert!(matches!(result, Err(DataFusionError::Execution { .. })));
309    }
310}