datafusion_physical_expr/
scalar_function.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Declaration of built-in (scalar) functions.
19//! This module contains built-in functions' enumeration and metadata.
20//!
21//! Generally, a function has:
22//! * a signature
23//! * a return type, that is a function of the incoming argument's types
24//! * the computation, that must accept each valid signature
25//!
26//! * Signature: see `Signature`
27//! * Return type: a function `(arg_types) -> return_type`. E.g. for sqrt, ([f32]) -> f32, ([f64]) -> f64.
28//!
29//! This module also has a set of coercion rules to improve user experience: if an argument i32 is passed
30//! to a function that supports f64, it is coerced to f64.
31
32use std::any::Any;
33use std::fmt::{self, Debug, Formatter};
34use std::hash::Hash;
35use std::sync::Arc;
36
37use crate::expressions::Literal;
38use crate::PhysicalExpr;
39
40use arrow::array::{Array, RecordBatch};
41use arrow::datatypes::{DataType, Schema};
42use datafusion_common::{internal_err, DFSchema, Result, ScalarValue};
43use datafusion_expr::interval_arithmetic::Interval;
44use datafusion_expr::sort_properties::ExprProperties;
45use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf;
46use datafusion_expr::{
47    expr_vec_fmt, ColumnarValue, Expr, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF,
48};
49
50/// Physical expression of a scalar function
51#[derive(Eq, PartialEq, Hash)]
52pub struct ScalarFunctionExpr {
53    fun: Arc<ScalarUDF>,
54    name: String,
55    args: Vec<Arc<dyn PhysicalExpr>>,
56    return_type: DataType,
57    nullable: bool,
58}
59
60impl Debug for ScalarFunctionExpr {
61    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
62        f.debug_struct("ScalarFunctionExpr")
63            .field("fun", &"<FUNC>")
64            .field("name", &self.name)
65            .field("args", &self.args)
66            .field("return_type", &self.return_type)
67            .finish()
68    }
69}
70
71impl ScalarFunctionExpr {
72    /// Create a new Scalar function
73    pub fn new(
74        name: &str,
75        fun: Arc<ScalarUDF>,
76        args: Vec<Arc<dyn PhysicalExpr>>,
77        return_type: DataType,
78    ) -> Self {
79        Self {
80            fun,
81            name: name.to_owned(),
82            args,
83            return_type,
84            nullable: true,
85        }
86    }
87
88    /// Create a new Scalar function
89    pub fn try_new(
90        fun: Arc<ScalarUDF>,
91        args: Vec<Arc<dyn PhysicalExpr>>,
92        schema: &Schema,
93    ) -> Result<Self> {
94        let name = fun.name().to_string();
95        let arg_types = args
96            .iter()
97            .map(|e| e.data_type(schema))
98            .collect::<Result<Vec<_>>>()?;
99
100        // verify that input data types is consistent with function's `TypeSignature`
101        data_types_with_scalar_udf(&arg_types, &fun)?;
102
103        let nullables = args
104            .iter()
105            .map(|e| e.nullable(schema))
106            .collect::<Result<Vec<_>>>()?;
107
108        let arguments = args
109            .iter()
110            .map(|e| {
111                e.as_any()
112                    .downcast_ref::<Literal>()
113                    .map(|literal| literal.value())
114            })
115            .collect::<Vec<_>>();
116        let ret_args = ReturnTypeArgs {
117            arg_types: &arg_types,
118            scalar_arguments: &arguments,
119            nullables: &nullables,
120        };
121        let (return_type, nullable) = fun.return_type_from_args(ret_args)?.into_parts();
122        Ok(Self {
123            fun,
124            name,
125            args,
126            return_type,
127            nullable,
128        })
129    }
130
131    /// Get the scalar function implementation
132    pub fn fun(&self) -> &ScalarUDF {
133        &self.fun
134    }
135
136    /// The name for this expression
137    pub fn name(&self) -> &str {
138        &self.name
139    }
140
141    /// Input arguments
142    pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
143        &self.args
144    }
145
146    /// Data type produced by this expression
147    pub fn return_type(&self) -> &DataType {
148        &self.return_type
149    }
150
151    pub fn with_nullable(mut self, nullable: bool) -> Self {
152        self.nullable = nullable;
153        self
154    }
155
156    pub fn nullable(&self) -> bool {
157        self.nullable
158    }
159}
160
161impl fmt::Display for ScalarFunctionExpr {
162    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
163        write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
164    }
165}
166
167impl PhysicalExpr for ScalarFunctionExpr {
168    /// Return a reference to Any that can be used for downcasting
169    fn as_any(&self) -> &dyn Any {
170        self
171    }
172
173    fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
174        Ok(self.return_type.clone())
175    }
176
177    fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
178        Ok(self.nullable)
179    }
180
181    fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
182        let args = self
183            .args
184            .iter()
185            .map(|e| e.evaluate(batch))
186            .collect::<Result<Vec<_>>>()?;
187
188        let input_empty = args.is_empty();
189        let input_all_scalar = args
190            .iter()
191            .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
192
193        // evaluate the function
194        let output = self.fun.invoke_with_args(ScalarFunctionArgs {
195            args,
196            number_rows: batch.num_rows(),
197            return_type: &self.return_type,
198        })?;
199
200        if let ColumnarValue::Array(array) = &output {
201            if array.len() != batch.num_rows() {
202                // If the arguments are a non-empty slice of scalar values, we can assume that
203                // returning a one-element array is equivalent to returning a scalar.
204                let preserve_scalar =
205                    array.len() == 1 && !input_empty && input_all_scalar;
206                return if preserve_scalar {
207                    ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
208                } else {
209                    internal_err!("UDF {} returned a different number of rows than expected. Expected: {}, Got: {}",
210                            self.name, batch.num_rows(), array.len())
211                };
212            }
213        }
214        Ok(output)
215    }
216
217    fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
218        self.args.iter().collect()
219    }
220
221    fn with_new_children(
222        self: Arc<Self>,
223        children: Vec<Arc<dyn PhysicalExpr>>,
224    ) -> Result<Arc<dyn PhysicalExpr>> {
225        Ok(Arc::new(
226            ScalarFunctionExpr::new(
227                &self.name,
228                Arc::clone(&self.fun),
229                children,
230                self.return_type().clone(),
231            )
232            .with_nullable(self.nullable),
233        ))
234    }
235
236    fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
237        self.fun.evaluate_bounds(children)
238    }
239
240    fn propagate_constraints(
241        &self,
242        interval: &Interval,
243        children: &[&Interval],
244    ) -> Result<Option<Vec<Interval>>> {
245        self.fun.propagate_constraints(interval, children)
246    }
247
248    fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
249        let sort_properties = self.fun.output_ordering(children)?;
250        let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
251        let children_range = children
252            .iter()
253            .map(|props| &props.range)
254            .collect::<Vec<_>>();
255        let range = self.fun().evaluate_bounds(&children_range)?;
256
257        Ok(ExprProperties {
258            sort_properties,
259            range,
260            preserves_lex_ordering,
261        })
262    }
263}
264
265/// Create a physical expression for the UDF.
266#[deprecated(since = "45.0.0", note = "use ScalarFunctionExpr::new() instead")]
267pub fn create_physical_expr(
268    fun: &ScalarUDF,
269    input_phy_exprs: &[Arc<dyn PhysicalExpr>],
270    input_schema: &Schema,
271    args: &[Expr],
272    input_dfschema: &DFSchema,
273) -> Result<Arc<dyn PhysicalExpr>> {
274    let input_expr_types = input_phy_exprs
275        .iter()
276        .map(|e| e.data_type(input_schema))
277        .collect::<Result<Vec<_>>>()?;
278
279    // verify that input data types is consistent with function's `TypeSignature`
280    data_types_with_scalar_udf(&input_expr_types, fun)?;
281
282    // Since we have arg_types, we don't need args and schema.
283    let return_type =
284        fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?;
285
286    Ok(Arc::new(
287        ScalarFunctionExpr::new(
288            fun.name(),
289            Arc::new(fun.clone()),
290            input_phy_exprs.to_vec(),
291            return_type,
292        )
293        .with_nullable(fun.is_nullable(args, input_dfschema)),
294    ))
295}