datafusion_functions/math/
gcd.rs1use arrow::array::{new_null_array, ArrayRef, AsArray, Int64Array, PrimitiveArray};
19use arrow::compute::try_binary;
20use arrow::datatypes::{DataType, Int64Type};
21use arrow::error::ArrowError;
22use std::any::Any;
23use std::mem::swap;
24use std::sync::Arc;
25
26use datafusion_common::{exec_err, internal_datafusion_err, Result, ScalarValue};
27use datafusion_expr::{
28 ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
29 Volatility,
30};
31use datafusion_macros::user_doc;
32
33#[user_doc(
34 doc_section(label = "Math Functions"),
35 description = "Returns the greatest common divisor of `expression_x` and `expression_y`. Returns 0 if both inputs are zero.",
36 syntax_example = "gcd(expression_x, expression_y)",
37 standard_argument(name = "expression_x", prefix = "First numeric"),
38 standard_argument(name = "expression_y", prefix = "Second numeric")
39)]
40#[derive(Debug)]
41pub struct GcdFunc {
42 signature: Signature,
43}
44
45impl Default for GcdFunc {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51impl GcdFunc {
52 pub fn new() -> Self {
53 Self {
54 signature: Signature::uniform(
55 2,
56 vec![DataType::Int64],
57 Volatility::Immutable,
58 ),
59 }
60 }
61}
62
63impl ScalarUDFImpl for GcdFunc {
64 fn as_any(&self) -> &dyn Any {
65 self
66 }
67
68 fn name(&self) -> &str {
69 "gcd"
70 }
71
72 fn signature(&self) -> &Signature {
73 &self.signature
74 }
75
76 fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
77 Ok(DataType::Int64)
78 }
79
80 fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
81 let args: [ColumnarValue; 2] = args.args.try_into().map_err(|_| {
82 internal_datafusion_err!("Expected 2 arguments for function gcd")
83 })?;
84
85 match args {
86 [ColumnarValue::Array(a), ColumnarValue::Array(b)] => {
87 compute_gcd_for_arrays(&a, &b)
88 }
89 [ColumnarValue::Scalar(ScalarValue::Int64(a)), ColumnarValue::Scalar(ScalarValue::Int64(b))] => {
90 match (a, b) {
91 (Some(a), Some(b)) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(
92 Some(compute_gcd(a, b)?),
93 ))),
94 _ => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))),
95 }
96 }
97 [ColumnarValue::Array(a), ColumnarValue::Scalar(ScalarValue::Int64(b))] => {
98 compute_gcd_with_scalar(&a, b)
99 }
100 [ColumnarValue::Scalar(ScalarValue::Int64(a)), ColumnarValue::Array(b)] => {
101 compute_gcd_with_scalar(&b, a)
102 }
103 _ => exec_err!("Unsupported argument types for function gcd"),
104 }
105 }
106
107 fn documentation(&self) -> Option<&Documentation> {
108 self.doc()
109 }
110}
111
112fn compute_gcd_for_arrays(a: &ArrayRef, b: &ArrayRef) -> Result<ColumnarValue> {
113 let a = a.as_primitive::<Int64Type>();
114 let b = b.as_primitive::<Int64Type>();
115 try_binary(a, b, compute_gcd)
116 .map(|arr: PrimitiveArray<Int64Type>| {
117 ColumnarValue::Array(Arc::new(arr) as ArrayRef)
118 })
119 .map_err(Into::into) }
121
122fn compute_gcd_with_scalar(arr: &ArrayRef, scalar: Option<i64>) -> Result<ColumnarValue> {
123 match scalar {
124 Some(scalar_value) => {
125 let result: Result<Int64Array> = arr
126 .as_primitive::<Int64Type>()
127 .iter()
128 .map(|val| match val {
129 Some(val) => Ok(Some(compute_gcd(val, scalar_value)?)),
130 _ => Ok(None),
131 })
132 .collect();
133
134 result.map(|arr| ColumnarValue::Array(Arc::new(arr) as ArrayRef))
135 }
136 None => Ok(ColumnarValue::Array(new_null_array(
137 &DataType::Int64,
138 arr.len(),
139 ))),
140 }
141}
142
143pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 {
145 if a == 0 {
146 return b;
147 }
148 if b == 0 {
149 return a;
150 }
151
152 let shift = (a | b).trailing_zeros();
153 a >>= a.trailing_zeros();
154 loop {
155 b >>= b.trailing_zeros();
156 if a > b {
157 swap(&mut a, &mut b);
158 }
159 b -= a;
160 if b == 0 {
161 return a << shift;
162 }
163 }
164}
165
166pub fn compute_gcd(x: i64, y: i64) -> Result<i64, ArrowError> {
168 let a = x.unsigned_abs();
169 let b = y.unsigned_abs();
170 let r = unsigned_gcd(a, b);
171 r.try_into().map_err(|_| {
173 ArrowError::ComputeError(format!("Signed integer overflow in GCD({x}, {y})"))
174 })
175}