datafusion_functions/math/
trunc.rs1use std::any::Any;
19use std::sync::Arc;
20
21use crate::utils::make_scalar_function;
22
23use arrow::array::{ArrayRef, AsArray, PrimitiveArray};
24use arrow::datatypes::DataType::{Float32, Float64};
25use arrow::datatypes::{DataType, Float32Type, Float64Type, Int64Type};
26use datafusion_common::ScalarValue::Int64;
27use datafusion_common::{exec_err, Result};
28use datafusion_expr::sort_properties::{ExprProperties, SortProperties};
29use datafusion_expr::TypeSignature::Exact;
30use datafusion_expr::{
31 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
32 Volatility,
33};
34use datafusion_macros::user_doc;
35
36#[user_doc(
37 doc_section(label = "Math Functions"),
38 description = "Truncates a number to a whole number or truncated to the specified decimal places.",
39 syntax_example = "trunc(numeric_expression[, decimal_places])",
40 standard_argument(name = "numeric_expression", prefix = "Numeric"),
41 argument(
42 name = "decimal_places",
43 description = r#"Optional. The number of decimal places to
44 truncate to. Defaults to 0 (truncate to a whole number). If
45 `decimal_places` is a positive integer, truncates digits to the
46 right of the decimal point. If `decimal_places` is a negative
47 integer, replaces digits to the left of the decimal point with `0`."#
48 )
49)]
50#[derive(Debug)]
51pub struct TruncFunc {
52 signature: Signature,
53}
54
55impl Default for TruncFunc {
56 fn default() -> Self {
57 TruncFunc::new()
58 }
59}
60
61impl TruncFunc {
62 pub fn new() -> Self {
63 use DataType::*;
64 Self {
65 signature: Signature::one_of(
71 vec![
72 Exact(vec![Float32, Int64]),
73 Exact(vec![Float64, Int64]),
74 Exact(vec![Float64]),
75 Exact(vec![Float32]),
76 ],
77 Volatility::Immutable,
78 ),
79 }
80 }
81}
82
83impl ScalarUDFImpl for TruncFunc {
84 fn as_any(&self) -> &dyn Any {
85 self
86 }
87
88 fn name(&self) -> &str {
89 "trunc"
90 }
91
92 fn signature(&self) -> &Signature {
93 &self.signature
94 }
95
96 fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
97 match arg_types[0] {
98 Float32 => Ok(Float32),
99 _ => Ok(Float64),
100 }
101 }
102
103 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
104 make_scalar_function(trunc, vec![])(&args.args)
105 }
106
107 fn output_ordering(&self, input: &[ExprProperties]) -> Result<SortProperties> {
108 let value = &input[0];
110 let precision = input.get(1);
111
112 if precision
113 .map(|r| r.sort_properties.eq(&SortProperties::Singleton))
114 .unwrap_or(true)
115 {
116 Ok(value.sort_properties)
117 } else {
118 Ok(SortProperties::Unordered)
119 }
120 }
121
122 fn documentation(&self) -> Option<&Documentation> {
123 self.doc()
124 }
125}
126
127fn trunc(args: &[ArrayRef]) -> Result<ArrayRef> {
129 if args.len() != 1 && args.len() != 2 {
130 return exec_err!(
131 "truncate function requires one or two arguments, got {}",
132 args.len()
133 );
134 }
135
136 let num = &args[0];
139 let precision = if args.len() == 1 {
140 ColumnarValue::Scalar(Int64(Some(0)))
141 } else {
142 ColumnarValue::Array(Arc::clone(&args[1]))
143 };
144
145 match num.data_type() {
146 Float64 => match precision {
147 ColumnarValue::Scalar(Int64(Some(0))) => {
148 Ok(Arc::new(
149 args[0]
150 .as_primitive::<Float64Type>()
151 .unary::<_, Float64Type>(|x: f64| {
152 if x == 0_f64 {
153 0_f64
154 } else {
155 x.trunc()
156 }
157 }),
158 ) as ArrayRef)
159 }
160 ColumnarValue::Array(precision) => {
161 let num_array = num.as_primitive::<Float64Type>();
162 let precision_array = precision.as_primitive::<Int64Type>();
163 let result: PrimitiveArray<Float64Type> =
164 arrow::compute::binary(num_array, precision_array, |x, y| {
165 compute_truncate64(x, y)
166 })?;
167
168 Ok(Arc::new(result) as ArrayRef)
169 }
170 _ => exec_err!("trunc function requires a scalar or array for precision"),
171 },
172 Float32 => match precision {
173 ColumnarValue::Scalar(Int64(Some(0))) => {
174 Ok(Arc::new(
175 args[0]
176 .as_primitive::<Float32Type>()
177 .unary::<_, Float32Type>(|x: f32| {
178 if x == 0_f32 {
179 0_f32
180 } else {
181 x.trunc()
182 }
183 }),
184 ) as ArrayRef)
185 }
186 ColumnarValue::Array(precision) => {
187 let num_array = num.as_primitive::<Float32Type>();
188 let precision_array = precision.as_primitive::<Int64Type>();
189 let result: PrimitiveArray<Float32Type> =
190 arrow::compute::binary(num_array, precision_array, |x, y| {
191 compute_truncate32(x, y)
192 })?;
193
194 Ok(Arc::new(result) as ArrayRef)
195 }
196 _ => exec_err!("trunc function requires a scalar or array for precision"),
197 },
198 other => exec_err!("Unsupported data type {other:?} for function trunc"),
199 }
200}
201
202fn compute_truncate32(x: f32, y: i64) -> f32 {
203 let factor = 10.0_f32.powi(y as i32);
204 (x * factor).round() / factor
205}
206
207fn compute_truncate64(x: f64, y: i64) -> f64 {
208 let factor = 10.0_f64.powi(y as i32);
209 (x * factor).round() / factor
210}
211
212#[cfg(test)]
213mod test {
214 use std::sync::Arc;
215
216 use crate::math::trunc::trunc;
217
218 use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
219 use datafusion_common::cast::{as_float32_array, as_float64_array};
220
221 #[test]
222 fn test_truncate_32() {
223 let args: Vec<ArrayRef> = vec![
224 Arc::new(Float32Array::from(vec![
225 15.0,
226 1_234.267_8,
227 1_233.123_4,
228 3.312_979_2,
229 -21.123_4,
230 ])),
231 Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
232 ];
233
234 let result = trunc(&args).expect("failed to initialize function truncate");
235 let floats =
236 as_float32_array(&result).expect("failed to initialize function truncate");
237
238 assert_eq!(floats.len(), 5);
239 assert_eq!(floats.value(0), 15.0);
240 assert_eq!(floats.value(1), 1_234.268);
241 assert_eq!(floats.value(2), 1_233.12);
242 assert_eq!(floats.value(3), 3.312_98);
243 assert_eq!(floats.value(4), -21.123_4);
244 }
245
246 #[test]
247 fn test_truncate_64() {
248 let args: Vec<ArrayRef> = vec![
249 Arc::new(Float64Array::from(vec![
250 5.0,
251 234.267_812_176,
252 123.123_456_789,
253 123.312_979_313_2,
254 -321.123_1,
255 ])),
256 Arc::new(Int64Array::from(vec![0, 3, 2, 5, 6])),
257 ];
258
259 let result = trunc(&args).expect("failed to initialize function truncate");
260 let floats =
261 as_float64_array(&result).expect("failed to initialize function truncate");
262
263 assert_eq!(floats.len(), 5);
264 assert_eq!(floats.value(0), 5.0);
265 assert_eq!(floats.value(1), 234.268);
266 assert_eq!(floats.value(2), 123.12);
267 assert_eq!(floats.value(3), 123.312_98);
268 assert_eq!(floats.value(4), -321.123_1);
269 }
270
271 #[test]
272 fn test_truncate_64_one_arg() {
273 let args: Vec<ArrayRef> = vec![Arc::new(Float64Array::from(vec![
274 5.0,
275 234.267_812,
276 123.123_45,
277 123.312_979_313_2,
278 -321.123,
279 ]))];
280
281 let result = trunc(&args).expect("failed to initialize function truncate");
282 let floats =
283 as_float64_array(&result).expect("failed to initialize function truncate");
284
285 assert_eq!(floats.len(), 5);
286 assert_eq!(floats.value(0), 5.0);
287 assert_eq!(floats.value(1), 234.0);
288 assert_eq!(floats.value(2), 123.0);
289 assert_eq!(floats.value(3), 123.0);
290 assert_eq!(floats.value(4), -321.0);
291 }
292}