datafusion_functions/math/
factorial.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use arrow::{
19    array::{ArrayRef, Int64Array},
20    error::ArrowError,
21};
22use std::any::Any;
23use std::sync::Arc;
24
25use arrow::datatypes::DataType;
26use arrow::datatypes::DataType::Int64;
27
28use crate::utils::make_scalar_function;
29use datafusion_common::{
30    arrow_datafusion_err, exec_err, internal_datafusion_err, DataFusionError, Result,
31};
32use datafusion_expr::{
33    ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature,
34    Volatility,
35};
36use datafusion_macros::user_doc;
37
38#[user_doc(
39    doc_section(label = "Math Functions"),
40    description = "Factorial. Returns 1 if value is less than 2.",
41    syntax_example = "factorial(numeric_expression)",
42    standard_argument(name = "numeric_expression", prefix = "Numeric")
43)]
44#[derive(Debug)]
45pub struct FactorialFunc {
46    signature: Signature,
47}
48
49impl Default for FactorialFunc {
50    fn default() -> Self {
51        FactorialFunc::new()
52    }
53}
54
55impl FactorialFunc {
56    pub fn new() -> Self {
57        Self {
58            signature: Signature::uniform(1, vec![Int64], Volatility::Immutable),
59        }
60    }
61}
62
63impl ScalarUDFImpl for FactorialFunc {
64    fn as_any(&self) -> &dyn Any {
65        self
66    }
67
68    fn name(&self) -> &str {
69        "factorial"
70    }
71
72    fn signature(&self) -> &Signature {
73        &self.signature
74    }
75
76    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
77        Ok(Int64)
78    }
79
80    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
81        make_scalar_function(factorial, vec![])(&args.args)
82    }
83
84    fn documentation(&self) -> Option<&Documentation> {
85        self.doc()
86    }
87}
88
89/// Factorial SQL function
90fn factorial(args: &[ArrayRef]) -> Result<ArrayRef> {
91    match args[0].data_type() {
92        Int64 => {
93            let arg = downcast_named_arg!((&args[0]), "value", Int64Array);
94            Ok(arg
95                .iter()
96                .map(|a| match a {
97                    Some(a) => (2..=a)
98                        .try_fold(1i64, i64::checked_mul)
99                        .ok_or_else(|| {
100                            arrow_datafusion_err!(ArrowError::ComputeError(format!(
101                                "Overflow happened on FACTORIAL({a})"
102                            )))
103                        })
104                        .map(Some),
105                    _ => Ok(None),
106                })
107                .collect::<Result<Int64Array>>()
108                .map(Arc::new)? as ArrayRef)
109        }
110        other => exec_err!("Unsupported data type {other:?} for function factorial."),
111    }
112}
113
114#[cfg(test)]
115mod test {
116
117    use datafusion_common::cast::as_int64_array;
118
119    use super::*;
120
121    #[test]
122    fn test_factorial_i64() {
123        let args: Vec<ArrayRef> = vec![
124            Arc::new(Int64Array::from(vec![0, 1, 2, 4])), // input
125        ];
126
127        let result = factorial(&args).expect("failed to initialize function factorial");
128        let ints =
129            as_int64_array(&result).expect("failed to initialize function factorial");
130
131        let expected = Int64Array::from(vec![1, 1, 2, 24]);
132
133        assert_eq!(ints, &expected);
134    }
135}