use crate::ctx::{Context, MutableContext};
use crate::dbs::Options;
use crate::doc::CursorDoc;
use crate::err::Error;
use crate::fnc;
use crate::iam::Action;
use crate::sql::fmt::Fmt;
use crate::sql::idiom::Idiom;
use crate::sql::script::Script;
use crate::sql::value::Value;
use crate::sql::Permission;
use futures::future::try_join_all;
use reblessive::tree::Stk;
use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::fmt;
use super::Kind;
pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Function";
#[revisioned(revision = 1)]
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
#[serde(rename = "$surrealdb::private::sql::Function")]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[non_exhaustive]
pub enum Function {
Normal(String, Vec<Value>),
Custom(String, Vec<Value>),
Script(Script, Vec<Value>),
Anonymous(Value, Vec<Value>),
}
pub(crate) enum OptimisedAggregate {
None,
Count,
CountFunction,
MathMax,
MathMin,
MathSum,
MathMean,
TimeMax,
TimeMin,
}
impl PartialOrd for Function {
#[inline]
fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
None
}
}
impl Function {
pub fn name(&self) -> Option<&str> {
match self {
Self::Normal(n, _) => Some(n.as_str()),
Self::Custom(n, _) => Some(n.as_str()),
_ => None,
}
}
pub fn args(&self) -> &[Value] {
match self {
Self::Normal(_, a) => a,
Self::Custom(_, a) => a,
_ => &[],
}
}
pub fn to_idiom(&self) -> Idiom {
match self {
Self::Anonymous(_, _) => "function".to_string().into(),
Self::Script(_, _) => "function".to_string().into(),
Self::Normal(f, _) => f.to_owned().into(),
Self::Custom(f, _) => format!("fn::{f}").into(),
}
}
pub fn aggregate(&self, val: Value) -> Result<Self, Error> {
match self {
Self::Normal(n, a) => {
let mut a = a.to_owned();
match a.len() {
0 => a.insert(0, val),
_ => {
a.remove(0);
a.insert(0, val);
}
}
Ok(Self::Normal(n.to_owned(), a))
}
_ => Err(fail!("Encountered a non-aggregate function: {self:?}")),
}
}
pub fn is_custom(&self) -> bool {
matches!(self, Self::Custom(_, _))
}
pub fn is_script(&self) -> bool {
matches!(self, Self::Script(_, _))
}
pub fn is_static(&self) -> bool {
match self {
Self::Normal(_, a) => a.iter().all(Value::is_static),
_ => false,
}
}
pub fn is_inline(&self) -> bool {
matches!(self, Self::Anonymous(_, _))
}
pub fn is_rolling(&self) -> bool {
match self {
Self::Normal(f, _) if f == "count" => true,
Self::Normal(f, _) if f == "math::max" => true,
Self::Normal(f, _) if f == "math::mean" => true,
Self::Normal(f, _) if f == "math::min" => true,
Self::Normal(f, _) if f == "math::sum" => true,
Self::Normal(f, _) if f == "time::max" => true,
Self::Normal(f, _) if f == "time::min" => true,
_ => false,
}
}
pub fn is_aggregate(&self) -> bool {
match self {
Self::Normal(f, _) if f == "array::distinct" => true,
Self::Normal(f, _) if f == "array::first" => true,
Self::Normal(f, _) if f == "array::flatten" => true,
Self::Normal(f, _) if f == "array::group" => true,
Self::Normal(f, _) if f == "array::last" => true,
Self::Normal(f, _) if f == "count" => true,
Self::Normal(f, _) if f == "math::bottom" => true,
Self::Normal(f, _) if f == "math::interquartile" => true,
Self::Normal(f, _) if f == "math::max" => true,
Self::Normal(f, _) if f == "math::mean" => true,
Self::Normal(f, _) if f == "math::median" => true,
Self::Normal(f, _) if f == "math::midhinge" => true,
Self::Normal(f, _) if f == "math::min" => true,
Self::Normal(f, _) if f == "math::mode" => true,
Self::Normal(f, _) if f == "math::nearestrank" => true,
Self::Normal(f, _) if f == "math::percentile" => true,
Self::Normal(f, _) if f == "math::sample" => true,
Self::Normal(f, _) if f == "math::spread" => true,
Self::Normal(f, _) if f == "math::stddev" => true,
Self::Normal(f, _) if f == "math::sum" => true,
Self::Normal(f, _) if f == "math::top" => true,
Self::Normal(f, _) if f == "math::trimean" => true,
Self::Normal(f, _) if f == "math::variance" => true,
Self::Normal(f, _) if f == "time::max" => true,
Self::Normal(f, _) if f == "time::min" => true,
_ => false,
}
}
pub(crate) fn get_optimised_aggregate(&self) -> OptimisedAggregate {
match self {
Self::Normal(f, v) if f == "count" => {
if v.is_empty() {
OptimisedAggregate::Count
} else {
OptimisedAggregate::CountFunction
}
}
Self::Normal(f, _) if f == "math::max" => OptimisedAggregate::MathMax,
Self::Normal(f, _) if f == "math::mean" => OptimisedAggregate::MathMean,
Self::Normal(f, _) if f == "math::min" => OptimisedAggregate::MathMin,
Self::Normal(f, _) if f == "math::sum" => OptimisedAggregate::MathSum,
Self::Normal(f, _) if f == "time::max" => OptimisedAggregate::TimeMax,
Self::Normal(f, _) if f == "time::min" => OptimisedAggregate::TimeMin,
_ => OptimisedAggregate::None,
}
}
pub(crate) fn is_count_all(&self) -> bool {
matches!(self, Self::Normal(f, p) if f == "count" && p.is_empty() )
}
}
impl Function {
pub(crate) async fn compute(
&self,
stk: &mut Stk,
ctx: &Context,
opt: &Options,
doc: Option<&CursorDoc>,
) -> Result<Value, Error> {
let opt = &opt.new_with_futures(true);
match self {
Self::Normal(s, x) => {
ctx.check_allowed_function(s)?;
let a = stk
.scope(|scope| {
try_join_all(
x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
)
})
.await?;
fnc::run(stk, ctx, opt, doc, s, a).await
}
Self::Anonymous(v, x) => {
let val = match v {
c @ Value::Closure(_) => c.clone(),
Value::Param(p) => ctx.value(p).cloned().unwrap_or(Value::None),
Value::Block(_) | Value::Subquery(_) | Value::Idiom(_) | Value::Function(_) => {
stk.run(|stk| v.compute(stk, ctx, opt, doc)).await?
}
_ => Value::None,
};
match val {
Value::Closure(closure) => {
let a = stk
.scope(|scope| {
try_join_all(
x.iter()
.map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
)
})
.await?;
stk.run(|stk| closure.compute(stk, ctx, opt, doc, a)).await
}
v => Err(Error::InvalidFunction {
name: "ANONYMOUS".to_string(),
message: format!("'{}' is not a function", v.kindof()),
}),
}
}
Self::Custom(s, x) => {
let name = format!("fn::{s}");
ctx.check_allowed_function(name.as_str())?;
let val = ctx.tx().get_db_function(opt.ns()?, opt.db()?, s).await?;
if opt.check_perms(Action::View)? {
match &val.permissions {
Permission::Full => (),
Permission::None => {
return Err(Error::FunctionPermissions {
name: s.to_owned(),
})
}
Permission::Specific(e) => {
let opt = &opt.new_with_perms(false);
if !stk.run(|stk| e.compute(stk, ctx, opt, doc)).await?.is_truthy() {
return Err(Error::FunctionPermissions {
name: s.to_owned(),
});
}
}
}
}
let max_args_len = val.args.len();
let mut min_args_len = 0;
val.args.iter().rev().for_each(|(_, kind)| match kind {
Kind::Option(_) if min_args_len == 0 => {}
Kind::Any if min_args_len == 0 => {}
_ => min_args_len += 1,
});
if x.len() < min_args_len || max_args_len < x.len() {
return Err(Error::InvalidArguments {
name: format!("fn::{}", val.name),
message: match (min_args_len, max_args_len) {
(1, 1) => String::from("The function expects 1 argument."),
(r, t) if r == t => format!("The function expects {r} arguments."),
(r, t) => format!("The function expects {r} to {t} arguments."),
},
});
}
let a = stk
.scope(|scope| {
try_join_all(
x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
)
})
.await?;
let mut ctx = MutableContext::new_isolated(ctx);
for (val, (name, kind)) in a.into_iter().zip(&val.args) {
ctx.add_value(name.to_raw(), val.coerce_to(kind)?.into());
}
let ctx = ctx.freeze();
let result = match stk.run(|stk| val.block.compute(stk, &ctx, opt, doc)).await {
Err(Error::Return {
value,
}) => Ok(value),
res => res,
}?;
if let Some(ref returns) = val.returns {
result
.coerce_to(returns)
.map_err(|e| e.function_check_from_coerce(val.name.to_string()))
} else {
Ok(result)
}
}
#[allow(unused_variables)]
Self::Script(s, x) => {
#[cfg(feature = "scripting")]
{
ctx.check_allowed_scripting()?;
let a = stk
.scope(|scope| {
try_join_all(
x.iter().map(|v| scope.run(|stk| v.compute(stk, ctx, opt, doc))),
)
})
.await?;
fnc::script::run(ctx, opt, doc, s, a).await
}
#[cfg(not(feature = "scripting"))]
{
Err(Error::InvalidScript {
message: String::from("Embedded functions are not enabled."),
})
}
}
}
}
}
impl fmt::Display for Function {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Self::Normal(s, e) => write!(f, "{s}({})", Fmt::comma_separated(e)),
Self::Custom(s, e) => write!(f, "fn::{s}({})", Fmt::comma_separated(e)),
Self::Script(s, e) => write!(f, "function({}) {{{s}}}", Fmt::comma_separated(e)),
Self::Anonymous(p, e) => write!(f, "{p}({})", Fmt::comma_separated(e)),
}
}
}