1use crate::ctx::{Context, MutableContext};
2use crate::dbs::Options;
3use crate::doc::CursorDoc;
4use crate::err::Error;
5use crate::fnc;
6use crate::iam::Action;
7use crate::sql::fmt::Fmt;
8use crate::sql::idiom::Idiom;
9use crate::sql::script::Script;
10use crate::sql::value::Value;
11use crate::sql::Permission;
12use futures::future::try_join_all;
13use reblessive::tree::Stk;
14use revision::revisioned;
15use serde::{Deserialize, Serialize};
16use std::cmp::Ordering;
17use std::fmt;
18
19use super::Kind;
20
21pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Function";
22
23#[revisioned(revision = 2)]
24#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
25#[serde(rename = "$surrealdb::private::sql::Function")]
26#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
27#[non_exhaustive]
28pub enum Function {
29 Normal(String, Vec<Value>),
30 Custom(String, Vec<Value>),
31 Script(Script, Vec<Value>),
32 #[revision(
33 end = 2,
34 convert_fn = "convert_anonymous_arg_computation",
35 fields_name = "OldAnonymousFields"
36 )]
37 Anonymous(Value, Vec<Value>),
38 #[revision(start = 2)]
39 Anonymous(Value, Vec<Value>, bool),
40 }
42
43impl Function {
44 fn convert_anonymous_arg_computation(
45 old: OldAnonymousFields,
46 _revision: u16,
47 ) -> Result<Self, revision::Error> {
48 Ok(Function::Anonymous(old.0, old.1, false))
49 }
50}
51
52pub(crate) enum OptimisedAggregate {
53 None,
54 Count,
55 CountFunction,
56 MathMax,
57 MathMin,
58 MathSum,
59 MathMean,
60 TimeMax,
61 TimeMin,
62}
63
64impl PartialOrd for Function {
65 #[inline]
66 fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
67 None
68 }
69}
70
71impl Function {
72 pub fn name(&self) -> Option<&str> {
74 match self {
75 Self::Normal(n, _) => Some(n.as_str()),
76 Self::Custom(n, _) => Some(n.as_str()),
77 _ => None,
78 }
79 }
80 pub fn args(&self) -> &[Value] {
82 match self {
83 Self::Normal(_, a) => a,
84 Self::Custom(_, a) => a,
85 _ => &[],
86 }
87 }
88 pub fn to_idiom(&self) -> Idiom {
90 match self {
91 Self::Anonymous(_, _, _) => "function".to_string().into(),
92 Self::Script(_, _) => "function".to_string().into(),
93 Self::Normal(f, _) => f.to_owned().into(),
94 Self::Custom(f, _) => format!("fn::{f}").into(),
95 }
96 }
97 pub fn writeable(&self) -> bool {
99 match self {
100 Self::Custom(_, _) => true,
101 Self::Script(_, _) => true,
102 Self::Normal(f, _) if f == "api::invoke" => true,
103 _ => self.args().iter().any(Value::writeable),
104 }
105 }
106 pub fn aggregate(&self, val: Value) -> Result<Self, Error> {
108 match self {
109 Self::Normal(n, a) => {
110 let mut a = a.to_owned();
111 match a.len() {
112 0 => a.insert(0, val),
113 _ => {
114 a.remove(0);
115 a.insert(0, val);
116 }
117 }
118 Ok(Self::Normal(n.to_owned(), a))
119 }
120 _ => Err(fail!("Encountered a non-aggregate function: {self:?}")),
121 }
122 }
123 pub fn is_custom(&self) -> bool {
125 matches!(self, Self::Custom(_, _))
126 }
127
128 pub fn is_script(&self) -> bool {
130 matches!(self, Self::Script(_, _))
131 }
132
133 pub fn is_static(&self) -> bool {
135 match self {
136 Self::Normal(_, a) => a.iter().all(Value::is_static),
137 _ => false,
138 }
139 }
140
141 pub fn is_inline(&self) -> bool {
143 matches!(self, Self::Anonymous(_, _, _))
144 }
145
146 pub fn is_rolling(&self) -> bool {
148 match self {
149 Self::Normal(f, _) if f == "count" => true,
150 Self::Normal(f, _) if f == "math::max" => true,
151 Self::Normal(f, _) if f == "math::mean" => true,
152 Self::Normal(f, _) if f == "math::min" => true,
153 Self::Normal(f, _) if f == "math::sum" => true,
154 Self::Normal(f, _) if f == "time::max" => true,
155 Self::Normal(f, _) if f == "time::min" => true,
156 _ => false,
157 }
158 }
159 pub fn is_aggregate(&self) -> bool {
161 match self {
162 Self::Normal(f, _) if f == "array::distinct" => true,
163 Self::Normal(f, _) if f == "array::first" => true,
164 Self::Normal(f, _) if f == "array::flatten" => true,
165 Self::Normal(f, _) if f == "array::group" => true,
166 Self::Normal(f, _) if f == "array::last" => true,
167 Self::Normal(f, _) if f == "count" => true,
168 Self::Normal(f, _) if f == "math::bottom" => true,
169 Self::Normal(f, _) if f == "math::interquartile" => true,
170 Self::Normal(f, _) if f == "math::max" => true,
171 Self::Normal(f, _) if f == "math::mean" => true,
172 Self::Normal(f, _) if f == "math::median" => true,
173 Self::Normal(f, _) if f == "math::midhinge" => true,
174 Self::Normal(f, _) if f == "math::min" => true,
175 Self::Normal(f, _) if f == "math::mode" => true,
176 Self::Normal(f, _) if f == "math::nearestrank" => true,
177 Self::Normal(f, _) if f == "math::percentile" => true,
178 Self::Normal(f, _) if f == "math::sample" => true,
179 Self::Normal(f, _) if f == "math::spread" => true,
180 Self::Normal(f, _) if f == "math::stddev" => true,
181 Self::Normal(f, _) if f == "math::sum" => true,
182 Self::Normal(f, _) if f == "math::top" => true,
183 Self::Normal(f, _) if f == "math::trimean" => true,
184 Self::Normal(f, _) if f == "math::variance" => true,
185 Self::Normal(f, _) if f == "time::max" => true,
186 Self::Normal(f, _) if f == "time::min" => true,
187 _ => false,
188 }
189 }
190 pub(crate) fn get_optimised_aggregate(&self) -> OptimisedAggregate {
191 match self {
192 Self::Normal(f, v) if f == "count" => {
193 if v.is_empty() {
194 OptimisedAggregate::Count
195 } else {
196 OptimisedAggregate::CountFunction
197 }
198 }
199 Self::Normal(f, _) if f == "math::max" => OptimisedAggregate::MathMax,
200 Self::Normal(f, _) if f == "math::mean" => OptimisedAggregate::MathMean,
201 Self::Normal(f, _) if f == "math::min" => OptimisedAggregate::MathMin,
202 Self::Normal(f, _) if f == "math::sum" => OptimisedAggregate::MathSum,
203 Self::Normal(f, _) if f == "time::max" => OptimisedAggregate::TimeMax,
204 Self::Normal(f, _) if f == "time::min" => OptimisedAggregate::TimeMin,
205 _ => OptimisedAggregate::None,
206 }
207 }
208
209 pub(crate) fn is_count_all(&self) -> bool {
210 matches!(self, Self::Normal(f, p) if f == "count" && p.is_empty() )
211 }
212}
213
214impl Function {
215 pub(crate) async fn compute(
219 &self,
220 stk: &mut Stk,
221 ctx: &Context,
222 opt: &Options,
223 doc: Option<&CursorDoc>,
224 ) -> Result<Value, Error> {
225 let opt = &opt.new_with_futures(true);
227 match self {
229 Self::Normal(s, x) => {
230 ctx.check_allowed_function(s)?;
232 let a = stk
234 .scope(|scope| {
235 try_join_all(
236 x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
237 )
238 })
239 .await?;
240 fnc::run(stk, ctx, opt, doc, s, a).await
242 }
243 Self::Anonymous(v, x, args_computed) => {
244 let val = match v {
245 c @ Value::Closure(_) => c.clone(),
246 Value::Param(p) => ctx.value(p).cloned().unwrap_or(Value::None),
247 Value::Block(_) | Value::Subquery(_) | Value::Idiom(_) | Value::Function(_) => {
248 stk.run(|stk| v.compute(stk, ctx, opt, doc)).await?
249 }
250 _ => Value::None,
251 };
252
253 match val {
254 Value::Closure(closure) => {
255 let a =
257 match args_computed {
258 true => x.clone(),
259 false => {
260 stk.scope(|scope| {
261 try_join_all(x.iter().map(|v| {
262 scope.run(|stk| v.compute(stk, ctx, opt, doc))
263 }))
264 })
265 .await?
266 }
267 };
268
269 stk.run(|stk| closure.compute(stk, ctx, opt, doc, a)).await
270 }
271 v => Err(Error::InvalidFunction {
272 name: "ANONYMOUS".to_string(),
273 message: format!("'{}' is not a function", v.kindof()),
274 }),
275 }
276 }
277 Self::Custom(s, x) => {
278 let name = format!("fn::{s}");
280 ctx.check_allowed_function(name.as_str())?;
282 let (ns, db) = opt.ns_db()?;
284 let val = ctx.tx().get_db_function(ns, db, s).await?;
285 if opt.check_perms(Action::View)? {
287 match &val.permissions {
288 Permission::Full => (),
289 Permission::None => {
290 return Err(Error::FunctionPermissions {
291 name: s.to_owned(),
292 })
293 }
294 Permission::Specific(e) => {
295 let opt = &opt.new_with_perms(false);
297 if !stk.run(|stk| e.compute(stk, ctx, opt, doc)).await?.is_truthy() {
299 return Err(Error::FunctionPermissions {
300 name: s.to_owned(),
301 });
302 }
303 }
304 }
305 }
306 let max_args_len = val.args.len();
308 let mut min_args_len = 0;
310 val.args.iter().rev().for_each(|(_, kind)| match kind {
312 Kind::Option(_) if min_args_len == 0 => {}
313 Kind::Any if min_args_len == 0 => {}
314 _ => min_args_len += 1,
315 });
316 if x.len() < min_args_len || max_args_len < x.len() {
318 return Err(Error::InvalidArguments {
319 name: format!("fn::{}", val.name),
320 message: match (min_args_len, max_args_len) {
321 (1, 1) => String::from("The function expects 1 argument."),
322 (r, t) if r == t => format!("The function expects {r} arguments."),
323 (r, t) => format!("The function expects {r} to {t} arguments."),
324 },
325 });
326 }
327 let a = stk
329 .scope(|scope| {
330 try_join_all(
331 x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
332 )
333 })
334 .await?;
335 let mut ctx = MutableContext::new_isolated(ctx);
337 for (val, (name, kind)) in a.into_iter().zip(&val.args) {
339 ctx.add_value(name.to_raw(), val.coerce_to(kind)?.into());
340 }
341 let ctx = ctx.freeze();
342 let result = match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await {
344 Err(Error::Return {
345 value,
346 }) => Ok(value),
347 res => res,
348 }?;
349
350 if let Some(ref returns) = val.returns {
351 result
352 .coerce_to(returns)
353 .map_err(|e| e.function_check_from_coerce(val.name.to_string()))
354 } else {
355 Ok(result)
356 }
357 }
358 #[allow(unused_variables)]
359 Self::Script(s, x) => {
360 #[cfg(feature = "scripting")]
361 {
362 ctx.check_allowed_scripting()?;
364 let a = stk
366 .scope(|scope| {
367 try_join_all(
368 x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
369 )
370 })
371 .await?;
372 fnc::script::run(ctx, opt, doc, s, a).await
374 }
375 #[cfg(not(feature = "scripting"))]
376 {
377 Err(Error::InvalidScript {
378 message: String::from("Embedded functions are not enabled."),
379 })
380 }
381 }
382 }
383 }
384}
385
386impl fmt::Display for Function {
387 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
388 match self {
389 Self::Normal(s, e) => write!(f, "{s}({})", Fmt::comma_separated(e)),
390 Self::Custom(s, e) => write!(f, "fn::{s}({})", Fmt::comma_separated(e)),
391 Self::Script(s, e) => write!(f, "function({}) {{{s}}}", Fmt::comma_separated(e)),
392 Self::Anonymous(p, e, _) => write!(f, "{p}({})", Fmt::comma_separated(e)),
393 }
394 }
395}