datafusion_functions/math/
round.rs1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::make_scalar_function;
22
23use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
24use arrow::compute::{cast_with_options, CastOptions};
25use arrow::datatypes::DataType::{Float32, Float64, Int32};
26use arrow::datatypes::{DataType, Float32Type, Float64Type, Int32Type};
27use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
28use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
29use datafusion_expr::TypeSignature::Exact;
30use datafusion_expr::{
31 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
32 Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "Math Functions"),
38 description = "Rounds a number to the nearest integer.",
39 syntax_example = "round(numeric_expression[, decimal_places])",
40 standard_argument(name = "numeric_expression", prefix = "Numeric"),
41 argument(
42 name = "decimal_places",
43 description = "Optional. The number of decimal places to round to. Defaults to 0."
44 )
45)]
46#[derive(Debug)]
47pub struct RoundFunc {
48 signature: Signature,
49}
50
51impl Default for RoundFunc {
52 fn default() -> Self {
53 RoundFunc::new()
54 }
55}
56
57impl RoundFunc {
58 pub fn new() -> Self {
59 use DataType::*;
60 Self {
61 signature: Signature::one_of(
62 vec![
63 Exact(vec![Float64, Int64]),
64 Exact(vec![Float32, Int64]),
65 Exact(vec![Float64]),
66 Exact(vec![Float32]),
67 ],
68 Volatility::Immutable,
69 ),
70 }
71 }
72}
73
74impl ScalarUDFImpl for RoundFunc {
75 fn as_any(&self) -> &dyn Any {
76 self
77 }
78
79 fn name(&self) -> &str {
80 "round"
81 }
82
83 fn signature(&self) -> &Signature {
84 &self.signature
85 }
86
87 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
88 match arg_types[0] {
89 Float32 => Ok(Float32),
90 _ => Ok(Float64),
91 }
92 }
93
94 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
95 make_scalar_function(round, vec![])(&args.args)
96 }
97
98 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
99 let value = &input[0];
101 let precision = input.get(1);
102
103 if precision
104 .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
105 .unwrap_or(true)
106 {
107 Ok(value.sort_properties)
108 } else {
109 Ok(SortProperties::Unordered)
110 }
111 }
112
113 fn documentation(&self) -> Option<&Documentation> {
114 self.doc()
115 }
116}
117
118pub fn round(args: &[ArrayRef]) -> Result<ArrayRef> {
120 if args.len() != 1 && args.len() != 2 {
121 return exec_err!(
122 "round function requires one or two arguments, got {}",
123 args.len()
124 );
125 }
126
127 let mut decimal_places = ColumnarValue::Scalar(ScalarValue::Int64(Some(0)));
128
129 if args.len() == 2 {
130 decimal_places = ColumnarValue::Array(Arc::clone(&args[1]));
131 }
132
133 match args[0].data_type() {
134 Float64 => match decimal_places {
135 ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => {
136 let decimal_places: i32 = decimal_places.try_into().map_err(|e| {
137 exec_datafusion_err!(
138 "Invalid value for decimal places: {decimal_places}: {e}"
139 )
140 })?;
141
142 let result = args[0]
143 .as_primitive::<Float64Type>()
144 .unary::<_, Float64Type>(|value: f64| {
145 (value * 10.0_f64.powi(decimal_places)).round()
146 / 10.0_f64.powi(decimal_places)
147 });
148 Ok(Arc::new(result) as _)
149 }
150 ColumnarValue::Array(decimal_places) => {
151 let options = CastOptions {
152 safe: false, ..Default::default()
154 };
155 let decimal_places = cast_with_options(&decimal_places, &Int32, &options)
156 .map_err(|e| {
157 exec_datafusion_err!("Invalid values for decimal places: {e}")
158 })?;
159
160 let values = args[0].as_primitive::<Float64Type>();
161 let decimal_places = decimal_places.as_primitive::<Int32Type>();
162 let result = arrow::compute::binary::<_, _, _, Float64Type>(
163 values,
164 decimal_places,
165 |value, decimal_places| {
166 (value * 10.0_f64.powi(decimal_places)).round()
167 / 10.0_f64.powi(decimal_places)
168 },
169 )?;
170 Ok(Arc::new(result) as _)
171 }
172 _ => {
173 exec_err!("round function requires a scalar or array for decimal_places")
174 }
175 },
176
177 Float32 => match decimal_places {
178 ColumnarValue::Scalar(ScalarValue::Int64(Some(decimal_places))) => {
179 let decimal_places: i32 = decimal_places.try_into().map_err(|e| {
180 exec_datafusion_err!(
181 "Invalid value for decimal places: {decimal_places}: {e}"
182 )
183 })?;
184 let result = args[0]
185 .as_primitive::<Float32Type>()
186 .unary::<_, Float32Type>(|value: f32| {
187 (value * 10.0_f32.powi(decimal_places)).round()
188 / 10.0_f32.powi(decimal_places)
189 });
190 Ok(Arc::new(result) as _)
191 }
192 ColumnarValue::Array(_) => {
193 let ColumnarValue::Array(decimal_places) =
194 decimal_places.cast_to(&Int32, None).map_err(|e| {
195 exec_datafusion_err!("Invalid values for decimal places: {e}")
196 })?
197 else {
198 panic!("Unexpected result of ColumnarValue::Array.cast")
199 };
200
201 let values = args[0].as_primitive::<Float32Type>();
202 let decimal_places = decimal_places.as_primitive::<Int32Type>();
203 let result: PrimitiveArray<Float32Type> = arrow::compute::binary(
204 values,
205 decimal_places,
206 |value, decimal_places| {
207 (value * 10.0_f32.powi(decimal_places)).round()
208 / 10.0_f32.powi(decimal_places)
209 },
210 )?;
211 Ok(Arc::new(result) as _)
212 }
213 _ => {
214 exec_err!("round function requires a scalar or array for decimal_places")
215 }
216 },
217
218 other => exec_err!("Unsupported data type {other:?} for function round"),
219 }
220}
221
222#[cfg(test)]
223mod test {
224 use std::sync::Arc;
225
226 use crate::math::round::round;
227
228 use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
229 use datafusion_common::cast::{as_float32_array, as_float64_array};
230 use datafusion_common::DataFusionError;
231
232 #[test]
233 fn test_round_f32() {
234 let args: Vec<ArrayRef> = vec![
235 Arc::new(Float32Array::from(vec![125.2345; 10])), Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), ];
238
239 let result = round(&args).expect("failed to initialize function round");
240 let floats =
241 as_float32_array(&result).expect("failed to initialize function round");
242
243 let expected = Float32Array::from(vec![
244 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
245 ]);
246
247 assert_eq!(floats, &expected);
248 }
249
250 #[test]
251 fn test_round_f64() {
252 let args: Vec<ArrayRef> = vec![
253 Arc::new(Float64Array::from(vec![125.2345; 10])), Arc::new(Int64Array::from(vec![0, 1, 2, 3, 4, 5, -1, -2, -3, -4])), ];
256
257 let result = round(&args).expect("failed to initialize function round");
258 let floats =
259 as_float64_array(&result).expect("failed to initialize function round");
260
261 let expected = Float64Array::from(vec![
262 125.0, 125.2, 125.23, 125.235, 125.2345, 125.2345, 130.0, 100.0, 0.0, 0.0,
263 ]);
264
265 assert_eq!(floats, &expected);
266 }
267
268 #[test]
269 fn test_round_f32_one_input() {
270 let args: Vec<ArrayRef> = vec![
271 Arc::new(Float32Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), ];
273
274 let result = round(&args).expect("failed to initialize function round");
275 let floats =
276 as_float32_array(&result).expect("failed to initialize function round");
277
278 let expected = Float32Array::from(vec![125.0, 12.0, 1.0, 0.0]);
279
280 assert_eq!(floats, &expected);
281 }
282
283 #[test]
284 fn test_round_f64_one_input() {
285 let args: Vec<ArrayRef> = vec![
286 Arc::new(Float64Array::from(vec![125.2345, 12.345, 1.234, 0.1234])), ];
288
289 let result = round(&args).expect("failed to initialize function round");
290 let floats =
291 as_float64_array(&result).expect("failed to initialize function round");
292
293 let expected = Float64Array::from(vec![125.0, 12.0, 1.0, 0.0]);
294
295 assert_eq!(floats, &expected);
296 }
297
298 #[test]
299 fn test_round_f32_cast_fail() {
300 let args: Vec<ArrayRef> = vec![
301 Arc::new(Float64Array::from(vec![125.2345])), Arc::new(Int64Array::from(vec![2147483648])), ];
304
305 let result = round(&args);
306
307 assert!(result.is_err());
308 assert!(matches!(result, Err(DataFusionError::Execution { .. })));
309 }
310}