use crate::ctx::Context;
use crate::dbs::{Options, Transaction};
use crate::doc::CursorDoc;
use crate::err::Error;
use crate::fmt::Fmt;
use crate::fnc;
use crate::iam::Action;
use crate::idiom::Idiom;
use crate::script::Script;
use crate::value::Value;
use crate::Permission;
use async_recursion::async_recursion;
use futures::future::try_join_all;
use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::fmt;
use super::Kind;
pub(crate) const TOKEN: &str = "$surrealdb::private::crate::Function";
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
#[serde(rename = "$surrealdb::private::crate::Function")]
#[revisioned(revision = 1)]
pub enum Function {
Normal(String, Vec<Value>),
Custom(String, Vec<Value>),
Script(Script, Vec<Value>),
}
impl PartialOrd for Function {
#[inline]
fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
None
}
}
impl Function {
pub fn name(&self) -> Option<&str> {
match self {
Self::Normal(n, _) => Some(n.as_str()),
Self::Custom(n, _) => Some(n.as_str()),
_ => None,
}
}
pub fn args(&self) -> &[Value] {
match self {
Self::Normal(_, a) => a,
Self::Custom(_, a) => a,
_ => &[],
}
}
pub fn to_idiom(&self) -> Idiom {
match self {
Self::Script(_, _) => "function".to_string().into(),
Self::Normal(f, _) => f.to_owned().into(),
Self::Custom(f, _) => format!("fn::{f}").into(),
}
}
pub fn aggregate(&self, val: Value) -> Self {
match self {
Self::Normal(n, a) => {
let mut a = a.to_owned();
match a.len() {
0 => a.insert(0, val),
_ => {
a.remove(0);
a.insert(0, val);
}
}
Self::Normal(n.to_owned(), a)
}
_ => unreachable!(),
}
}
pub fn is_custom(&self) -> bool {
matches!(self, Self::Custom(_, _))
}
pub fn is_script(&self) -> bool {
matches!(self, Self::Script(_, _))
}
pub fn is_static(&self) -> bool {
match self {
Self::Normal(_, a) => a.iter().all(Value::is_static),
_ => false,
}
}
pub fn is_rolling(&self) -> bool {
match self {
Self::Normal(f, _) if f == "count" => true,
Self::Normal(f, _) if f == "math::max" => true,
Self::Normal(f, _) if f == "math::mean" => true,
Self::Normal(f, _) if f == "math::min" => true,
Self::Normal(f, _) if f == "math::sum" => true,
Self::Normal(f, _) if f == "time::max" => true,
Self::Normal(f, _) if f == "time::min" => true,
_ => false,
}
}
pub fn is_aggregate(&self) -> bool {
match self {
Self::Normal(f, _) if f == "array::distinct" => true,
Self::Normal(f, _) if f == "array::first" => true,
Self::Normal(f, _) if f == "array::flatten" => true,
Self::Normal(f, _) if f == "array::group" => true,
Self::Normal(f, _) if f == "array::last" => true,
Self::Normal(f, _) if f == "count" => true,
Self::Normal(f, _) if f == "math::bottom" => true,
Self::Normal(f, _) if f == "math::interquartile" => true,
Self::Normal(f, _) if f == "math::max" => true,
Self::Normal(f, _) if f == "math::mean" => true,
Self::Normal(f, _) if f == "math::median" => true,
Self::Normal(f, _) if f == "math::midhinge" => true,
Self::Normal(f, _) if f == "math::min" => true,
Self::Normal(f, _) if f == "math::mode" => true,
Self::Normal(f, _) if f == "math::nearestrank" => true,
Self::Normal(f, _) if f == "math::percentile" => true,
Self::Normal(f, _) if f == "math::sample" => true,
Self::Normal(f, _) if f == "math::spread" => true,
Self::Normal(f, _) if f == "math::stddev" => true,
Self::Normal(f, _) if f == "math::sum" => true,
Self::Normal(f, _) if f == "math::top" => true,
Self::Normal(f, _) if f == "math::trimean" => true,
Self::Normal(f, _) if f == "math::variance" => true,
Self::Normal(f, _) if f == "time::max" => true,
Self::Normal(f, _) if f == "time::min" => true,
_ => false,
}
}
}
impl Function {
#[cfg_attr(not(target_arch = "wasm32"), async_recursion)]
#[cfg_attr(target_arch = "wasm32", async_recursion(?Send))]
pub(crate) async fn compute(
&self,
ctx: &Context<'_>,
opt: &Options,
txn: &Transaction,
doc: Option<&'async_recursion CursorDoc<'_>>,
) -> Result<Value, Error> {
let opt = &opt.new_with_futures(true);
match self {
Self::Normal(s, x) => {
ctx.check_allowed_function(s)?;
let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
fnc::run(ctx, opt, txn, doc, s, a).await
}
Self::Custom(s, x) => {
let name = format!("fn::{s}");
ctx.check_allowed_function(name.as_str())?;
let val = {
let mut run = txn.lock().await;
run.get_and_cache_db_function(opt.ns(), opt.db(), s).await?
};
if opt.check_perms(Action::View) {
match &val.permissions {
Permission::Full => (),
Permission::None => {
return Err(Error::FunctionPermissions {
name: s.to_owned(),
})
}
Permission::Specific(e) => {
let opt = &opt.new_with_perms(false);
if !e.compute(ctx, opt, txn, doc).await?.is_truthy() {
return Err(Error::FunctionPermissions {
name: s.to_owned(),
});
}
}
}
}
let max_args_len = val.args.len();
let mut min_args_len = 0;
val.args.iter().rev().for_each(|(_, kind)| match kind {
Kind::Option(_) if min_args_len == 0 => {}
_ => min_args_len += 1,
});
if x.len() < min_args_len || max_args_len < x.len() {
return Err(Error::InvalidArguments {
name: format!("fn::{}", val.name),
message: match (min_args_len, max_args_len) {
(1, 1) => String::from("The function expects 1 argument."),
(r, t) if r == t => format!("The function expects {r} arguments."),
(r, t) => format!("The function expects {r} to {t} arguments."),
},
});
}
let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
let mut ctx = Context::new(ctx);
for (val, (name, kind)) in a.into_iter().zip(&val.args) {
ctx.add_value(name.to_raw(), val.coerce_to(kind)?);
}
val.block.compute(&ctx, opt, txn, doc).await
}
#[allow(unused_variables)]
Self::Script(s, x) => {
#[cfg(feature = "scripting")]
{
ctx.check_allowed_scripting()?;
let a = try_join_all(x.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
fnc::script::run(ctx, opt, txn, doc, s, a).await
}
#[cfg(not(feature = "scripting"))]
{
Err(Error::InvalidScript {
message: String::from("Embedded functions are not enabled."),
})
}
}
}
}
}
impl fmt::Display for Function {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Normal(s, e) => write!(f, "{s}({})", Fmt::comma_separated(e)),
Self::Custom(s, e) => write!(f, "fn::{s}({})", Fmt::comma_separated(e)),
Self::Script(s, e) => write!(f, "function({}) {{{s}}}", Fmt::comma_separated(e)),
}
}
}