use crate::ctx::Context;
use crate::dbs::{Options, Transaction};
use crate::doc::CursorDoc;
use crate::err::Error;
use crate::value::Value;
use derive::Store;
use revision::revisioned;
use serde::{Deserialize, Serialize};
use std::fmt;
#[cfg(feature = "ml")]
use crate::iam::Action;
#[cfg(feature = "ml")]
use crate::Permission;
#[cfg(feature = "ml")]
use futures::future::try_join_all;
#[cfg(feature = "ml")]
use std::collections::HashMap;
#[cfg(feature = "ml")]
use surrealml_core::execution::compute::ModelComputation;
#[cfg(feature = "ml")]
use surrealml_core::storage::surml_file::SurMlFile;
#[cfg(feature = "ml")]
const ARGUMENTS: &str = "The model expects 1 argument. The argument can be either a number, an object, or an array of numbers.";
#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Serialize, Deserialize, Store, Hash)]
#[revisioned(revision = 1)]
pub struct Model {
pub name: String,
pub version: String,
pub args: Vec<Value>,
}
impl fmt::Display for Model {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ml::{}<{}>(", self.name, self.version)?;
for (idx, p) in self.args.iter().enumerate() {
if idx != 0 {
write!(f, ",")?;
}
write!(f, "{}", p)?;
}
write!(f, ")")
}
}
impl Model {
#[cfg(feature = "ml")]
pub(crate) async fn compute(
&self,
ctx: &Context<'_>,
opt: &Options,
txn: &Transaction,
doc: Option<&CursorDoc<'_>>,
) -> Result<Value, Error> {
let opt = &opt.new_with_futures(true);
let name = format!("ml::{}", self.name);
ctx.check_allowed_function(name.as_str())?;
let val = {
let mut run = txn.lock().await;
run.get_and_cache_db_model(opt.ns(), opt.db(), &self.name, &self.version).await?
};
let path = format!(
"ml/{}/{}/{}-{}-{}.surml",
opt.ns(),
opt.db(),
self.name,
self.version,
val.hash
);
if opt.check_perms(Action::View) {
match &val.permissions {
Permission::Full => (),
Permission::None => {
return Err(Error::FunctionPermissions {
name: self.name.to_owned(),
})
}
Permission::Specific(e) => {
let opt = &opt.new_with_perms(false);
if !e.compute(ctx, opt, txn, doc).await?.is_truthy() {
return Err(Error::FunctionPermissions {
name: self.name.to_owned(),
});
}
}
}
}
let mut args =
try_join_all(self.args.iter().map(|v| v.compute(ctx, opt, txn, doc))).await?;
if args.len() != 1 {
return Err(Error::InvalidArguments {
name: format!("ml::{}<{}>", self.name, self.version),
message: ARGUMENTS.into(),
});
}
match args.swap_remove(0) {
Value::Object(v) => {
let mut args = v
.into_iter()
.map(|(k, v)| Ok((k, Value::try_into(v)?)))
.collect::<Result<HashMap<String, f32>, Error>>()
.map_err(|_| Error::InvalidArguments {
name: format!("ml::{}<{}>", self.name, self.version),
message: ARGUMENTS.into(),
})?;
let bytes = crate::obs::get(&path).await?;
let outcome = tokio::task::spawn_blocking(move || {
let mut file = SurMlFile::from_bytes(bytes).unwrap();
let compute_unit = ModelComputation {
surml_file: &mut file,
};
compute_unit.buffered_compute(&mut args).map_err(Error::ModelComputation)
})
.await
.unwrap()?;
Ok(outcome[0].into())
}
Value::Number(v) => {
let args: f32 = v.try_into().map_err(|_| Error::InvalidArguments {
name: format!("ml::{}<{}>", self.name, self.version),
message: ARGUMENTS.into(),
})?;
let bytes = crate::obs::get(&path).await?;
let tensor = ndarray::arr1::<f32>(&[args]).into_dyn();
let outcome = tokio::task::spawn_blocking(move || {
let mut file = SurMlFile::from_bytes(bytes).unwrap();
let compute_unit = ModelComputation {
surml_file: &mut file,
};
compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation)
})
.await
.unwrap()?;
Ok(outcome[0].into())
}
Value::Array(v) => {
let args = v
.into_iter()
.map(Value::try_into)
.collect::<Result<Vec<f32>, Error>>()
.map_err(|_| Error::InvalidArguments {
name: format!("ml::{}<{}>", self.name, self.version),
message: ARGUMENTS.into(),
})?;
let bytes = crate::obs::get(&path).await?;
let tensor = ndarray::arr1::<f32>(&args).into_dyn();
let outcome = tokio::task::spawn_blocking(move || {
let mut file = SurMlFile::from_bytes(bytes).unwrap();
let compute_unit = ModelComputation {
surml_file: &mut file,
};
compute_unit.raw_compute(tensor, None).map_err(Error::ModelComputation)
})
.await
.unwrap()?;
Ok(outcome[0].into())
}
_ => Err(Error::InvalidArguments {
name: format!("ml::{}<{}>", self.name, self.version),
message: ARGUMENTS.into(),
}),
}
}
#[cfg(not(feature = "ml"))]
pub(crate) async fn compute(
&self,
_ctx: &Context<'_>,
_opt: &Options,
_txn: &Transaction,
_doc: Option<&CursorDoc<'_>>,
) -> Result<Value, Error> {
Err(Error::InvalidModel {
message: String::from("Machine learning computation is not enabled."),
})
}
}