arrow_arith/
numeric.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Defines numeric arithmetic kernels on [`PrimitiveArray`], such as [`add`]
19
20use std::cmp::Ordering;
21use std::fmt::Formatter;
22use std::sync::Arc;
23
24use arrow_array::cast::AsArray;
25use arrow_array::timezone::Tz;
26use arrow_array::types::*;
27use arrow_array::*;
28use arrow_buffer::{ArrowNativeType, IntervalDayTime, IntervalMonthDayNano};
29use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit};
30
31use crate::arity::{binary, try_binary};
32
33/// Perform `lhs + rhs`, returning an error on overflow
34pub fn add(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
35    arithmetic_op(Op::Add, lhs, rhs)
36}
37
38/// Perform `lhs + rhs`, wrapping on overflow for [`DataType::is_integer`]
39pub fn add_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
40    arithmetic_op(Op::AddWrapping, lhs, rhs)
41}
42
43/// Perform `lhs - rhs`, returning an error on overflow
44pub fn sub(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
45    arithmetic_op(Op::Sub, lhs, rhs)
46}
47
48/// Perform `lhs - rhs`, wrapping on overflow for [`DataType::is_integer`]
49pub fn sub_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
50    arithmetic_op(Op::SubWrapping, lhs, rhs)
51}
52
53/// Perform `lhs * rhs`, returning an error on overflow
54pub fn mul(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
55    arithmetic_op(Op::Mul, lhs, rhs)
56}
57
58/// Perform `lhs * rhs`, wrapping on overflow for [`DataType::is_integer`]
59pub fn mul_wrapping(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
60    arithmetic_op(Op::MulWrapping, lhs, rhs)
61}
62
63/// Perform `lhs / rhs`
64///
65/// Overflow or division by zero will result in an error, with exception to
66/// floating point numbers, which instead follow the IEEE 754 rules
67pub fn div(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
68    arithmetic_op(Op::Div, lhs, rhs)
69}
70
71/// Perform `lhs % rhs`
72///
73/// Overflow or division by zero will result in an error, with exception to
74/// floating point numbers, which instead follow the IEEE 754 rules
75pub fn rem(lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
76    arithmetic_op(Op::Rem, lhs, rhs)
77}
78
79macro_rules! neg_checked {
80    ($t:ty, $a:ident) => {{
81        let array = $a
82            .as_primitive::<$t>()
83            .try_unary::<_, $t, _>(|x| x.neg_checked())?;
84        Ok(Arc::new(array))
85    }};
86}
87
88macro_rules! neg_wrapping {
89    ($t:ty, $a:ident) => {{
90        let array = $a.as_primitive::<$t>().unary::<_, $t>(|x| x.neg_wrapping());
91        Ok(Arc::new(array))
92    }};
93}
94
95/// Negates each element of  `array`, returning an error on overflow
96///
97/// Note: negation of unsigned arrays is not supported and will return in an error,
98/// for wrapping unsigned negation consider using [`neg_wrapping`][neg_wrapping()]
99pub fn neg(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
100    use DataType::*;
101    use IntervalUnit::*;
102    use TimeUnit::*;
103
104    match array.data_type() {
105        Int8 => neg_checked!(Int8Type, array),
106        Int16 => neg_checked!(Int16Type, array),
107        Int32 => neg_checked!(Int32Type, array),
108        Int64 => neg_checked!(Int64Type, array),
109        Float16 => neg_wrapping!(Float16Type, array),
110        Float32 => neg_wrapping!(Float32Type, array),
111        Float64 => neg_wrapping!(Float64Type, array),
112        Decimal128(p, s) => {
113            let a = array
114                .as_primitive::<Decimal128Type>()
115                .try_unary::<_, Decimal128Type, _>(|x| x.neg_checked())?;
116
117            Ok(Arc::new(a.with_precision_and_scale(*p, *s)?))
118        }
119        Decimal256(p, s) => {
120            let a = array
121                .as_primitive::<Decimal256Type>()
122                .try_unary::<_, Decimal256Type, _>(|x| x.neg_checked())?;
123
124            Ok(Arc::new(a.with_precision_and_scale(*p, *s)?))
125        }
126        Duration(Second) => neg_checked!(DurationSecondType, array),
127        Duration(Millisecond) => neg_checked!(DurationMillisecondType, array),
128        Duration(Microsecond) => neg_checked!(DurationMicrosecondType, array),
129        Duration(Nanosecond) => neg_checked!(DurationNanosecondType, array),
130        Interval(YearMonth) => neg_checked!(IntervalYearMonthType, array),
131        Interval(DayTime) => {
132            let a = array
133                .as_primitive::<IntervalDayTimeType>()
134                .try_unary::<_, IntervalDayTimeType, ArrowError>(|x| {
135                    let (days, ms) = IntervalDayTimeType::to_parts(x);
136                    Ok(IntervalDayTimeType::make_value(
137                        days.neg_checked()?,
138                        ms.neg_checked()?,
139                    ))
140                })?;
141            Ok(Arc::new(a))
142        }
143        Interval(MonthDayNano) => {
144            let a = array
145                .as_primitive::<IntervalMonthDayNanoType>()
146                .try_unary::<_, IntervalMonthDayNanoType, ArrowError>(|x| {
147                    let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(x);
148                    Ok(IntervalMonthDayNanoType::make_value(
149                        months.neg_checked()?,
150                        days.neg_checked()?,
151                        nanos.neg_checked()?,
152                    ))
153                })?;
154            Ok(Arc::new(a))
155        }
156        t => Err(ArrowError::InvalidArgumentError(format!(
157            "Invalid arithmetic operation: !{t}"
158        ))),
159    }
160}
161
162/// Negates each element of  `array`, wrapping on overflow for [`DataType::is_integer`]
163pub fn neg_wrapping(array: &dyn Array) -> Result<ArrayRef, ArrowError> {
164    downcast_integer! {
165        array.data_type() => (neg_wrapping, array),
166        _ => neg(array),
167    }
168}
169
170/// An enumeration of arithmetic operations
171///
172/// This allows sharing the type dispatch logic across the various kernels
173#[derive(Debug, Copy, Clone)]
174enum Op {
175    AddWrapping,
176    Add,
177    SubWrapping,
178    Sub,
179    MulWrapping,
180    Mul,
181    Div,
182    Rem,
183}
184
185impl std::fmt::Display for Op {
186    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
187        match self {
188            Op::AddWrapping | Op::Add => write!(f, "+"),
189            Op::SubWrapping | Op::Sub => write!(f, "-"),
190            Op::MulWrapping | Op::Mul => write!(f, "*"),
191            Op::Div => write!(f, "/"),
192            Op::Rem => write!(f, "%"),
193        }
194    }
195}
196
197impl Op {
198    fn commutative(&self) -> bool {
199        matches!(self, Self::Add | Self::AddWrapping)
200    }
201}
202
203/// Dispatch the given `op` to the appropriate specialized kernel
204fn arithmetic_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result<ArrayRef, ArrowError> {
205    use DataType::*;
206    use IntervalUnit::*;
207    use TimeUnit::*;
208
209    macro_rules! integer_helper {
210        ($t:ty, $op:ident, $l:ident, $l_scalar:ident, $r:ident, $r_scalar:ident) => {
211            integer_op::<$t>($op, $l, $l_scalar, $r, $r_scalar)
212        };
213    }
214
215    let (l, l_scalar) = lhs.get();
216    let (r, r_scalar) = rhs.get();
217    downcast_integer! {
218        l.data_type(), r.data_type() => (integer_helper, op, l, l_scalar, r, r_scalar),
219        (Float16, Float16) => float_op::<Float16Type>(op, l, l_scalar, r, r_scalar),
220        (Float32, Float32) => float_op::<Float32Type>(op, l, l_scalar, r, r_scalar),
221        (Float64, Float64) => float_op::<Float64Type>(op, l, l_scalar, r, r_scalar),
222        (Timestamp(Second, _), _) => timestamp_op::<TimestampSecondType>(op, l, l_scalar, r, r_scalar),
223        (Timestamp(Millisecond, _), _) => timestamp_op::<TimestampMillisecondType>(op, l, l_scalar, r, r_scalar),
224        (Timestamp(Microsecond, _), _) => timestamp_op::<TimestampMicrosecondType>(op, l, l_scalar, r, r_scalar),
225        (Timestamp(Nanosecond, _), _) => timestamp_op::<TimestampNanosecondType>(op, l, l_scalar, r, r_scalar),
226        (Duration(Second), Duration(Second)) => duration_op::<DurationSecondType>(op, l, l_scalar, r, r_scalar),
227        (Duration(Millisecond), Duration(Millisecond)) => duration_op::<DurationMillisecondType>(op, l, l_scalar, r, r_scalar),
228        (Duration(Microsecond), Duration(Microsecond)) => duration_op::<DurationMicrosecondType>(op, l, l_scalar, r, r_scalar),
229        (Duration(Nanosecond), Duration(Nanosecond)) => duration_op::<DurationNanosecondType>(op, l, l_scalar, r, r_scalar),
230        (Interval(YearMonth), Interval(YearMonth)) => interval_op::<IntervalYearMonthType>(op, l, l_scalar, r, r_scalar),
231        (Interval(DayTime), Interval(DayTime)) => interval_op::<IntervalDayTimeType>(op, l, l_scalar, r, r_scalar),
232        (Interval(MonthDayNano), Interval(MonthDayNano)) => interval_op::<IntervalMonthDayNanoType>(op, l, l_scalar, r, r_scalar),
233        (Date32, _) => date_op::<Date32Type>(op, l, l_scalar, r, r_scalar),
234        (Date64, _) => date_op::<Date64Type>(op, l, l_scalar, r, r_scalar),
235        (Decimal128(_, _), Decimal128(_, _)) => decimal_op::<Decimal128Type>(op, l, l_scalar, r, r_scalar),
236        (Decimal256(_, _), Decimal256(_, _)) => decimal_op::<Decimal256Type>(op, l, l_scalar, r, r_scalar),
237        (l_t, r_t) => match (l_t, r_t) {
238            (Duration(_) | Interval(_), Date32 | Date64 | Timestamp(_, _)) if op.commutative() => {
239                arithmetic_op(op, rhs, lhs)
240            }
241            _ => Err(ArrowError::InvalidArgumentError(
242              format!("Invalid arithmetic operation: {l_t} {op} {r_t}")
243            ))
244        }
245    }
246}
247
248/// Perform an infallible binary operation on potentially scalar inputs
249macro_rules! op {
250    ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {
251        match ($l_s, $r_s) {
252            (true, true) | (false, false) => binary($l, $r, |$l, $r| $op)?,
253            (true, false) => match ($l.null_count() == 0).then(|| $l.value(0)) {
254                None => PrimitiveArray::new_null($r.len()),
255                Some($l) => $r.unary(|$r| $op),
256            },
257            (false, true) => match ($r.null_count() == 0).then(|| $r.value(0)) {
258                None => PrimitiveArray::new_null($l.len()),
259                Some($r) => $l.unary(|$l| $op),
260            },
261        }
262    };
263}
264
265/// Same as `op` but with a type hint for the returned array
266macro_rules! op_ref {
267    ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{
268        let array: PrimitiveArray<$t> = op!($l, $l_s, $r, $r_s, $op);
269        Arc::new(array)
270    }};
271}
272
273/// Perform a fallible binary operation on potentially scalar inputs
274macro_rules! try_op {
275    ($l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {
276        match ($l_s, $r_s) {
277            (true, true) | (false, false) => try_binary($l, $r, |$l, $r| $op)?,
278            (true, false) => match ($l.null_count() == 0).then(|| $l.value(0)) {
279                None => PrimitiveArray::new_null($r.len()),
280                Some($l) => $r.try_unary(|$r| $op)?,
281            },
282            (false, true) => match ($r.null_count() == 0).then(|| $r.value(0)) {
283                None => PrimitiveArray::new_null($l.len()),
284                Some($r) => $l.try_unary(|$l| $op)?,
285            },
286        }
287    };
288}
289
290/// Same as `try_op` but with a type hint for the returned array
291macro_rules! try_op_ref {
292    ($t:ty, $l:ident, $l_s:expr, $r:ident, $r_s:expr, $op:expr) => {{
293        let array: PrimitiveArray<$t> = try_op!($l, $l_s, $r, $r_s, $op);
294        Arc::new(array)
295    }};
296}
297
298/// Perform an arithmetic operation on integers
299fn integer_op<T: ArrowPrimitiveType>(
300    op: Op,
301    l: &dyn Array,
302    l_s: bool,
303    r: &dyn Array,
304    r_s: bool,
305) -> Result<ArrayRef, ArrowError> {
306    let l = l.as_primitive::<T>();
307    let r = r.as_primitive::<T>();
308    let array: PrimitiveArray<T> = match op {
309        Op::AddWrapping => op!(l, l_s, r, r_s, l.add_wrapping(r)),
310        Op::Add => try_op!(l, l_s, r, r_s, l.add_checked(r)),
311        Op::SubWrapping => op!(l, l_s, r, r_s, l.sub_wrapping(r)),
312        Op::Sub => try_op!(l, l_s, r, r_s, l.sub_checked(r)),
313        Op::MulWrapping => op!(l, l_s, r, r_s, l.mul_wrapping(r)),
314        Op::Mul => try_op!(l, l_s, r, r_s, l.mul_checked(r)),
315        Op::Div => try_op!(l, l_s, r, r_s, l.div_checked(r)),
316        Op::Rem => try_op!(l, l_s, r, r_s, l.mod_checked(r)),
317    };
318    Ok(Arc::new(array))
319}
320
321/// Perform an arithmetic operation on floats
322fn float_op<T: ArrowPrimitiveType>(
323    op: Op,
324    l: &dyn Array,
325    l_s: bool,
326    r: &dyn Array,
327    r_s: bool,
328) -> Result<ArrayRef, ArrowError> {
329    let l = l.as_primitive::<T>();
330    let r = r.as_primitive::<T>();
331    let array: PrimitiveArray<T> = match op {
332        Op::AddWrapping | Op::Add => op!(l, l_s, r, r_s, l.add_wrapping(r)),
333        Op::SubWrapping | Op::Sub => op!(l, l_s, r, r_s, l.sub_wrapping(r)),
334        Op::MulWrapping | Op::Mul => op!(l, l_s, r, r_s, l.mul_wrapping(r)),
335        Op::Div => op!(l, l_s, r, r_s, l.div_wrapping(r)),
336        Op::Rem => op!(l, l_s, r, r_s, l.mod_wrapping(r)),
337    };
338    Ok(Arc::new(array))
339}
340
341/// Arithmetic trait for timestamp arrays
342trait TimestampOp: ArrowTimestampType {
343    type Duration: ArrowPrimitiveType<Native = i64>;
344
345    fn add_year_month(timestamp: i64, delta: i32, tz: Tz) -> Option<i64>;
346    fn add_day_time(timestamp: i64, delta: IntervalDayTime, tz: Tz) -> Option<i64>;
347    fn add_month_day_nano(timestamp: i64, delta: IntervalMonthDayNano, tz: Tz) -> Option<i64>;
348
349    fn sub_year_month(timestamp: i64, delta: i32, tz: Tz) -> Option<i64>;
350    fn sub_day_time(timestamp: i64, delta: IntervalDayTime, tz: Tz) -> Option<i64>;
351    fn sub_month_day_nano(timestamp: i64, delta: IntervalMonthDayNano, tz: Tz) -> Option<i64>;
352}
353
354macro_rules! timestamp {
355    ($t:ty, $d:ty) => {
356        impl TimestampOp for $t {
357            type Duration = $d;
358
359            fn add_year_month(left: i64, right: i32, tz: Tz) -> Option<i64> {
360                Self::add_year_months(left, right, tz)
361            }
362
363            fn add_day_time(left: i64, right: IntervalDayTime, tz: Tz) -> Option<i64> {
364                Self::add_day_time(left, right, tz)
365            }
366
367            fn add_month_day_nano(left: i64, right: IntervalMonthDayNano, tz: Tz) -> Option<i64> {
368                Self::add_month_day_nano(left, right, tz)
369            }
370
371            fn sub_year_month(left: i64, right: i32, tz: Tz) -> Option<i64> {
372                Self::subtract_year_months(left, right, tz)
373            }
374
375            fn sub_day_time(left: i64, right: IntervalDayTime, tz: Tz) -> Option<i64> {
376                Self::subtract_day_time(left, right, tz)
377            }
378
379            fn sub_month_day_nano(left: i64, right: IntervalMonthDayNano, tz: Tz) -> Option<i64> {
380                Self::subtract_month_day_nano(left, right, tz)
381            }
382        }
383    };
384}
385timestamp!(TimestampSecondType, DurationSecondType);
386timestamp!(TimestampMillisecondType, DurationMillisecondType);
387timestamp!(TimestampMicrosecondType, DurationMicrosecondType);
388timestamp!(TimestampNanosecondType, DurationNanosecondType);
389
390/// Perform arithmetic operation on a timestamp array
391fn timestamp_op<T: TimestampOp>(
392    op: Op,
393    l: &dyn Array,
394    l_s: bool,
395    r: &dyn Array,
396    r_s: bool,
397) -> Result<ArrayRef, ArrowError> {
398    use DataType::*;
399    use IntervalUnit::*;
400
401    let l = l.as_primitive::<T>();
402    let l_tz: Tz = l.timezone().unwrap_or("+00:00").parse()?;
403
404    let array: PrimitiveArray<T> = match (op, r.data_type()) {
405        (Op::Sub | Op::SubWrapping, Timestamp(unit, _)) if unit == &T::UNIT => {
406            let r = r.as_primitive::<T>();
407            return Ok(try_op_ref!(T::Duration, l, l_s, r, r_s, l.sub_checked(r)));
408        }
409
410        (Op::Add | Op::AddWrapping, Duration(unit)) if unit == &T::UNIT => {
411            let r = r.as_primitive::<T::Duration>();
412            try_op!(l, l_s, r, r_s, l.add_checked(r))
413        }
414        (Op::Sub | Op::SubWrapping, Duration(unit)) if unit == &T::UNIT => {
415            let r = r.as_primitive::<T::Duration>();
416            try_op!(l, l_s, r, r_s, l.sub_checked(r))
417        }
418
419        (Op::Add | Op::AddWrapping, Interval(YearMonth)) => {
420            let r = r.as_primitive::<IntervalYearMonthType>();
421            try_op!(
422                l,
423                l_s,
424                r,
425                r_s,
426                T::add_year_month(l, r, l_tz).ok_or(ArrowError::ComputeError(
427                    "Timestamp out of range".to_string()
428                ))
429            )
430        }
431        (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => {
432            let r = r.as_primitive::<IntervalYearMonthType>();
433            try_op!(
434                l,
435                l_s,
436                r,
437                r_s,
438                T::sub_year_month(l, r, l_tz).ok_or(ArrowError::ComputeError(
439                    "Timestamp out of range".to_string()
440                ))
441            )
442        }
443
444        (Op::Add | Op::AddWrapping, Interval(DayTime)) => {
445            let r = r.as_primitive::<IntervalDayTimeType>();
446            try_op!(
447                l,
448                l_s,
449                r,
450                r_s,
451                T::add_day_time(l, r, l_tz).ok_or(ArrowError::ComputeError(
452                    "Timestamp out of range".to_string()
453                ))
454            )
455        }
456        (Op::Sub | Op::SubWrapping, Interval(DayTime)) => {
457            let r = r.as_primitive::<IntervalDayTimeType>();
458            try_op!(
459                l,
460                l_s,
461                r,
462                r_s,
463                T::sub_day_time(l, r, l_tz).ok_or(ArrowError::ComputeError(
464                    "Timestamp out of range".to_string()
465                ))
466            )
467        }
468
469        (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => {
470            let r = r.as_primitive::<IntervalMonthDayNanoType>();
471            try_op!(
472                l,
473                l_s,
474                r,
475                r_s,
476                T::add_month_day_nano(l, r, l_tz).ok_or(ArrowError::ComputeError(
477                    "Timestamp out of range".to_string()
478                ))
479            )
480        }
481        (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => {
482            let r = r.as_primitive::<IntervalMonthDayNanoType>();
483            try_op!(
484                l,
485                l_s,
486                r,
487                r_s,
488                T::sub_month_day_nano(l, r, l_tz).ok_or(ArrowError::ComputeError(
489                    "Timestamp out of range".to_string()
490                ))
491            )
492        }
493        _ => {
494            return Err(ArrowError::InvalidArgumentError(format!(
495                "Invalid timestamp arithmetic operation: {} {op} {}",
496                l.data_type(),
497                r.data_type()
498            )))
499        }
500    };
501    Ok(Arc::new(array.with_timezone_opt(l.timezone())))
502}
503
504/// Arithmetic trait for date arrays
505///
506/// Note: these should be fallible (#4456)
507trait DateOp: ArrowTemporalType {
508    fn add_year_month(timestamp: Self::Native, delta: i32) -> Self::Native;
509    fn add_day_time(timestamp: Self::Native, delta: IntervalDayTime) -> Self::Native;
510    fn add_month_day_nano(timestamp: Self::Native, delta: IntervalMonthDayNano) -> Self::Native;
511
512    fn sub_year_month(timestamp: Self::Native, delta: i32) -> Self::Native;
513    fn sub_day_time(timestamp: Self::Native, delta: IntervalDayTime) -> Self::Native;
514    fn sub_month_day_nano(timestamp: Self::Native, delta: IntervalMonthDayNano) -> Self::Native;
515}
516
517macro_rules! date {
518    ($t:ty) => {
519        impl DateOp for $t {
520            fn add_year_month(left: Self::Native, right: i32) -> Self::Native {
521                Self::add_year_months(left, right)
522            }
523
524            fn add_day_time(left: Self::Native, right: IntervalDayTime) -> Self::Native {
525                Self::add_day_time(left, right)
526            }
527
528            fn add_month_day_nano(left: Self::Native, right: IntervalMonthDayNano) -> Self::Native {
529                Self::add_month_day_nano(left, right)
530            }
531
532            fn sub_year_month(left: Self::Native, right: i32) -> Self::Native {
533                Self::subtract_year_months(left, right)
534            }
535
536            fn sub_day_time(left: Self::Native, right: IntervalDayTime) -> Self::Native {
537                Self::subtract_day_time(left, right)
538            }
539
540            fn sub_month_day_nano(left: Self::Native, right: IntervalMonthDayNano) -> Self::Native {
541                Self::subtract_month_day_nano(left, right)
542            }
543        }
544    };
545}
546date!(Date32Type);
547date!(Date64Type);
548
549/// Arithmetic trait for interval arrays
550trait IntervalOp: ArrowPrimitiveType {
551    fn add(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError>;
552    fn sub(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError>;
553}
554
555impl IntervalOp for IntervalYearMonthType {
556    fn add(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError> {
557        left.add_checked(right)
558    }
559
560    fn sub(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError> {
561        left.sub_checked(right)
562    }
563}
564
565impl IntervalOp for IntervalDayTimeType {
566    fn add(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError> {
567        let (l_days, l_ms) = Self::to_parts(left);
568        let (r_days, r_ms) = Self::to_parts(right);
569        let days = l_days.add_checked(r_days)?;
570        let ms = l_ms.add_checked(r_ms)?;
571        Ok(Self::make_value(days, ms))
572    }
573
574    fn sub(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError> {
575        let (l_days, l_ms) = Self::to_parts(left);
576        let (r_days, r_ms) = Self::to_parts(right);
577        let days = l_days.sub_checked(r_days)?;
578        let ms = l_ms.sub_checked(r_ms)?;
579        Ok(Self::make_value(days, ms))
580    }
581}
582
583impl IntervalOp for IntervalMonthDayNanoType {
584    fn add(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError> {
585        let (l_months, l_days, l_nanos) = Self::to_parts(left);
586        let (r_months, r_days, r_nanos) = Self::to_parts(right);
587        let months = l_months.add_checked(r_months)?;
588        let days = l_days.add_checked(r_days)?;
589        let nanos = l_nanos.add_checked(r_nanos)?;
590        Ok(Self::make_value(months, days, nanos))
591    }
592
593    fn sub(left: Self::Native, right: Self::Native) -> Result<Self::Native, ArrowError> {
594        let (l_months, l_days, l_nanos) = Self::to_parts(left);
595        let (r_months, r_days, r_nanos) = Self::to_parts(right);
596        let months = l_months.sub_checked(r_months)?;
597        let days = l_days.sub_checked(r_days)?;
598        let nanos = l_nanos.sub_checked(r_nanos)?;
599        Ok(Self::make_value(months, days, nanos))
600    }
601}
602
603/// Perform arithmetic operation on an interval array
604fn interval_op<T: IntervalOp>(
605    op: Op,
606    l: &dyn Array,
607    l_s: bool,
608    r: &dyn Array,
609    r_s: bool,
610) -> Result<ArrayRef, ArrowError> {
611    let l = l.as_primitive::<T>();
612    let r = r.as_primitive::<T>();
613    match op {
614        Op::Add | Op::AddWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, T::add(l, r))),
615        Op::Sub | Op::SubWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, T::sub(l, r))),
616        _ => Err(ArrowError::InvalidArgumentError(format!(
617            "Invalid interval arithmetic operation: {} {op} {}",
618            l.data_type(),
619            r.data_type()
620        ))),
621    }
622}
623
624fn duration_op<T: ArrowPrimitiveType>(
625    op: Op,
626    l: &dyn Array,
627    l_s: bool,
628    r: &dyn Array,
629    r_s: bool,
630) -> Result<ArrayRef, ArrowError> {
631    let l = l.as_primitive::<T>();
632    let r = r.as_primitive::<T>();
633    match op {
634        Op::Add | Op::AddWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, l.add_checked(r))),
635        Op::Sub | Op::SubWrapping => Ok(try_op_ref!(T, l, l_s, r, r_s, l.sub_checked(r))),
636        _ => Err(ArrowError::InvalidArgumentError(format!(
637            "Invalid duration arithmetic operation: {} {op} {}",
638            l.data_type(),
639            r.data_type()
640        ))),
641    }
642}
643
644/// Perform arithmetic operation on a date array
645fn date_op<T: DateOp>(
646    op: Op,
647    l: &dyn Array,
648    l_s: bool,
649    r: &dyn Array,
650    r_s: bool,
651) -> Result<ArrayRef, ArrowError> {
652    use DataType::*;
653    use IntervalUnit::*;
654
655    const NUM_SECONDS_IN_DAY: i64 = 60 * 60 * 24;
656
657    let r_t = r.data_type();
658    match (T::DATA_TYPE, op, r_t) {
659        (Date32, Op::Sub | Op::SubWrapping, Date32) => {
660            let l = l.as_primitive::<Date32Type>();
661            let r = r.as_primitive::<Date32Type>();
662            return Ok(op_ref!(
663                DurationSecondType,
664                l,
665                l_s,
666                r,
667                r_s,
668                ((l as i64) - (r as i64)) * NUM_SECONDS_IN_DAY
669            ));
670        }
671        (Date64, Op::Sub | Op::SubWrapping, Date64) => {
672            let l = l.as_primitive::<Date64Type>();
673            let r = r.as_primitive::<Date64Type>();
674            let result = try_op_ref!(DurationMillisecondType, l, l_s, r, r_s, l.sub_checked(r));
675            return Ok(result);
676        }
677        _ => {}
678    }
679
680    let l = l.as_primitive::<T>();
681    match (op, r_t) {
682        (Op::Add | Op::AddWrapping, Interval(YearMonth)) => {
683            let r = r.as_primitive::<IntervalYearMonthType>();
684            Ok(op_ref!(T, l, l_s, r, r_s, T::add_year_month(l, r)))
685        }
686        (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => {
687            let r = r.as_primitive::<IntervalYearMonthType>();
688            Ok(op_ref!(T, l, l_s, r, r_s, T::sub_year_month(l, r)))
689        }
690
691        (Op::Add | Op::AddWrapping, Interval(DayTime)) => {
692            let r = r.as_primitive::<IntervalDayTimeType>();
693            Ok(op_ref!(T, l, l_s, r, r_s, T::add_day_time(l, r)))
694        }
695        (Op::Sub | Op::SubWrapping, Interval(DayTime)) => {
696            let r = r.as_primitive::<IntervalDayTimeType>();
697            Ok(op_ref!(T, l, l_s, r, r_s, T::sub_day_time(l, r)))
698        }
699
700        (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => {
701            let r = r.as_primitive::<IntervalMonthDayNanoType>();
702            Ok(op_ref!(T, l, l_s, r, r_s, T::add_month_day_nano(l, r)))
703        }
704        (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => {
705            let r = r.as_primitive::<IntervalMonthDayNanoType>();
706            Ok(op_ref!(T, l, l_s, r, r_s, T::sub_month_day_nano(l, r)))
707        }
708
709        _ => Err(ArrowError::InvalidArgumentError(format!(
710            "Invalid date arithmetic operation: {} {op} {}",
711            l.data_type(),
712            r.data_type()
713        ))),
714    }
715}
716
717/// Perform arithmetic operation on decimal arrays
718fn decimal_op<T: DecimalType>(
719    op: Op,
720    l: &dyn Array,
721    l_s: bool,
722    r: &dyn Array,
723    r_s: bool,
724) -> Result<ArrayRef, ArrowError> {
725    let l = l.as_primitive::<T>();
726    let r = r.as_primitive::<T>();
727
728    let (p1, s1, p2, s2) = match (l.data_type(), r.data_type()) {
729        (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => (p1, s1, p2, s2),
730        (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => (p1, s1, p2, s2),
731        _ => unreachable!(),
732    };
733
734    // Follow the Hive decimal arithmetic rules
735    // https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
736    let array: PrimitiveArray<T> = match op {
737        Op::Add | Op::AddWrapping | Op::Sub | Op::SubWrapping => {
738            // max(s1, s2)
739            let result_scale = *s1.max(s2);
740
741            // max(s1, s2) + max(p1-s1, p2-s2) + 1
742            let result_precision =
743                (result_scale.saturating_add((*p1 as i8 - s1).max(*p2 as i8 - s2)) as u8)
744                    .saturating_add(1)
745                    .min(T::MAX_PRECISION);
746
747            let l_mul = T::Native::usize_as(10).pow_checked((result_scale - s1) as _)?;
748            let r_mul = T::Native::usize_as(10).pow_checked((result_scale - s2) as _)?;
749
750            match op {
751                Op::Add | Op::AddWrapping => {
752                    try_op!(
753                        l,
754                        l_s,
755                        r,
756                        r_s,
757                        l.mul_checked(l_mul)?.add_checked(r.mul_checked(r_mul)?)
758                    )
759                }
760                Op::Sub | Op::SubWrapping => {
761                    try_op!(
762                        l,
763                        l_s,
764                        r,
765                        r_s,
766                        l.mul_checked(l_mul)?.sub_checked(r.mul_checked(r_mul)?)
767                    )
768                }
769                _ => unreachable!(),
770            }
771            .with_precision_and_scale(result_precision, result_scale)?
772        }
773        Op::Mul | Op::MulWrapping => {
774            let result_precision = p1.saturating_add(p2 + 1).min(T::MAX_PRECISION);
775            let result_scale = s1.saturating_add(*s2);
776            if result_scale > T::MAX_SCALE {
777                // SQL standard says that if the resulting scale of a multiply operation goes
778                // beyond the maximum, rounding is not acceptable and thus an error occurs
779                return Err(ArrowError::InvalidArgumentError(format!(
780                    "Output scale of {} {op} {} would exceed max scale of {}",
781                    l.data_type(),
782                    r.data_type(),
783                    T::MAX_SCALE
784                )));
785            }
786
787            try_op!(l, l_s, r, r_s, l.mul_checked(r))
788                .with_precision_and_scale(result_precision, result_scale)?
789        }
790
791        Op::Div => {
792            // Follow postgres and MySQL adding a fixed scale increment of 4
793            // s1 + 4
794            let result_scale = s1.saturating_add(4).min(T::MAX_SCALE);
795            let mul_pow = result_scale - s1 + s2;
796
797            // p1 - s1 + s2 + result_scale
798            let result_precision = (mul_pow.saturating_add(*p1 as i8) as u8).min(T::MAX_PRECISION);
799
800            let (l_mul, r_mul) = match mul_pow.cmp(&0) {
801                Ordering::Greater => (
802                    T::Native::usize_as(10).pow_checked(mul_pow as _)?,
803                    T::Native::ONE,
804                ),
805                Ordering::Equal => (T::Native::ONE, T::Native::ONE),
806                Ordering::Less => (
807                    T::Native::ONE,
808                    T::Native::usize_as(10).pow_checked(mul_pow.neg_wrapping() as _)?,
809                ),
810            };
811
812            try_op!(
813                l,
814                l_s,
815                r,
816                r_s,
817                l.mul_checked(l_mul)?.div_checked(r.mul_checked(r_mul)?)
818            )
819            .with_precision_and_scale(result_precision, result_scale)?
820        }
821
822        Op::Rem => {
823            // max(s1, s2)
824            let result_scale = *s1.max(s2);
825            // min(p1-s1, p2 -s2) + max( s1,s2 )
826            let result_precision =
827                (result_scale.saturating_add((*p1 as i8 - s1).min(*p2 as i8 - s2)) as u8)
828                    .min(T::MAX_PRECISION);
829
830            let l_mul = T::Native::usize_as(10).pow_wrapping((result_scale - s1) as _);
831            let r_mul = T::Native::usize_as(10).pow_wrapping((result_scale - s2) as _);
832
833            try_op!(
834                l,
835                l_s,
836                r,
837                r_s,
838                l.mul_checked(l_mul)?.mod_checked(r.mul_checked(r_mul)?)
839            )
840            .with_precision_and_scale(result_precision, result_scale)?
841        }
842    };
843
844    Ok(Arc::new(array))
845}
846
847#[cfg(test)]
848mod tests {
849    use super::*;
850    use arrow_array::temporal_conversions::{as_date, as_datetime};
851    use arrow_buffer::{i256, ScalarBuffer};
852    use chrono::{DateTime, NaiveDate};
853
854    fn test_neg_primitive<T: ArrowPrimitiveType>(
855        input: &[T::Native],
856        out: Result<&[T::Native], &str>,
857    ) {
858        let a = PrimitiveArray::<T>::new(ScalarBuffer::from(input.to_vec()), None);
859        match out {
860            Ok(expected) => {
861                let result = neg(&a).unwrap();
862                assert_eq!(result.as_primitive::<T>().values(), expected);
863            }
864            Err(e) => {
865                let err = neg(&a).unwrap_err().to_string();
866                assert_eq!(e, err);
867            }
868        }
869    }
870
871    #[test]
872    fn test_neg() {
873        let input = &[1, -5, 2, 693, 3929];
874        let output = &[-1, 5, -2, -693, -3929];
875        test_neg_primitive::<Int32Type>(input, Ok(output));
876
877        let input = &[1, -5, 2, 693, 3929];
878        let output = &[-1, 5, -2, -693, -3929];
879        test_neg_primitive::<Int64Type>(input, Ok(output));
880        test_neg_primitive::<DurationSecondType>(input, Ok(output));
881        test_neg_primitive::<DurationMillisecondType>(input, Ok(output));
882        test_neg_primitive::<DurationMicrosecondType>(input, Ok(output));
883        test_neg_primitive::<DurationNanosecondType>(input, Ok(output));
884
885        let input = &[f32::MAX, f32::MIN, f32::INFINITY, 1.3, 0.5];
886        let output = &[f32::MIN, f32::MAX, f32::NEG_INFINITY, -1.3, -0.5];
887        test_neg_primitive::<Float32Type>(input, Ok(output));
888
889        test_neg_primitive::<Int32Type>(
890            &[i32::MIN],
891            Err("Arithmetic overflow: Overflow happened on: - -2147483648"),
892        );
893        test_neg_primitive::<Int64Type>(
894            &[i64::MIN],
895            Err("Arithmetic overflow: Overflow happened on: - -9223372036854775808"),
896        );
897        test_neg_primitive::<DurationSecondType>(
898            &[i64::MIN],
899            Err("Arithmetic overflow: Overflow happened on: - -9223372036854775808"),
900        );
901
902        let r = neg_wrapping(&Int32Array::from(vec![i32::MIN])).unwrap();
903        assert_eq!(r.as_primitive::<Int32Type>().value(0), i32::MIN);
904
905        let r = neg_wrapping(&Int64Array::from(vec![i64::MIN])).unwrap();
906        assert_eq!(r.as_primitive::<Int64Type>().value(0), i64::MIN);
907
908        let err = neg_wrapping(&DurationSecondArray::from(vec![i64::MIN]))
909            .unwrap_err()
910            .to_string();
911
912        assert_eq!(
913            err,
914            "Arithmetic overflow: Overflow happened on: - -9223372036854775808"
915        );
916
917        let a = Decimal128Array::from(vec![1, 3, -44, 2, 4])
918            .with_precision_and_scale(9, 6)
919            .unwrap();
920
921        let r = neg(&a).unwrap();
922        assert_eq!(r.data_type(), a.data_type());
923        assert_eq!(
924            r.as_primitive::<Decimal128Type>().values(),
925            &[-1, -3, 44, -2, -4]
926        );
927
928        let a = Decimal256Array::from(vec![
929            i256::from_i128(342),
930            i256::from_i128(-4949),
931            i256::from_i128(3),
932        ])
933        .with_precision_and_scale(9, 6)
934        .unwrap();
935
936        let r = neg(&a).unwrap();
937        assert_eq!(r.data_type(), a.data_type());
938        assert_eq!(
939            r.as_primitive::<Decimal256Type>().values(),
940            &[
941                i256::from_i128(-342),
942                i256::from_i128(4949),
943                i256::from_i128(-3),
944            ]
945        );
946
947        let a = IntervalYearMonthArray::from(vec![
948            IntervalYearMonthType::make_value(2, 4),
949            IntervalYearMonthType::make_value(2, -4),
950            IntervalYearMonthType::make_value(-3, -5),
951        ]);
952        let r = neg(&a).unwrap();
953        assert_eq!(
954            r.as_primitive::<IntervalYearMonthType>().values(),
955            &[
956                IntervalYearMonthType::make_value(-2, -4),
957                IntervalYearMonthType::make_value(-2, 4),
958                IntervalYearMonthType::make_value(3, 5),
959            ]
960        );
961
962        let a = IntervalDayTimeArray::from(vec![
963            IntervalDayTimeType::make_value(2, 4),
964            IntervalDayTimeType::make_value(2, -4),
965            IntervalDayTimeType::make_value(-3, -5),
966        ]);
967        let r = neg(&a).unwrap();
968        assert_eq!(
969            r.as_primitive::<IntervalDayTimeType>().values(),
970            &[
971                IntervalDayTimeType::make_value(-2, -4),
972                IntervalDayTimeType::make_value(-2, 4),
973                IntervalDayTimeType::make_value(3, 5),
974            ]
975        );
976
977        let a = IntervalMonthDayNanoArray::from(vec![
978            IntervalMonthDayNanoType::make_value(2, 4, 5953394),
979            IntervalMonthDayNanoType::make_value(2, -4, -45839),
980            IntervalMonthDayNanoType::make_value(-3, -5, 6944),
981        ]);
982        let r = neg(&a).unwrap();
983        assert_eq!(
984            r.as_primitive::<IntervalMonthDayNanoType>().values(),
985            &[
986                IntervalMonthDayNanoType::make_value(-2, -4, -5953394),
987                IntervalMonthDayNanoType::make_value(-2, 4, 45839),
988                IntervalMonthDayNanoType::make_value(3, 5, -6944),
989            ]
990        );
991    }
992
993    #[test]
994    fn test_integer() {
995        let a = Int32Array::from(vec![4, 3, 5, -6, 100]);
996        let b = Int32Array::from(vec![6, 2, 5, -7, 3]);
997        let result = add(&a, &b).unwrap();
998        assert_eq!(
999            result.as_ref(),
1000            &Int32Array::from(vec![10, 5, 10, -13, 103])
1001        );
1002        let result = sub(&a, &b).unwrap();
1003        assert_eq!(result.as_ref(), &Int32Array::from(vec![-2, 1, 0, 1, 97]));
1004        let result = div(&a, &b).unwrap();
1005        assert_eq!(result.as_ref(), &Int32Array::from(vec![0, 1, 1, 0, 33]));
1006        let result = mul(&a, &b).unwrap();
1007        assert_eq!(result.as_ref(), &Int32Array::from(vec![24, 6, 25, 42, 300]));
1008        let result = rem(&a, &b).unwrap();
1009        assert_eq!(result.as_ref(), &Int32Array::from(vec![4, 1, 0, -6, 1]));
1010
1011        let a = Int8Array::from(vec![Some(2), None, Some(45)]);
1012        let b = Int8Array::from(vec![Some(5), Some(3), None]);
1013        let result = add(&a, &b).unwrap();
1014        assert_eq!(result.as_ref(), &Int8Array::from(vec![Some(7), None, None]));
1015
1016        let a = UInt8Array::from(vec![56, 5, 3]);
1017        let b = UInt8Array::from(vec![200, 2, 5]);
1018        let err = add(&a, &b).unwrap_err().to_string();
1019        assert_eq!(err, "Arithmetic overflow: Overflow happened on: 56 + 200");
1020        let result = add_wrapping(&a, &b).unwrap();
1021        assert_eq!(result.as_ref(), &UInt8Array::from(vec![0, 7, 8]));
1022
1023        let a = UInt8Array::from(vec![34, 5, 3]);
1024        let b = UInt8Array::from(vec![200, 2, 5]);
1025        let err = sub(&a, &b).unwrap_err().to_string();
1026        assert_eq!(err, "Arithmetic overflow: Overflow happened on: 34 - 200");
1027        let result = sub_wrapping(&a, &b).unwrap();
1028        assert_eq!(result.as_ref(), &UInt8Array::from(vec![90, 3, 254]));
1029
1030        let a = UInt8Array::from(vec![34, 5, 3]);
1031        let b = UInt8Array::from(vec![200, 2, 5]);
1032        let err = mul(&a, &b).unwrap_err().to_string();
1033        assert_eq!(err, "Arithmetic overflow: Overflow happened on: 34 * 200");
1034        let result = mul_wrapping(&a, &b).unwrap();
1035        assert_eq!(result.as_ref(), &UInt8Array::from(vec![144, 10, 15]));
1036
1037        let a = Int16Array::from(vec![i16::MIN]);
1038        let b = Int16Array::from(vec![-1]);
1039        let err = div(&a, &b).unwrap_err().to_string();
1040        assert_eq!(
1041            err,
1042            "Arithmetic overflow: Overflow happened on: -32768 / -1"
1043        );
1044
1045        let a = Int16Array::from(vec![21]);
1046        let b = Int16Array::from(vec![0]);
1047        let err = div(&a, &b).unwrap_err().to_string();
1048        assert_eq!(err, "Divide by zero error");
1049
1050        let a = Int16Array::from(vec![21]);
1051        let b = Int16Array::from(vec![0]);
1052        let err = rem(&a, &b).unwrap_err().to_string();
1053        assert_eq!(err, "Divide by zero error");
1054    }
1055
1056    #[test]
1057    fn test_float() {
1058        let a = Float32Array::from(vec![1., f32::MAX, 6., -4., -1., 0.]);
1059        let b = Float32Array::from(vec![1., f32::MAX, f32::MAX, -3., 45., 0.]);
1060        let result = add(&a, &b).unwrap();
1061        assert_eq!(
1062            result.as_ref(),
1063            &Float32Array::from(vec![2., f32::INFINITY, f32::MAX, -7., 44.0, 0.])
1064        );
1065
1066        let result = sub(&a, &b).unwrap();
1067        assert_eq!(
1068            result.as_ref(),
1069            &Float32Array::from(vec![0., 0., f32::MIN, -1., -46., 0.])
1070        );
1071
1072        let result = mul(&a, &b).unwrap();
1073        assert_eq!(
1074            result.as_ref(),
1075            &Float32Array::from(vec![1., f32::INFINITY, f32::INFINITY, 12., -45., 0.])
1076        );
1077
1078        let result = div(&a, &b).unwrap();
1079        let r = result.as_primitive::<Float32Type>();
1080        assert_eq!(r.value(0), 1.);
1081        assert_eq!(r.value(1), 1.);
1082        assert!(r.value(2) < f32::EPSILON);
1083        assert_eq!(r.value(3), -4. / -3.);
1084        assert!(r.value(5).is_nan());
1085
1086        let result = rem(&a, &b).unwrap();
1087        let r = result.as_primitive::<Float32Type>();
1088        assert_eq!(&r.values()[..5], &[0., 0., 6., -1., -1.]);
1089        assert!(r.value(5).is_nan());
1090    }
1091
1092    #[test]
1093    fn test_decimal() {
1094        // 0.015 7.842 -0.577 0.334 -0.078 0.003
1095        let a = Decimal128Array::from(vec![15, 0, -577, 334, -78, 3])
1096            .with_precision_and_scale(12, 3)
1097            .unwrap();
1098
1099        // 5.4 0 -35.6 0.3 0.6 7.45
1100        let b = Decimal128Array::from(vec![54, 34, -356, 3, 6, 745])
1101            .with_precision_and_scale(12, 1)
1102            .unwrap();
1103
1104        let result = add(&a, &b).unwrap();
1105        assert_eq!(result.data_type(), &DataType::Decimal128(15, 3));
1106        assert_eq!(
1107            result.as_primitive::<Decimal128Type>().values(),
1108            &[5415, 3400, -36177, 634, 522, 74503]
1109        );
1110
1111        let result = sub(&a, &b).unwrap();
1112        assert_eq!(result.data_type(), &DataType::Decimal128(15, 3));
1113        assert_eq!(
1114            result.as_primitive::<Decimal128Type>().values(),
1115            &[-5385, -3400, 35023, 34, -678, -74497]
1116        );
1117
1118        let result = mul(&a, &b).unwrap();
1119        assert_eq!(result.data_type(), &DataType::Decimal128(25, 4));
1120        assert_eq!(
1121            result.as_primitive::<Decimal128Type>().values(),
1122            &[810, 0, 205412, 1002, -468, 2235]
1123        );
1124
1125        let result = div(&a, &b).unwrap();
1126        assert_eq!(result.data_type(), &DataType::Decimal128(17, 7));
1127        assert_eq!(
1128            result.as_primitive::<Decimal128Type>().values(),
1129            &[27777, 0, 162078, 11133333, -1300000, 402]
1130        );
1131
1132        let result = rem(&a, &b).unwrap();
1133        assert_eq!(result.data_type(), &DataType::Decimal128(12, 3));
1134        assert_eq!(
1135            result.as_primitive::<Decimal128Type>().values(),
1136            &[15, 0, -577, 34, -78, 3]
1137        );
1138
1139        let a = Decimal128Array::from(vec![1])
1140            .with_precision_and_scale(3, 3)
1141            .unwrap();
1142        let b = Decimal128Array::from(vec![1])
1143            .with_precision_and_scale(37, 37)
1144            .unwrap();
1145        let err = mul(&a, &b).unwrap_err().to_string();
1146        assert_eq!(err, "Invalid argument error: Output scale of Decimal128(3, 3) * Decimal128(37, 37) would exceed max scale of 38");
1147
1148        let a = Decimal128Array::from(vec![1])
1149            .with_precision_and_scale(3, -2)
1150            .unwrap();
1151        let err = add(&a, &b).unwrap_err().to_string();
1152        assert_eq!(err, "Arithmetic overflow: Overflow happened on: 10 ^ 39");
1153
1154        let a = Decimal128Array::from(vec![10])
1155            .with_precision_and_scale(3, -1)
1156            .unwrap();
1157        let err = add(&a, &b).unwrap_err().to_string();
1158        assert_eq!(
1159            err,
1160            "Arithmetic overflow: Overflow happened on: 10 * 100000000000000000000000000000000000000"
1161        );
1162
1163        let b = Decimal128Array::from(vec![0])
1164            .with_precision_and_scale(1, 1)
1165            .unwrap();
1166        let err = div(&a, &b).unwrap_err().to_string();
1167        assert_eq!(err, "Divide by zero error");
1168        let err = rem(&a, &b).unwrap_err().to_string();
1169        assert_eq!(err, "Divide by zero error");
1170    }
1171
1172    fn test_timestamp_impl<T: TimestampOp>() {
1173        let a = PrimitiveArray::<T>::new(vec![2000000, 434030324, 53943340].into(), None);
1174        let b = PrimitiveArray::<T>::new(vec![329593, 59349, 694994].into(), None);
1175
1176        let result = sub(&a, &b).unwrap();
1177        assert_eq!(
1178            result.as_primitive::<T::Duration>().values(),
1179            &[1670407, 433970975, 53248346]
1180        );
1181
1182        let r2 = add(&b, &result.as_ref()).unwrap();
1183        assert_eq!(r2.as_ref(), &a);
1184
1185        let r3 = add(&result.as_ref(), &b).unwrap();
1186        assert_eq!(r3.as_ref(), &a);
1187
1188        let format_array = |x: &dyn Array| -> Vec<String> {
1189            x.as_primitive::<T>()
1190                .values()
1191                .into_iter()
1192                .map(|x| as_datetime::<T>(*x).unwrap().to_string())
1193                .collect()
1194        };
1195
1196        let values = vec![
1197            "1970-01-01T00:00:00Z",
1198            "2010-04-01T04:00:20Z",
1199            "1960-01-30T04:23:20Z",
1200        ]
1201        .into_iter()
1202        .map(|x| T::make_value(DateTime::parse_from_rfc3339(x).unwrap().naive_utc()).unwrap())
1203        .collect();
1204
1205        let a = PrimitiveArray::<T>::new(values, None);
1206        let b = IntervalYearMonthArray::from(vec![
1207            IntervalYearMonthType::make_value(5, 34),
1208            IntervalYearMonthType::make_value(-2, 4),
1209            IntervalYearMonthType::make_value(7, -4),
1210        ]);
1211        let r4 = add(&a, &b).unwrap();
1212        assert_eq!(
1213            &format_array(r4.as_ref()),
1214            &[
1215                "1977-11-01 00:00:00".to_string(),
1216                "2008-08-01 04:00:20".to_string(),
1217                "1966-09-30 04:23:20".to_string()
1218            ]
1219        );
1220
1221        let r5 = sub(&r4, &b).unwrap();
1222        assert_eq!(r5.as_ref(), &a);
1223
1224        let b = IntervalDayTimeArray::from(vec![
1225            IntervalDayTimeType::make_value(5, 454000),
1226            IntervalDayTimeType::make_value(-34, 0),
1227            IntervalDayTimeType::make_value(7, -4000),
1228        ]);
1229        let r6 = add(&a, &b).unwrap();
1230        assert_eq!(
1231            &format_array(r6.as_ref()),
1232            &[
1233                "1970-01-06 00:07:34".to_string(),
1234                "2010-02-26 04:00:20".to_string(),
1235                "1960-02-06 04:23:16".to_string()
1236            ]
1237        );
1238
1239        let r7 = sub(&r6, &b).unwrap();
1240        assert_eq!(r7.as_ref(), &a);
1241
1242        let b = IntervalMonthDayNanoArray::from(vec![
1243            IntervalMonthDayNanoType::make_value(344, 34, -43_000_000_000),
1244            IntervalMonthDayNanoType::make_value(-593, -33, 13_000_000_000),
1245            IntervalMonthDayNanoType::make_value(5, 2, 493_000_000_000),
1246        ]);
1247        let r8 = add(&a, &b).unwrap();
1248        assert_eq!(
1249            &format_array(r8.as_ref()),
1250            &[
1251                "1998-10-04 23:59:17".to_string(),
1252                "1960-09-29 04:00:33".to_string(),
1253                "1960-07-02 04:31:33".to_string()
1254            ]
1255        );
1256
1257        let r9 = sub(&r8, &b).unwrap();
1258        // Note: subtraction is not the inverse of addition for intervals
1259        assert_eq!(
1260            &format_array(r9.as_ref()),
1261            &[
1262                "1970-01-02 00:00:00".to_string(),
1263                "2010-04-02 04:00:20".to_string(),
1264                "1960-01-31 04:23:20".to_string()
1265            ]
1266        );
1267    }
1268
1269    #[test]
1270    fn test_timestamp() {
1271        test_timestamp_impl::<TimestampSecondType>();
1272        test_timestamp_impl::<TimestampMillisecondType>();
1273        test_timestamp_impl::<TimestampMicrosecondType>();
1274        test_timestamp_impl::<TimestampNanosecondType>();
1275    }
1276
1277    #[test]
1278    fn test_interval() {
1279        let a = IntervalYearMonthArray::from(vec![
1280            IntervalYearMonthType::make_value(32, 4),
1281            IntervalYearMonthType::make_value(32, 4),
1282        ]);
1283        let b = IntervalYearMonthArray::from(vec![
1284            IntervalYearMonthType::make_value(-4, 6),
1285            IntervalYearMonthType::make_value(-3, 23),
1286        ]);
1287        let result = add(&a, &b).unwrap();
1288        assert_eq!(
1289            result.as_ref(),
1290            &IntervalYearMonthArray::from(vec![
1291                IntervalYearMonthType::make_value(28, 10),
1292                IntervalYearMonthType::make_value(29, 27)
1293            ])
1294        );
1295        let result = sub(&a, &b).unwrap();
1296        assert_eq!(
1297            result.as_ref(),
1298            &IntervalYearMonthArray::from(vec![
1299                IntervalYearMonthType::make_value(36, -2),
1300                IntervalYearMonthType::make_value(35, -19)
1301            ])
1302        );
1303
1304        let a = IntervalDayTimeArray::from(vec![
1305            IntervalDayTimeType::make_value(32, 4),
1306            IntervalDayTimeType::make_value(32, 4),
1307        ]);
1308        let b = IntervalDayTimeArray::from(vec![
1309            IntervalDayTimeType::make_value(-4, 6),
1310            IntervalDayTimeType::make_value(-3, 23),
1311        ]);
1312        let result = add(&a, &b).unwrap();
1313        assert_eq!(
1314            result.as_ref(),
1315            &IntervalDayTimeArray::from(vec![
1316                IntervalDayTimeType::make_value(28, 10),
1317                IntervalDayTimeType::make_value(29, 27)
1318            ])
1319        );
1320        let result = sub(&a, &b).unwrap();
1321        assert_eq!(
1322            result.as_ref(),
1323            &IntervalDayTimeArray::from(vec![
1324                IntervalDayTimeType::make_value(36, -2),
1325                IntervalDayTimeType::make_value(35, -19)
1326            ])
1327        );
1328        let a = IntervalMonthDayNanoArray::from(vec![
1329            IntervalMonthDayNanoType::make_value(32, 4, 4000000000000),
1330            IntervalMonthDayNanoType::make_value(32, 4, 45463000000000000),
1331        ]);
1332        let b = IntervalMonthDayNanoArray::from(vec![
1333            IntervalMonthDayNanoType::make_value(-4, 6, 46000000000000),
1334            IntervalMonthDayNanoType::make_value(-3, 23, 3564000000000000),
1335        ]);
1336        let result = add(&a, &b).unwrap();
1337        assert_eq!(
1338            result.as_ref(),
1339            &IntervalMonthDayNanoArray::from(vec![
1340                IntervalMonthDayNanoType::make_value(28, 10, 50000000000000),
1341                IntervalMonthDayNanoType::make_value(29, 27, 49027000000000000)
1342            ])
1343        );
1344        let result = sub(&a, &b).unwrap();
1345        assert_eq!(
1346            result.as_ref(),
1347            &IntervalMonthDayNanoArray::from(vec![
1348                IntervalMonthDayNanoType::make_value(36, -2, -42000000000000),
1349                IntervalMonthDayNanoType::make_value(35, -19, 41899000000000000)
1350            ])
1351        );
1352        let a = IntervalMonthDayNanoArray::from(vec![IntervalMonthDayNano::MAX]);
1353        let b = IntervalMonthDayNanoArray::from(vec![IntervalMonthDayNano::ONE]);
1354        let err = add(&a, &b).unwrap_err().to_string();
1355        assert_eq!(
1356            err,
1357            "Arithmetic overflow: Overflow happened on: 2147483647 + 1"
1358        );
1359    }
1360
1361    fn test_duration_impl<T: ArrowPrimitiveType<Native = i64>>() {
1362        let a = PrimitiveArray::<T>::new(vec![1000, 4394, -3944].into(), None);
1363        let b = PrimitiveArray::<T>::new(vec![4, -5, -243].into(), None);
1364
1365        let result = add(&a, &b).unwrap();
1366        assert_eq!(result.as_primitive::<T>().values(), &[1004, 4389, -4187]);
1367        let result = sub(&a, &b).unwrap();
1368        assert_eq!(result.as_primitive::<T>().values(), &[996, 4399, -3701]);
1369
1370        let err = mul(&a, &b).unwrap_err().to_string();
1371        assert!(
1372            err.contains("Invalid duration arithmetic operation"),
1373            "{err}"
1374        );
1375
1376        let err = div(&a, &b).unwrap_err().to_string();
1377        assert!(
1378            err.contains("Invalid duration arithmetic operation"),
1379            "{err}"
1380        );
1381
1382        let err = rem(&a, &b).unwrap_err().to_string();
1383        assert!(
1384            err.contains("Invalid duration arithmetic operation"),
1385            "{err}"
1386        );
1387
1388        let a = PrimitiveArray::<T>::new(vec![i64::MAX].into(), None);
1389        let b = PrimitiveArray::<T>::new(vec![1].into(), None);
1390        let err = add(&a, &b).unwrap_err().to_string();
1391        assert_eq!(
1392            err,
1393            "Arithmetic overflow: Overflow happened on: 9223372036854775807 + 1"
1394        );
1395    }
1396
1397    #[test]
1398    fn test_duration() {
1399        test_duration_impl::<DurationSecondType>();
1400        test_duration_impl::<DurationMillisecondType>();
1401        test_duration_impl::<DurationMicrosecondType>();
1402        test_duration_impl::<DurationNanosecondType>();
1403    }
1404
1405    fn test_date_impl<T: ArrowPrimitiveType, F>(f: F)
1406    where
1407        F: Fn(NaiveDate) -> T::Native,
1408        T::Native: TryInto<i64>,
1409    {
1410        let a = PrimitiveArray::<T>::new(
1411            vec![
1412                f(NaiveDate::from_ymd_opt(1979, 1, 30).unwrap()),
1413                f(NaiveDate::from_ymd_opt(2010, 4, 3).unwrap()),
1414                f(NaiveDate::from_ymd_opt(2008, 2, 29).unwrap()),
1415            ]
1416            .into(),
1417            None,
1418        );
1419
1420        let b = IntervalYearMonthArray::from(vec![
1421            IntervalYearMonthType::make_value(34, 2),
1422            IntervalYearMonthType::make_value(3, -3),
1423            IntervalYearMonthType::make_value(-12, 4),
1424        ]);
1425
1426        let format_array = |x: &dyn Array| -> Vec<String> {
1427            x.as_primitive::<T>()
1428                .values()
1429                .into_iter()
1430                .map(|x| {
1431                    as_date::<T>((*x).try_into().ok().unwrap())
1432                        .unwrap()
1433                        .to_string()
1434                })
1435                .collect()
1436        };
1437
1438        let result = add(&a, &b).unwrap();
1439        assert_eq!(
1440            &format_array(result.as_ref()),
1441            &[
1442                "2013-03-30".to_string(),
1443                "2013-01-03".to_string(),
1444                "1996-06-29".to_string(),
1445            ]
1446        );
1447        let result = sub(&result, &b).unwrap();
1448        assert_eq!(result.as_ref(), &a);
1449
1450        let b = IntervalDayTimeArray::from(vec![
1451            IntervalDayTimeType::make_value(34, 2),
1452            IntervalDayTimeType::make_value(3, -3),
1453            IntervalDayTimeType::make_value(-12, 4),
1454        ]);
1455
1456        let result = add(&a, &b).unwrap();
1457        assert_eq!(
1458            &format_array(result.as_ref()),
1459            &[
1460                "1979-03-05".to_string(),
1461                "2010-04-06".to_string(),
1462                "2008-02-17".to_string(),
1463            ]
1464        );
1465        let result = sub(&result, &b).unwrap();
1466        assert_eq!(result.as_ref(), &a);
1467
1468        let b = IntervalMonthDayNanoArray::from(vec![
1469            IntervalMonthDayNanoType::make_value(34, 2, -34353534),
1470            IntervalMonthDayNanoType::make_value(3, -3, 2443),
1471            IntervalMonthDayNanoType::make_value(-12, 4, 2323242423232),
1472        ]);
1473
1474        let result = add(&a, &b).unwrap();
1475        assert_eq!(
1476            &format_array(result.as_ref()),
1477            &[
1478                "1981-12-02".to_string(),
1479                "2010-06-30".to_string(),
1480                "2007-03-04".to_string(),
1481            ]
1482        );
1483        let result = sub(&result, &b).unwrap();
1484        assert_eq!(
1485            &format_array(result.as_ref()),
1486            &[
1487                "1979-01-31".to_string(),
1488                "2010-04-02".to_string(),
1489                "2008-02-29".to_string(),
1490            ]
1491        );
1492    }
1493
1494    #[test]
1495    fn test_date() {
1496        test_date_impl::<Date32Type, _>(Date32Type::from_naive_date);
1497        test_date_impl::<Date64Type, _>(Date64Type::from_naive_date);
1498
1499        let a = Date32Array::from(vec![i32::MIN, i32::MAX, 23, 7684]);
1500        let b = Date32Array::from(vec![i32::MIN, i32::MIN, -2, 45]);
1501        let result = sub(&a, &b).unwrap();
1502        assert_eq!(
1503            result.as_primitive::<DurationSecondType>().values(),
1504            &[0, 371085174288000, 2160000, 660009600]
1505        );
1506
1507        let a = Date64Array::from(vec![4343, 76676, 3434]);
1508        let b = Date64Array::from(vec![3, -5, 5]);
1509        let result = sub(&a, &b).unwrap();
1510        assert_eq!(
1511            result.as_primitive::<DurationMillisecondType>().values(),
1512            &[4340, 76681, 3429]
1513        );
1514
1515        let a = Date64Array::from(vec![i64::MAX]);
1516        let b = Date64Array::from(vec![-1]);
1517        let err = sub(&a, &b).unwrap_err().to_string();
1518        assert_eq!(
1519            err,
1520            "Arithmetic overflow: Overflow happened on: 9223372036854775807 - -1"
1521        );
1522    }
1523}