surrealdb_core/sql/
function.rs

1use crate::ctx::{Context, MutableContext};
2use crate::dbs::Options;
3use crate::doc::CursorDoc;
4use crate::err::Error;
5use crate::fnc;
6use crate::iam::Action;
7use crate::sql::fmt::Fmt;
8use crate::sql::idiom::Idiom;
9use crate::sql::script::Script;
10use crate::sql::value::Value;
11use crate::sql::Permission;
12use futures::future::try_join_all;
13use reblessive::tree::Stk;
14use revision::revisioned;
15use serde::{Deserialize, Serialize};
16use std::cmp::Ordering;
17use std::fmt;
18
19use super::Kind;
20
21pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Function";
22
23#[revisioned(revision = 2)]
24#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
25#[serde(rename = "$surrealdb::private::sql::Function")]
26#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
27#[non_exhaustive]
28pub enum Function {
29	Normal(String, Vec<Value>),
30	Custom(String, Vec<Value>),
31	Script(Script, Vec<Value>),
32	#[revision(
33		end = 2,
34		convert_fn = "convert_anonymous_arg_computation",
35		fields_name = "OldAnonymousFields"
36	)]
37	Anonymous(Value, Vec<Value>),
38	#[revision(start = 2)]
39	Anonymous(Value, Vec<Value>, bool),
40	// Add new variants here
41}
42
43impl Function {
44	fn convert_anonymous_arg_computation(
45		old: OldAnonymousFields,
46		_revision: u16,
47	) -> Result<Self, revision::Error> {
48		Ok(Function::Anonymous(old.0, old.1, false))
49	}
50}
51
52pub(crate) enum OptimisedAggregate {
53	None,
54	Count,
55	CountFunction,
56	MathMax,
57	MathMin,
58	MathSum,
59	MathMean,
60	TimeMax,
61	TimeMin,
62}
63
64impl PartialOrd for Function {
65	#[inline]
66	fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
67		None
68	}
69}
70
71impl Function {
72	/// Get function name if applicable
73	pub fn name(&self) -> Option<&str> {
74		match self {
75			Self::Normal(n, _) => Some(n.as_str()),
76			Self::Custom(n, _) => Some(n.as_str()),
77			_ => None,
78		}
79	}
80	/// Get function arguments if applicable
81	pub fn args(&self) -> &[Value] {
82		match self {
83			Self::Normal(_, a) => a,
84			Self::Custom(_, a) => a,
85			_ => &[],
86		}
87	}
88	/// Convert function call to a field name
89	pub fn to_idiom(&self) -> Idiom {
90		match self {
91			Self::Anonymous(_, _, _) => "function".to_string().into(),
92			Self::Script(_, _) => "function".to_string().into(),
93			Self::Normal(f, _) => f.to_owned().into(),
94			Self::Custom(f, _) => format!("fn::{f}").into(),
95		}
96	}
97	/// Checks if this function invocation is writable
98	pub fn writeable(&self) -> bool {
99		match self {
100			Self::Custom(_, _) => true,
101			Self::Script(_, _) => true,
102			Self::Normal(f, _) if f == "api::invoke" => true,
103			_ => self.args().iter().any(Value::writeable),
104		}
105	}
106	/// Convert this function to an aggregate
107	pub fn aggregate(&self, val: Value) -> Result<Self, Error> {
108		match self {
109			Self::Normal(n, a) => {
110				let mut a = a.to_owned();
111				match a.len() {
112					0 => a.insert(0, val),
113					_ => {
114						a.remove(0);
115						a.insert(0, val);
116					}
117				}
118				Ok(Self::Normal(n.to_owned(), a))
119			}
120			_ => Err(fail!("Encountered a non-aggregate function: {self:?}")),
121		}
122	}
123	/// Check if this function is a custom function
124	pub fn is_custom(&self) -> bool {
125		matches!(self, Self::Custom(_, _))
126	}
127
128	/// Check if this function is a scripting function
129	pub fn is_script(&self) -> bool {
130		matches!(self, Self::Script(_, _))
131	}
132
133	/// Check if all arguments are static values
134	pub fn is_static(&self) -> bool {
135		match self {
136			Self::Normal(_, a) => a.iter().all(Value::is_static),
137			_ => false,
138		}
139	}
140
141	/// Check if this function is a closure function
142	pub fn is_inline(&self) -> bool {
143		matches!(self, Self::Anonymous(_, _, _))
144	}
145
146	/// Check if this function is a rolling function
147	pub fn is_rolling(&self) -> bool {
148		match self {
149			Self::Normal(f, _) if f == "count" => true,
150			Self::Normal(f, _) if f == "math::max" => true,
151			Self::Normal(f, _) if f == "math::mean" => true,
152			Self::Normal(f, _) if f == "math::min" => true,
153			Self::Normal(f, _) if f == "math::sum" => true,
154			Self::Normal(f, _) if f == "time::max" => true,
155			Self::Normal(f, _) if f == "time::min" => true,
156			_ => false,
157		}
158	}
159	/// Check if this function is a grouping function
160	pub fn is_aggregate(&self) -> bool {
161		match self {
162			Self::Normal(f, _) if f == "array::distinct" => true,
163			Self::Normal(f, _) if f == "array::first" => true,
164			Self::Normal(f, _) if f == "array::flatten" => true,
165			Self::Normal(f, _) if f == "array::group" => true,
166			Self::Normal(f, _) if f == "array::last" => true,
167			Self::Normal(f, _) if f == "count" => true,
168			Self::Normal(f, _) if f == "math::bottom" => true,
169			Self::Normal(f, _) if f == "math::interquartile" => true,
170			Self::Normal(f, _) if f == "math::max" => true,
171			Self::Normal(f, _) if f == "math::mean" => true,
172			Self::Normal(f, _) if f == "math::median" => true,
173			Self::Normal(f, _) if f == "math::midhinge" => true,
174			Self::Normal(f, _) if f == "math::min" => true,
175			Self::Normal(f, _) if f == "math::mode" => true,
176			Self::Normal(f, _) if f == "math::nearestrank" => true,
177			Self::Normal(f, _) if f == "math::percentile" => true,
178			Self::Normal(f, _) if f == "math::sample" => true,
179			Self::Normal(f, _) if f == "math::spread" => true,
180			Self::Normal(f, _) if f == "math::stddev" => true,
181			Self::Normal(f, _) if f == "math::sum" => true,
182			Self::Normal(f, _) if f == "math::top" => true,
183			Self::Normal(f, _) if f == "math::trimean" => true,
184			Self::Normal(f, _) if f == "math::variance" => true,
185			Self::Normal(f, _) if f == "time::max" => true,
186			Self::Normal(f, _) if f == "time::min" => true,
187			_ => false,
188		}
189	}
190	pub(crate) fn get_optimised_aggregate(&self) -> OptimisedAggregate {
191		match self {
192			Self::Normal(f, v) if f == "count" => {
193				if v.is_empty() {
194					OptimisedAggregate::Count
195				} else {
196					OptimisedAggregate::CountFunction
197				}
198			}
199			Self::Normal(f, _) if f == "math::max" => OptimisedAggregate::MathMax,
200			Self::Normal(f, _) if f == "math::mean" => OptimisedAggregate::MathMean,
201			Self::Normal(f, _) if f == "math::min" => OptimisedAggregate::MathMin,
202			Self::Normal(f, _) if f == "math::sum" => OptimisedAggregate::MathSum,
203			Self::Normal(f, _) if f == "time::max" => OptimisedAggregate::TimeMax,
204			Self::Normal(f, _) if f == "time::min" => OptimisedAggregate::TimeMin,
205			_ => OptimisedAggregate::None,
206		}
207	}
208
209	pub(crate) fn is_count_all(&self) -> bool {
210		matches!(self, Self::Normal(f, p) if f == "count" && p.is_empty() )
211	}
212}
213
214impl Function {
215	/// Process this type returning a computed simple Value
216	///
217	/// Was marked recursive
218	pub(crate) async fn compute(
219		&self,
220		stk: &mut Stk,
221		ctx: &Context,
222		opt: &Options,
223		doc: Option<&CursorDoc>,
224	) -> Result<Value, Error> {
225		// Ensure futures are run
226		let opt = &opt.new_with_futures(true);
227		// Process the function type
228		match self {
229			Self::Normal(s, x) => {
230				// Check this function is allowed
231				ctx.check_allowed_function(s)?;
232				// Compute the function arguments
233				let a = stk
234					.scope(|scope| {
235						try_join_all(
236							x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
237						)
238					})
239					.await?;
240				// Run the normal function
241				fnc::run(stk, ctx, opt, doc, s, a).await
242			}
243			Self::Anonymous(v, x, args_computed) => {
244				let val = match v {
245					c @ Value::Closure(_) => c.clone(),
246					Value::Param(p) => ctx.value(p).cloned().unwrap_or(Value::None),
247					Value::Block(_) | Value::Subquery(_) | Value::Idiom(_) | Value::Function(_) => {
248						stk.run(|stk| v.compute(stk, ctx, opt, doc)).await?
249					}
250					_ => Value::None,
251				};
252
253				match val {
254					Value::Closure(closure) => {
255						// Compute the function arguments
256						let a =
257							match args_computed {
258								true => x.clone(),
259								false => {
260									stk.scope(|scope| {
261										try_join_all(x.iter().map(|v| {
262											scope.run(|stk| v.compute(stk, ctx, opt, doc))
263										}))
264									})
265									.await?
266								}
267							};
268
269						stk.run(|stk| closure.compute(stk, ctx, opt, doc, a)).await
270					}
271					v => Err(Error::InvalidFunction {
272						name: "ANONYMOUS".to_string(),
273						message: format!("'{}' is not a function", v.kindof()),
274					}),
275				}
276			}
277			Self::Custom(s, x) => {
278				// Get the full name of this function
279				let name = format!("fn::{s}");
280				// Check this function is allowed
281				ctx.check_allowed_function(name.as_str())?;
282				// Get the function definition
283				let (ns, db) = opt.ns_db()?;
284				let val = ctx.tx().get_db_function(ns, db, s).await?;
285				// Check permissions
286				if opt.check_perms(Action::View)? {
287					match &val.permissions {
288						Permission::Full => (),
289						Permission::None => {
290							return Err(Error::FunctionPermissions {
291								name: s.to_owned(),
292							})
293						}
294						Permission::Specific(e) => {
295							// Disable permissions
296							let opt = &opt.new_with_perms(false);
297							// Process the PERMISSION clause
298							if !stk.run(|stk| e.compute(stk, ctx, opt, doc)).await?.is_truthy() {
299								return Err(Error::FunctionPermissions {
300									name: s.to_owned(),
301								});
302							}
303						}
304					}
305				}
306				// Get the number of function arguments
307				let max_args_len = val.args.len();
308				// Track the number of required arguments
309				let mut min_args_len = 0;
310				// Check for any final optional arguments
311				val.args.iter().rev().for_each(|(_, kind)| match kind {
312					Kind::Option(_) if min_args_len == 0 => {}
313					Kind::Any if min_args_len == 0 => {}
314					_ => min_args_len += 1,
315				});
316				// Check the necessary arguments are passed
317				if x.len() < min_args_len || max_args_len < x.len() {
318					return Err(Error::InvalidArguments {
319						name: format!("fn::{}", val.name),
320						message: match (min_args_len, max_args_len) {
321							(1, 1) => String::from("The function expects 1 argument."),
322							(r, t) if r == t => format!("The function expects {r} arguments."),
323							(r, t) => format!("The function expects {r} to {t} arguments."),
324						},
325					});
326				}
327				// Compute the function arguments
328				let a = stk
329					.scope(|scope| {
330						try_join_all(
331							x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
332						)
333					})
334					.await?;
335				// Duplicate context
336				let mut ctx = MutableContext::new_isolated(ctx);
337				// Process the function arguments
338				for (val, (name, kind)) in a.into_iter().zip(&val.args) {
339					ctx.add_value(name.to_raw(), val.coerce_to(kind)?.into());
340				}
341				let ctx = ctx.freeze();
342				// Run the custom function
343				let result = match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await {
344					Err(Error::Return {
345						value,
346					}) => Ok(value),
347					res => res,
348				}?;
349
350				if let Some(ref returns) = val.returns {
351					result
352						.coerce_to(returns)
353						.map_err(|e| e.function_check_from_coerce(val.name.to_string()))
354				} else {
355					Ok(result)
356				}
357			}
358			#[allow(unused_variables)]
359			Self::Script(s, x) => {
360				#[cfg(feature = "scripting")]
361				{
362					// Check if scripting is allowed
363					ctx.check_allowed_scripting()?;
364					// Compute the function arguments
365					let a = stk
366						.scope(|scope| {
367							try_join_all(
368								x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
369							)
370						})
371						.await?;
372					// Run the script function
373					fnc::script::run(ctx, opt, doc, s, a).await
374				}
375				#[cfg(not(feature = "scripting"))]
376				{
377					Err(Error::InvalidScript {
378						message: String::from("Embedded functions are not enabled."),
379					})
380				}
381			}
382		}
383	}
384}
385
386impl fmt::Display for Function {
387	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
388		match self {
389			Self::Normal(s, e) => write!(f, "{s}({})", Fmt::comma_separated(e)),
390			Self::Custom(s, e) => write!(f, "fn::{s}({})", Fmt::comma_separated(e)),
391			Self::Script(s, e) => write!(f, "function({}) {{{s}}}", Fmt::comma_separated(e)),
392			Self::Anonymous(p, e, _) => write!(f, "{p}({})", Fmt::comma_separated(e)),
393		}
394	}
395}