1use crate::ctx::Context;
2use crate::dbs::Options;
3use crate::doc::CursorDoc;
4use crate::err::Error;
5use crate::sql::value::Value;
6
7use reblessive::tree::Stk;
8use revision::revisioned;
9use serde::{Deserialize, Serialize};
10use std::fmt;
11
12#[cfg(feature = "ml")]
13use crate::iam::Action;
14#[cfg(feature = "ml")]
15use crate::sql::Permission;
16#[cfg(feature = "ml")]
17use futures::future::try_join_all;
18#[cfg(feature = "ml")]
19use std::collections::HashMap;
20#[cfg(feature = "ml")]
21use surrealml::errors::error::SurrealError;
22#[cfg(feature = "ml")]
23use surrealml::execution::compute::ModelComputation;
24#[cfg(feature = "ml")]
25use surrealml::storage::surml_file::SurMlFile;
26
27#[cfg(feature = "ml")]
28const ARGUMENTS: &str = "The model expects 1 argument. The argument can be either a number, an object, or an array of numbers.";
29
30pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Model";
31
32#[revisioned(revision = 1)]
33#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Serialize, Deserialize, Hash)]
34#[serde(rename = "$surrealdb::private::sql::Model")]
35#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
36#[non_exhaustive]
37pub struct Model {
38 pub name: String,
39 pub version: String,
40 pub args: Vec<Value>,
41}
42
43impl fmt::Display for Model {
44 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45 write!(f, "ml::{}<{}>(", self.name, self.version)?;
46 for (idx, p) in self.args.iter().enumerate() {
47 if idx != 0 {
48 write!(f, ",")?;
49 }
50 write!(f, "{}", p)?;
51 }
52 write!(f, ")")
53 }
54}
55
56impl Model {
57 #[cfg(feature = "ml")]
58 pub(crate) async fn compute(
59 &self,
60 stk: &mut Stk,
61 ctx: &Context,
62 opt: &Options,
63 doc: Option<&CursorDoc>,
64 ) -> Result<Value, Error> {
65 let opt = &opt.new_with_futures(true);
67 let name = format!("ml::{}", self.name);
69 ctx.check_allowed_function(name.as_str())?;
71 let (ns, db) = opt.ns_db()?;
73 let val = ctx.tx().get_db_model(ns, db, &self.name, &self.version).await?;
74 let (ns, db) = opt.ns_db()?;
76 let path = format!("ml/{}/{}/{}-{}-{}.surml", ns, db, self.name, self.version, val.hash);
77 if opt.check_perms(Action::View)? {
79 match &val.permissions {
80 Permission::Full => (),
81 Permission::None => {
82 return Err(Error::FunctionPermissions {
83 name: self.name.to_owned(),
84 })
85 }
86 Permission::Specific(e) => {
87 let opt = &opt.new_with_perms(false);
89 if !stk.run(|stk| e.compute(stk, ctx, opt, doc)).await?.is_truthy() {
91 return Err(Error::FunctionPermissions {
92 name: self.name.to_owned(),
93 });
94 }
95 }
96 }
97 }
98 let mut args = stk
100 .scope(|stk| {
101 try_join_all(self.args.iter().map(|v| stk.run(|stk| v.compute(stk, ctx, opt, doc))))
102 })
103 .await?;
104 if args.len() != 1 {
106 return Err(Error::InvalidArguments {
107 name: format!("ml::{}<{}>", self.name, self.version),
108 message: ARGUMENTS.into(),
109 });
110 }
111 match args.swap_remove(0) {
113 Value::Object(v) => {
115 let mut args = v
117 .into_iter()
118 .map(|(k, v)| Ok((k, Value::try_into(v)?)))
119 .collect::<Result<HashMap<String, f32>, Error>>()
120 .map_err(|_| Error::InvalidArguments {
121 name: format!("ml::{}<{}>", self.name, self.version),
122 message: ARGUMENTS.into(),
123 })?;
124 let bytes = crate::obs::get(&path).await?;
126 let outcome: Vec<f32> = tokio::task::spawn_blocking(move || {
128 let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
129 Error::ModelComputation(err.message.to_string())
130 })?;
131 let compute_unit = ModelComputation {
132 surml_file: &mut file,
133 };
134 compute_unit.buffered_compute(&mut args).map_err(|err: SurrealError| {
135 Error::ModelComputation(err.message.to_string())
136 })
137 })
138 .await
139 .unwrap()?;
140 Ok(outcome.into())
142 }
143 Value::Number(v) => {
145 let args: f32 = v.try_into().map_err(|_| Error::InvalidArguments {
147 name: format!("ml::{}<{}>", self.name, self.version),
148 message: ARGUMENTS.into(),
149 })?;
150 let bytes = crate::obs::get(&path).await?;
152 let tensor = ndarray::arr1::<f32>(&[args]).into_dyn();
154 let outcome: Vec<f32> = tokio::task::spawn_blocking(move || {
156 let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
157 Error::ModelComputation(err.message.to_string())
158 })?;
159 let compute_unit = ModelComputation {
160 surml_file: &mut file,
161 };
162 compute_unit.raw_compute(tensor, None).map_err(|err: SurrealError| {
163 Error::ModelComputation(err.message.to_string())
164 })
165 })
166 .await
167 .unwrap()?;
168 Ok(outcome.into())
170 }
171 Value::Array(v) => {
173 let args = v
175 .into_iter()
176 .map(Value::try_into)
177 .collect::<Result<Vec<f32>, Error>>()
178 .map_err(|_| Error::InvalidArguments {
179 name: format!("ml::{}<{}>", self.name, self.version),
180 message: ARGUMENTS.into(),
181 })?;
182 let bytes = crate::obs::get(&path).await?;
184 let tensor = ndarray::arr1::<f32>(&args).into_dyn();
186 let outcome: Vec<f32> = tokio::task::spawn_blocking(move || {
188 let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
189 Error::ModelComputation(err.message.to_string())
190 })?;
191 let compute_unit = ModelComputation {
192 surml_file: &mut file,
193 };
194 compute_unit.raw_compute(tensor, None).map_err(|err: SurrealError| {
195 Error::ModelComputation(err.message.to_string())
196 })
197 })
198 .await
199 .unwrap()?;
200 Ok(outcome.into())
202 }
203 _ => Err(Error::InvalidArguments {
205 name: format!("ml::{}<{}>", self.name, self.version),
206 message: ARGUMENTS.into(),
207 }),
208 }
209 }
210
211 #[cfg(not(feature = "ml"))]
212 pub(crate) async fn compute(
213 &self,
214 _stk: &mut Stk,
215 _ctx: &Context,
216 _opt: &Options,
217 _doc: Option<&CursorDoc>,
218 ) -> Result<Value, Error> {
219 Err(Error::InvalidModel {
220 message: String::from("Machine learning computation is not enabled."),
221 })
222 }
223}