surrealdb_core/sql/
model.rs

1use crate::ctx::Context;
2use crate::dbs::Options;
3use crate::doc::CursorDoc;
4use crate::err::Error;
5use crate::sql::value::Value;
6
7use reblessive::tree::Stk;
8use revision::revisioned;
9use serde::{Deserialize, Serialize};
10use std::fmt;
11
12#[cfg(feature = "ml")]
13use crate::iam::Action;
14#[cfg(feature = "ml")]
15use crate::sql::Permission;
16#[cfg(feature = "ml")]
17use futures::future::try_join_all;
18#[cfg(feature = "ml")]
19use std::collections::HashMap;
20#[cfg(feature = "ml")]
21use surrealml::errors::error::SurrealError;
22#[cfg(feature = "ml")]
23use surrealml::execution::compute::ModelComputation;
24#[cfg(feature = "ml")]
25use surrealml::storage::surml_file::SurMlFile;
26
27#[cfg(feature = "ml")]
28const ARGUMENTS: &str = "The model expects 1 argument. The argument can be either a number, an object, or an array of numbers.";
29
30pub(crate) const TOKEN: &str = "$surrealdb::private::sql::Model";
31
32#[revisioned(revision = 1)]
33#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Serialize, Deserialize, Hash)]
34#[serde(rename = "$surrealdb::private::sql::Model")]
35#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
36#[non_exhaustive]
37pub struct Model {
38	pub name: String,
39	pub version: String,
40	pub args: Vec<Value>,
41}
42
43impl fmt::Display for Model {
44	fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
45		write!(f, "ml::{}<{}>(", self.name, self.version)?;
46		for (idx, p) in self.args.iter().enumerate() {
47			if idx != 0 {
48				write!(f, ",")?;
49			}
50			write!(f, "{}", p)?;
51		}
52		write!(f, ")")
53	}
54}
55
56impl Model {
57	#[cfg(feature = "ml")]
58	pub(crate) async fn compute(
59		&self,
60		stk: &mut Stk,
61		ctx: &Context,
62		opt: &Options,
63		doc: Option<&CursorDoc>,
64	) -> Result<Value, Error> {
65		// Ensure futures are run
66		let opt = &opt.new_with_futures(true);
67		// Get the full name of this model
68		let name = format!("ml::{}", self.name);
69		// Check this function is allowed
70		ctx.check_allowed_function(name.as_str())?;
71		// Get the model definition
72		let (ns, db) = opt.ns_db()?;
73		let val = ctx.tx().get_db_model(ns, db, &self.name, &self.version).await?;
74		// Calculate the model path
75		let (ns, db) = opt.ns_db()?;
76		let path = format!("ml/{}/{}/{}-{}-{}.surml", ns, db, self.name, self.version, val.hash);
77		// Check permissions
78		if opt.check_perms(Action::View)? {
79			match &val.permissions {
80				Permission::Full => (),
81				Permission::None => {
82					return Err(Error::FunctionPermissions {
83						name: self.name.to_owned(),
84					})
85				}
86				Permission::Specific(e) => {
87					// Disable permissions
88					let opt = &opt.new_with_perms(false);
89					// Process the PERMISSION clause
90					if !stk.run(|stk| e.compute(stk, ctx, opt, doc)).await?.is_truthy() {
91						return Err(Error::FunctionPermissions {
92							name: self.name.to_owned(),
93						});
94					}
95				}
96			}
97		}
98		// Compute the function arguments
99		let mut args = stk
100			.scope(|stk| {
101				try_join_all(self.args.iter().map(|v| stk.run(|stk| v.compute(stk, ctx, opt, doc))))
102			})
103			.await?;
104		// Check the minimum argument length
105		if args.len() != 1 {
106			return Err(Error::InvalidArguments {
107				name: format!("ml::{}<{}>", self.name, self.version),
108				message: ARGUMENTS.into(),
109			});
110		}
111		// Take the first and only specified argument
112		match args.swap_remove(0) {
113			// Perform bufferered compute
114			Value::Object(v) => {
115				// Compute the model function arguments
116				let mut args = v
117					.into_iter()
118					.map(|(k, v)| Ok((k, Value::try_into(v)?)))
119					.collect::<Result<HashMap<String, f32>, Error>>()
120					.map_err(|_| Error::InvalidArguments {
121						name: format!("ml::{}<{}>", self.name, self.version),
122						message: ARGUMENTS.into(),
123					})?;
124				// Get the model file as bytes
125				let bytes = crate::obs::get(&path).await?;
126				// Run the compute in a blocking task
127				let outcome: Vec<f32> = tokio::task::spawn_blocking(move || {
128					let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
129						Error::ModelComputation(err.message.to_string())
130					})?;
131					let compute_unit = ModelComputation {
132						surml_file: &mut file,
133					};
134					compute_unit.buffered_compute(&mut args).map_err(|err: SurrealError| {
135						Error::ModelComputation(err.message.to_string())
136					})
137				})
138				.await
139				.unwrap()?;
140				// Convert the output to a value
141				Ok(outcome.into())
142			}
143			// Perform raw compute
144			Value::Number(v) => {
145				// Compute the model function arguments
146				let args: f32 = v.try_into().map_err(|_| Error::InvalidArguments {
147					name: format!("ml::{}<{}>", self.name, self.version),
148					message: ARGUMENTS.into(),
149				})?;
150				// Get the model file as bytes
151				let bytes = crate::obs::get(&path).await?;
152				// Convert the argument to a tensor
153				let tensor = ndarray::arr1::<f32>(&[args]).into_dyn();
154				// Run the compute in a blocking task
155				let outcome: Vec<f32> = tokio::task::spawn_blocking(move || {
156					let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
157						Error::ModelComputation(err.message.to_string())
158					})?;
159					let compute_unit = ModelComputation {
160						surml_file: &mut file,
161					};
162					compute_unit.raw_compute(tensor, None).map_err(|err: SurrealError| {
163						Error::ModelComputation(err.message.to_string())
164					})
165				})
166				.await
167				.unwrap()?;
168				// Convert the output to a value
169				Ok(outcome.into())
170			}
171			// Perform raw compute
172			Value::Array(v) => {
173				// Compute the model function arguments
174				let args = v
175					.into_iter()
176					.map(Value::try_into)
177					.collect::<Result<Vec<f32>, Error>>()
178					.map_err(|_| Error::InvalidArguments {
179						name: format!("ml::{}<{}>", self.name, self.version),
180						message: ARGUMENTS.into(),
181					})?;
182				// Get the model file as bytes
183				let bytes = crate::obs::get(&path).await?;
184				// Convert the argument to a tensor
185				let tensor = ndarray::arr1::<f32>(&args).into_dyn();
186				// Run the compute in a blocking task
187				let outcome: Vec<f32> = tokio::task::spawn_blocking(move || {
188					let mut file = SurMlFile::from_bytes(bytes).map_err(|err: SurrealError| {
189						Error::ModelComputation(err.message.to_string())
190					})?;
191					let compute_unit = ModelComputation {
192						surml_file: &mut file,
193					};
194					compute_unit.raw_compute(tensor, None).map_err(|err: SurrealError| {
195						Error::ModelComputation(err.message.to_string())
196					})
197				})
198				.await
199				.unwrap()?;
200				// Convert the output to a value
201				Ok(outcome.into())
202			}
203			//
204			_ => Err(Error::InvalidArguments {
205				name: format!("ml::{}<{}>", self.name, self.version),
206				message: ARGUMENTS.into(),
207			}),
208		}
209	}
210
211	#[cfg(not(feature = "ml"))]
212	pub(crate) async fn compute(
213		&self,
214		_stk: &mut Stk,
215		_ctx: &Context,
216		_opt: &Options,
217		_doc: Option<&CursorDoc>,
218	) -> Result<Value, Error> {
219		Err(Error::InvalidModel {
220			message: String::from("Machine learning computation is not enabled."),
221		})
222	}
223}