datafusion_functions/math/
trunc.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::make_scalar_function;
22
23use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
24use arrow::datatypes::DataType::{Float32, Float64};
25use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type};
26use datafusion_common::ScalarValue::Int64;
27use datafusion_common::{exec_err, Result};
28use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
29use datafusion_expr::TypeSignature::Exact;
30use datafusion_expr::{
31    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
32    Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37    doc_section(label = "Math Functions"),
38    description = "Truncates a number to a whole number or truncated to the specified decimal places.",
39    syntax_example = "trunc(numeric_expression[, decimal_places])",
40    standard_argument(name = "numeric_expression", prefix = "Numeric"),
41    argument(
42        name = "decimal_places",
43        description = r#"Optional. The number of decimal places to
44  truncate to. Defaults to 0 (truncate to a whole number). If
45  `decimal_places` is a positive integer, truncates digits to the
46  right of the decimal point. If `decimal_places` is a negative
47  integer, replaces digits to the left of the decimal point with `0`."#
48    )
49)]
50#[derive(Debug)]
51pub struct TruncFunc {
52    signature: Signature,
53}
54
55impl Default for TruncFunc {
56    fn default() -> Self {
57        TruncFunc::new()
58    }
59}
60
61impl TruncFunc {
62    pub fn new() -> Self {
63        use DataType::*;
64        Self {
65            // math expressions expect 1 argument of type f64 or f32
66            // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we
67            // return the best approximation for it (in f64).
68            // We accept f32 because in this case it is clear that the best approximation
69            // will be as good as the number of digits in the number
70            signature: Signature::one_of(
71                vec![
72                    Exact(vec![Float32, Int64]),
73                    Exact(vec![Float64, Int64]),
74                    Exact(vec![Float64]),
75                    Exact(vec![Float32]),
76                ],
77                Volatility::Immutable,
78            ),
79        }
80    }
81}
82
83impl ScalarUDFImpl for TruncFunc {
84    fn as_any(&self) -> &dyn Any {
85        self
86    }
87
88    fn name(&self) -> &str {
89        "trunc"
90    }
91
92    fn signature(&self) -> &Signature {
93        &self.signature
94    }
95
96    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
97        match arg_types[0] {
98            Float32 => Ok(Float32),
99            _ => Ok(Float64),
100        }
101    }
102
103    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
104        make_scalar_function(trunc, vec![])(&args.args)
105    }
106
107    fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
108        // trunc preserves the order of the first argument
109        let value = &input[0];
110        let precision = input.get(1);
111
112        if precision
113            .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
114            .unwrap_or(true)
115        {
116            Ok(value.sort_properties)
117        } else {
118            Ok(SortProperties::Unordered)
119        }
120    }
121
122    fn documentation(&self) -> Option<&Documentation> {
123        self.doc()
124    }
125}
126
127/// Truncate(numeric, decimalPrecision) and trunc(numeric) SQL function
128fn trunc(args: &[ArrayRef]) -> Result<ArrayRef> {
129    if args.len() != 1 && args.len() != 2 {
130        return exec_err!(
131            "truncate function requires one or two arguments, got {}",
132            args.len()
133        );
134    }
135
136    // If only one arg then invoke toolchain trunc(num) and precision = 0 by default
137    // or then invoke the compute_truncate method to process precision
138    let num = &args[0];
139    let precision = if args.len() == 1 {
140        ColumnarValue::Scalar(Int64(Some(0)))
141    } else {
142        ColumnarValue::Array(Arc::clone(&args[1]))
143    };
144
145    match num.data_type() {
146        Float64 => match precision {
147            ColumnarValue::Scalar(Int64(Some(0))) => {
148                Ok(Arc::new(
149                    args[0]
150                        .as_primitive::<Float64Type>()
151                        .unary::<_, Float64Type>(|x: f64| {
152                            if x == 0_f64 {
153                                0_f64
154                            } else {
155                                x.trunc()
156                            }
157                        }),
158                ) as ArrayRef)
159            }
160            ColumnarValue::Array(precision) => {
161                let num_array = num.as_primitive::<Float64Type>();
162                let precision_array = precision.as_primitive::<Int64Type>();
163                let result: PrimitiveArray<Float64Type> =
164                    arrow::compute::binary(num_array, precision_array, |x, y| {
165                        compute_truncate64(x, y)
166                    })?;
167
168                Ok(Arc::new(result) as ArrayRef)
169            }
170            _ => exec_err!("trunc function requires a scalar or array for precision"),
171        },
172        Float32 => match precision {
173            ColumnarValue::Scalar(Int64(Some(0))) => {
174                Ok(Arc::new(
175                    args[0]
176                        .as_primitive::<Float32Type>()
177                        .unary::<_, Float32Type>(|x: f32| {
178                            if x == 0_f32 {
179                                0_f32
180                            } else {
181                                x.trunc()
182                            }
183                        }),
184                ) as ArrayRef)
185            }
186            ColumnarValue::Array(precision) => {
187                let num_array = num.as_primitive::<Float32Type>();
188                let precision_array = precision.as_primitive::<Int64Type>();
189                let result: PrimitiveArray<Float32Type> =
190                    arrow::compute::binary(num_array, precision_array, |x, y| {
191                        compute_truncate32(x, y)
192                    })?;
193
194                Ok(Arc::new(result) as ArrayRef)
195            }
196            _ => exec_err!("trunc function requires a scalar or array for precision"),
197        },
198        other => exec_err!("Unsupported data type {other:?} for function trunc"),
199    }
200}
201
202fn compute_truncate32(x: f32, y: i64) -> f32 {
203    let factor = 10.0_f32.powi(y as i32);
204    (x * factor).round() / factor
205}
206
207fn compute_truncate64(x: f64, y: i64) -> f64 {
208    let factor = 10.0_f64.powi(y as i32);
209    (x * factor).round() / factor
210}
211
212#[cfg(test)]
213mod test {
214    use std::sync::Arc;
215
216    use crate::math::trunc::trunc;
217
218    use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
219    use datafusion_common::cast::{as_float32_array, as_float64_array};
220
221    #[test]
222    fn test_truncate_32() {
223        let args: Vec<ArrayRef> = vec![
224            Arc::new(Float32Array::from(vec![
225                15.0,
226                1_234.267_8,
227                1_233.123_4,
228                3.312_979_2,
229                -21.123_4,
230            ])),
231            Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
232        ];
233
234        let result = trunc(&args).expect("failed to initialize function truncate");
235        let floats =
236            as_float32_array(&result).expect("failed to initialize function truncate");
237
238        assert_eq!(floats.len(), 5);
239        assert_eq!(floats.value(0), 15.0);
240        assert_eq!(floats.value(1), 1_234.268);
241        assert_eq!(floats.value(2), 1_233.12);
242        assert_eq!(floats.value(3), 3.312_98);
243        assert_eq!(floats.value(4), -21.123_4);
244    }
245
246    #[test]
247    fn test_truncate_64() {
248        let args: Vec<ArrayRef> = vec![
249            Arc::new(Float64Array::from(vec![
250                5.0,
251                234.267_812_176,
252                123.123_456_789,
253                123.312_979_313_2,
254                -321.123_1,
255            ])),
256            Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
257        ];
258
259        let result = trunc(&args).expect("failed to initialize function truncate");
260        let floats =
261            as_float64_array(&result).expect("failed to initialize function truncate");
262
263        assert_eq!(floats.len(), 5);
264        assert_eq!(floats.value(0), 5.0);
265        assert_eq!(floats.value(1), 234.268);
266        assert_eq!(floats.value(2), 123.12);
267        assert_eq!(floats.value(3), 123.312_98);
268        assert_eq!(floats.value(4), -321.123_1);
269    }
270
271    #[test]
272    fn test_truncate_64_one_arg() {
273        let args: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
274            5.0,
275            234.267_812,
276            123.123_45,
277            123.312_979_313_2,
278            -321.123,
279        ]))];
280
281        let result = trunc(&args).expect("failed to initialize function truncate");
282        let floats =
283            as_float64_array(&result).expect("failed to initialize function truncate");
284
285        assert_eq!(floats.len(), 5);
286        assert_eq!(floats.value(0), 5.0);
287        assert_eq!(floats.value(1), 234.0);
288        assert_eq!(floats.value(2), 123.0);
289        assert_eq!(floats.value(3), 123.0);
290        assert_eq!(floats.value(4), -321.0);
291    }
292}