datafusion_physical_expr/
scalar_function.rs1use std::any::Any;
33use std::fmt::{self, Debug, Formatter};
34use std::hash::Hash;
35use std::sync::Arc;
36
37use crate::expressions::Literal;
38use crate::PhysicalExpr;
39
40use arrow::array::{Array, RecordBatch};
41use arrow::datatypes::{DataType, Schema};
42use datafusion_common::{internal_err, DFSchema, Result, ScalarValue};
43use datafusion_expr::interval_arithmetic::Interval;
44use datafusion_expr::sort_properties::ExprProperties;
45use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf;
46use datafusion_expr::{
47 expr_vec_fmt, ColumnarValue, Expr, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF,
48};
49
50#[derive(Eq, PartialEq, Hash)]
52pub struct ScalarFunctionExpr {
53 fun: Arc<ScalarUDF>,
54 name: String,
55 args: Vec<Arc<dyn PhysicalExpr>>,
56 return_type: DataType,
57 nullable: bool,
58}
59
60impl Debug for ScalarFunctionExpr {
61 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
62 f.debug_struct("ScalarFunctionExpr")
63 .field("fun", &"<FUNC>")
64 .field("name", &self.name)
65 .field("args", &self.args)
66 .field("return_type", &self.return_type)
67 .finish()
68 }
69}
70
71impl ScalarFunctionExpr {
72 pub fn new(
74 name: &str,
75 fun: Arc<ScalarUDF>,
76 args: Vec<Arc<dyn PhysicalExpr>>,
77 return_type: DataType,
78 ) -> Self {
79 Self {
80 fun,
81 name: name.to_owned(),
82 args,
83 return_type,
84 nullable: true,
85 }
86 }
87
88 pub fn try_new(
90 fun: Arc<ScalarUDF>,
91 args: Vec<Arc<dyn PhysicalExpr>>,
92 schema: &Schema,
93 ) -> Result<Self> {
94 let name = fun.name().to_string();
95 let arg_types = args
96 .iter()
97 .map(|e| e.data_type(schema))
98 .collect::<Result<Vec<_>>>()?;
99
100 data_types_with_scalar_udf(&arg_types, &fun)?;
102
103 let nullables = args
104 .iter()
105 .map(|e| e.nullable(schema))
106 .collect::<Result<Vec<_>>>()?;
107
108 let arguments = args
109 .iter()
110 .map(|e| {
111 e.as_any()
112 .downcast_ref::<Literal>()
113 .map(|literal| literal.value())
114 })
115 .collect::<Vec<_>>();
116 let ret_args = ReturnTypeArgs {
117 arg_types: &arg_types,
118 scalar_arguments: &arguments,
119 nullables: &nullables,
120 };
121 let (return_type, nullable) = fun.return_type_from_args(ret_args)?.into_parts();
122 Ok(Self {
123 fun,
124 name,
125 args,
126 return_type,
127 nullable,
128 })
129 }
130
131 pub fn fun(&self) -> &ScalarUDF {
133 &self.fun
134 }
135
136 pub fn name(&self) -> &str {
138 &self.name
139 }
140
141 pub fn args(&self) -> &[Arc<dyn PhysicalExpr>] {
143 &self.args
144 }
145
146 pub fn return_type(&self) -> &DataType {
148 &self.return_type
149 }
150
151 pub fn with_nullable(mut self, nullable: bool) -> Self {
152 self.nullable = nullable;
153 self
154 }
155
156 pub fn nullable(&self) -> bool {
157 self.nullable
158 }
159}
160
161impl fmt::Display for ScalarFunctionExpr {
162 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
163 write!(f, "{}({})", self.name, expr_vec_fmt!(self.args))
164 }
165}
166
167impl PhysicalExpr for ScalarFunctionExpr {
168 fn as_any(&self) -> &dyn Any {
170 self
171 }
172
173 fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
174 Ok(self.return_type.clone())
175 }
176
177 fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
178 Ok(self.nullable)
179 }
180
181 fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
182 let args = self
183 .args
184 .iter()
185 .map(|e| e.evaluate(batch))
186 .collect::<Result<Vec<_>>>()?;
187
188 let input_empty = args.is_empty();
189 let input_all_scalar = args
190 .iter()
191 .all(|arg| matches!(arg, ColumnarValue::Scalar(_)));
192
193 let output = self.fun.invoke_with_args(ScalarFunctionArgs {
195 args,
196 number_rows: batch.num_rows(),
197 return_type: &self.return_type,
198 })?;
199
200 if let ColumnarValue::Array(array) = &output {
201 if array.len() != batch.num_rows() {
202 let preserve_scalar =
205 array.len() == 1 && !input_empty && input_all_scalar;
206 return if preserve_scalar {
207 ScalarValue::try_from_array(array, 0).map(ColumnarValue::Scalar)
208 } else {
209 internal_err!("UDF {} returned a different number of rows than expected. Expected: {}, Got: {}",
210 self.name, batch.num_rows(), array.len())
211 };
212 }
213 }
214 Ok(output)
215 }
216
217 fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
218 self.args.iter().collect()
219 }
220
221 fn with_new_children(
222 self: Arc<Self>,
223 children: Vec<Arc<dyn PhysicalExpr>>,
224 ) -> Result<Arc<dyn PhysicalExpr>> {
225 Ok(Arc::new(
226 ScalarFunctionExpr::new(
227 &self.name,
228 Arc::clone(&self.fun),
229 children,
230 self.return_type().clone(),
231 )
232 .with_nullable(self.nullable),
233 ))
234 }
235
236 fn evaluate_bounds(&self, children: &[&Interval]) -> Result<Interval> {
237 self.fun.evaluate_bounds(children)
238 }
239
240 fn propagate_constraints(
241 &self,
242 interval: &Interval,
243 children: &[&Interval],
244 ) -> Result<Option<Vec<Interval>>> {
245 self.fun.propagate_constraints(interval, children)
246 }
247
248 fn get_properties(&self, children: &[ExprProperties]) -> Result<ExprProperties> {
249 let sort_properties = self.fun.output_ordering(children)?;
250 let preserves_lex_ordering = self.fun.preserves_lex_ordering(children)?;
251 let children_range = children
252 .iter()
253 .map(|props| &props.range)
254 .collect::<Vec<_>>();
255 let range = self.fun().evaluate_bounds(&children_range)?;
256
257 Ok(ExprProperties {
258 sort_properties,
259 range,
260 preserves_lex_ordering,
261 })
262 }
263}
264
265#[deprecated(since = "45.0.0", note = "use ScalarFunctionExpr::new() instead")]
267pub fn create_physical_expr(
268 fun: &ScalarUDF,
269 input_phy_exprs: &[Arc<dyn PhysicalExpr>],
270 input_schema: &Schema,
271 args: &[Expr],
272 input_dfschema: &DFSchema,
273) -> Result<Arc<dyn PhysicalExpr>> {
274 let input_expr_types = input_phy_exprs
275 .iter()
276 .map(|e| e.data_type(input_schema))
277 .collect::<Result<Vec<_>>>()?;
278
279 data_types_with_scalar_udf(&input_expr_types, fun)?;
281
282 let return_type =
284 fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?;
285
286 Ok(Arc::new(
287 ScalarFunctionExpr::new(
288 fun.name(),
289 Arc::new(fun.clone()),
290 input_phy_exprs.to_vec(),
291 return_type,
292 )
293 .with_nullable(fun.is_nullable(args, input_dfschema)),
294 ))
295}