use super::{live, Stream};
use crate::api::conn::Command;
use crate::api::err::Error;
use crate::api::method::BoxFuture;
use crate::api::opt;
use crate::api::Connection;
use crate::api::ExtraFeatures;
use crate::api::Result;
use crate::method::OnceLockExt;
use crate::method::Stats;
use crate::method::WithStats;
use crate::value::Notification;
use crate::{Surreal, Value};
use futures::future::Either;
use futures::stream::SelectAll;
use futures::StreamExt;
use indexmap::IndexMap;
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::borrow::Cow;
use std::collections::HashMap;
use std::future::IntoFuture;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use surrealdb_core::sql::{
self, to_value as to_core_value, Object as CoreObject, Statement, Value as CoreValue,
};
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Query<'r, C: Connection> {
pub(crate) client: Cow<'r, Surreal<C>>,
pub(crate) inner: Result<ValidQuery>,
}
#[derive(Debug)]
pub(crate) enum ValidQuery {
Raw {
query: Cow<'static, str>,
bindings: CoreObject,
},
Normal {
query: Vec<Statement>,
register_live_queries: bool,
bindings: CoreObject,
},
}
impl<'r, C> Query<'r, C>
where
C: Connection,
{
pub(crate) fn normal(
client: Cow<'r, Surreal<C>>,
query: Vec<Statement>,
bindings: CoreObject,
register_live_queries: bool,
) -> Self {
Query {
client,
inner: Ok(ValidQuery::Normal {
query,
bindings,
register_live_queries,
}),
}
}
pub(crate) fn map_valid<F>(self, f: F) -> Self
where
F: FnOnce(ValidQuery) -> Result<ValidQuery>,
{
match self.inner {
Ok(x) => Query {
client: self.client,
inner: f(x),
},
x => Query {
client: self.client,
inner: x,
},
}
}
pub fn into_owned(self) -> Query<'static, C> {
Query {
client: Cow::Owned(self.client.into_owned()),
inner: self.inner,
}
}
}
impl<'r, Client> IntoFuture for Query<'r, Client>
where
Client: Connection,
{
type Output = Result<Response>;
type IntoFuture = BoxFuture<'r, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let router = self.client.router.extract()?;
match self.inner? {
ValidQuery::Raw {
query,
bindings,
} => {
router
.execute_query(Command::RawQuery {
query,
variables: bindings,
})
.await
}
ValidQuery::Normal {
query,
register_live_queries,
bindings,
} => {
let query_statements = query;
let query_indicies = if register_live_queries {
query_statements
.iter()
.filter(|x| {
!matches!(
x,
Statement::Begin(_)
| Statement::Commit(_) | Statement::Cancel(_)
)
})
.enumerate()
.filter(|(_, x)| matches!(x, Statement::Live(_)))
.map(|(i, _)| i)
.collect()
} else {
Vec::new()
};
if !query_indicies.is_empty()
&& !router.features.contains(&ExtraFeatures::LiveQueries)
{
return Err(Error::LiveQueriesNotSupported.into());
}
let mut query = sql::Query::default();
query.0 .0 = query_statements;
let mut response = router
.execute_query(Command::Query {
query,
variables: bindings,
})
.await?;
for idx in query_indicies {
let Some((_, result)) = response.results.get(&idx) else {
continue;
};
let res = match result {
Ok(id) => {
let CoreValue::Uuid(uuid) = id else {
return Err(Error::InternalError(
"successfull live query did not return a uuid".to_string(),
)
.into());
};
live::register(router, uuid.0).await.map(|rx| {
Stream::new(
Surreal::new_from_router_waiter(
self.client.router.clone(),
self.client.waiter.clone(),
),
uuid.0,
Some(rx),
)
})
}
Err(_) => Err(crate::Error::from(Error::NotLiveQuery(idx))),
};
response.live_queries.insert(idx, res);
}
Ok(response)
}
}
})
}
}
impl<'r, Client> IntoFuture for WithStats<Query<'r, Client>>
where
Client: Connection,
{
type Output = Result<WithStats<Response>>;
type IntoFuture = BoxFuture<'r, Self::Output>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let response = self.0.await?;
Ok(WithStats(response))
})
}
}
impl<C> Query<'_, C>
where
C: Connection,
{
pub fn query(self, surql: impl opt::IntoQuery) -> Self {
self.map_valid(move |valid| match valid {
ValidQuery::Raw {
..
} => {
Err(Error::InvalidParams("Appending to raw queries is not supported".to_owned())
.into())
}
ValidQuery::Normal {
mut query,
register_live_queries,
bindings,
} => match surql.into_query() {
Ok(stmts) => {
query.extend(stmts);
Ok(ValidQuery::Normal {
query,
register_live_queries,
bindings,
})
}
Err(crate::Error::Api(crate::api::err::Error::RawQuery(..))) => {
Err(Error::InvalidParams("Appending raw queries is not supported".to_owned())
.into())
}
Err(error) => Err(error),
},
})
}
pub const fn with_stats(self) -> WithStats<Self> {
WithStats(self)
}
pub fn bind(self, bindings: impl Serialize + 'static) -> Self {
self.map_valid(move |mut valid| {
let current_bindings = match &mut valid {
ValidQuery::Raw {
bindings,
..
} => bindings,
ValidQuery::Normal {
bindings,
..
} => bindings,
};
let bindings = to_core_value(bindings)?;
match bindings {
CoreValue::Object(mut map) => current_bindings.append(&mut map.0),
CoreValue::Array(array) => {
if array.len() != 2 || !matches!(array[0], CoreValue::Strand(_)) {
let bindings = CoreValue::Array(array);
let bindings = Value::from_inner(bindings);
return Err(Error::InvalidBindings(bindings).into());
}
let mut iter = array.into_iter();
let Some(CoreValue::Strand(key)) = iter.next() else {
unreachable!()
};
let Some(value) = iter.next() else {
unreachable!()
};
current_bindings.insert(key.0, value);
}
_ => {
let bindings = Value::from_inner(bindings);
return Err(Error::InvalidBindings(bindings).into());
}
}
Ok(valid)
})
}
}
pub(crate) type QueryResult = Result<CoreValue>;
#[derive(Debug)]
pub struct Response {
pub(crate) results: IndexMap<usize, (Stats, QueryResult)>,
pub(crate) live_queries: IndexMap<usize, Result<Stream<Value>>>,
}
#[derive(Debug)]
#[must_use = "streams do nothing unless you poll them"]
pub struct QueryStream<R>(pub(crate) Either<Stream<R>, SelectAll<Stream<R>>>);
impl futures::Stream for QueryStream<Value> {
type Item = Notification<Value>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.as_mut().0.poll_next_unpin(cx)
}
}
impl<R> futures::Stream for QueryStream<Notification<R>>
where
R: DeserializeOwned + Unpin,
{
type Item = Result<Notification<R>>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.as_mut().0.poll_next_unpin(cx)
}
}
impl Response {
pub(crate) fn new() -> Self {
Self {
results: Default::default(),
live_queries: Default::default(),
}
}
pub fn take<R>(&mut self, index: impl opt::QueryResult<R>) -> Result<R>
where
R: DeserializeOwned,
{
index.query_result(self)
}
pub fn stream<R>(&mut self, index: impl opt::QueryStream<R>) -> Result<QueryStream<R>> {
index.query_stream(self)
}
pub fn take_errors(&mut self) -> HashMap<usize, crate::Error> {
let mut keys = Vec::new();
for (key, result) in &self.results {
if result.1.is_err() {
keys.push(*key);
}
}
let mut errors = HashMap::with_capacity(keys.len());
for key in keys {
if let Some((_, Err(error))) = self.results.swap_remove(&key) {
errors.insert(key, error);
}
}
errors
}
pub fn check(mut self) -> Result<Self> {
let mut first_error = None;
for (key, result) in &self.results {
if result.1.is_err() {
first_error = Some(*key);
break;
}
}
if let Some(key) = first_error {
if let Some((_, Err(error))) = self.results.swap_remove(&key) {
return Err(error);
}
}
Ok(self)
}
pub fn num_statements(&self) -> usize {
self.results.len()
}
}
impl WithStats<Response> {
pub fn take<R>(&mut self, index: impl opt::QueryResult<R>) -> Option<(Stats, Result<R>)>
where
R: DeserializeOwned,
{
let stats = index.stats(&self.0)?;
let result = index.query_result(&mut self.0);
Some((stats, result))
}
pub fn take_errors(&mut self) -> HashMap<usize, (Stats, crate::Error)> {
let mut keys = Vec::new();
for (key, result) in &self.0.results {
if result.1.is_err() {
keys.push(*key);
}
}
let mut errors = HashMap::with_capacity(keys.len());
for key in keys {
if let Some((stats, Err(error))) = self.0.results.swap_remove(&key) {
errors.insert(key, (stats, error));
}
}
errors
}
pub fn check(self) -> Result<Self> {
let response = self.0.check()?;
Ok(Self(response))
}
pub fn num_statements(&self) -> usize {
self.0.num_statements()
}
pub fn into_inner(self) -> Response {
self.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{value::to_value, Error::Api};
use serde::Deserialize;
use surrealdb_core::sql::Value as CoreValue;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Summary {
title: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct Article {
title: String,
body: String,
}
fn to_map(vec: Vec<QueryResult>) -> IndexMap<usize, (Stats, QueryResult)> {
vec.into_iter()
.map(|result| {
let stats = Stats {
execution_time: Default::default(),
};
(stats, result)
})
.enumerate()
.collect()
}
#[test]
fn take_from_an_empty_response() {
let mut response = Response::new();
let value: Value = response.take(0).unwrap();
assert!(value.into_inner().is_none());
let mut response = Response::new();
let option: Option<String> = response.take(0).unwrap();
assert!(option.is_none());
let mut response = Response::new();
let vec: Vec<String> = response.take(0).unwrap();
assert!(vec.is_empty());
}
#[test]
fn take_from_an_errored_query() {
let mut response = Response {
results: to_map(vec![Err(Error::ConnectionUninitialised.into())]),
..Response::new()
};
response.take::<Option<()>>(0).unwrap_err();
}
#[test]
fn take_from_empty_records() {
let mut response = Response {
results: to_map(vec![]),
..Response::new()
};
let value: Value = response.take(0).unwrap();
assert_eq!(value, Default::default());
let mut response = Response {
results: to_map(vec![]),
..Response::new()
};
let option: Option<String> = response.take(0).unwrap();
assert!(option.is_none());
let mut response = Response {
results: to_map(vec![]),
..Response::new()
};
let vec: Vec<String> = response.take(0).unwrap();
assert!(vec.is_empty());
}
#[test]
fn take_from_a_scalar_response() {
let scalar = 265;
let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let value: Value = response.take(0).unwrap();
assert_eq!(value.into_inner(), CoreValue::from(scalar));
let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let option: Option<_> = response.take(0).unwrap();
assert_eq!(option, Some(scalar));
let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let vec: Vec<i64> = response.take(0).unwrap();
assert_eq!(vec, vec![scalar]);
let scalar = true;
let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let value: Value = response.take(0).unwrap();
assert_eq!(value.into_inner(), CoreValue::from(scalar));
let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let option: Option<_> = response.take(0).unwrap();
assert_eq!(option, Some(scalar));
let mut response = Response {
results: to_map(vec![Ok(scalar.into())]),
..Response::new()
};
let vec: Vec<bool> = response.take(0).unwrap();
assert_eq!(vec, vec![scalar]);
}
#[test]
fn take_preserves_order() {
let mut response = Response {
results: to_map(vec![
Ok(0.into()),
Ok(1.into()),
Ok(2.into()),
Ok(3.into()),
Ok(4.into()),
Ok(5.into()),
Ok(6.into()),
Ok(7.into()),
]),
..Response::new()
};
let Some(four): Option<i32> = response.take(4).unwrap() else {
panic!("query not found");
};
assert_eq!(four, 4);
let Some(six): Option<i32> = response.take(6).unwrap() else {
panic!("query not found");
};
assert_eq!(six, 6);
let Some(zero): Option<i32> = response.take(0).unwrap() else {
panic!("query not found");
};
assert_eq!(zero, 0);
let one: Value = response.take(1).unwrap();
assert_eq!(one.into_inner(), CoreValue::from(1));
}
#[test]
fn take_key() {
let summary = Summary {
title: "Lorem Ipsum".to_owned(),
};
let value = to_value(summary.clone()).unwrap();
let mut response = Response {
results: to_map(vec![Ok(value.clone().into_inner())]),
..Response::new()
};
let title: Value = response.take("title").unwrap();
assert_eq!(title.into_inner(), CoreValue::from(summary.title.as_str()));
let mut response = Response {
results: to_map(vec![Ok(value.clone().into_inner())]),
..Response::new()
};
let Some(title): Option<String> = response.take("title").unwrap() else {
panic!("title not found");
};
assert_eq!(title, summary.title);
let mut response = Response {
results: to_map(vec![Ok(value.into_inner())]),
..Response::new()
};
let vec: Vec<String> = response.take("title").unwrap();
assert_eq!(vec, vec![summary.title]);
let article = Article {
title: "Lorem Ipsum".to_owned(),
body: "Lorem Ipsum Lorem Ipsum".to_owned(),
};
let value = to_value(article.clone()).unwrap();
let mut response = Response {
results: to_map(vec![Ok(value.clone().into_inner())]),
..Response::new()
};
let Some(title): Option<String> = response.take("title").unwrap() else {
panic!("title not found");
};
assert_eq!(title, article.title);
let Some(body): Option<String> = response.take("body").unwrap() else {
panic!("body not found");
};
assert_eq!(body, article.body);
let mut response = Response {
results: to_map(vec![Ok(value.clone().into_inner())]),
..Response::new()
};
let vec: Vec<String> = response.take("title").unwrap();
assert_eq!(vec, vec![article.title.clone()]);
let mut response = Response {
results: to_map(vec![Ok(value.into_inner())]),
..Response::new()
};
let value: Value = response.take("title").unwrap();
assert_eq!(value.into_inner(), CoreValue::from(article.title));
}
#[test]
fn take_key_multi() {
let article = Article {
title: "Lorem Ipsum".to_owned(),
body: "Lorem Ipsum Lorem Ipsum".to_owned(),
};
let value = to_value(article.clone()).unwrap();
let mut response = Response {
results: to_map(vec![Ok(value.clone().into_inner())]),
..Response::new()
};
let title: Vec<String> = response.take("title").unwrap();
assert_eq!(title, vec![article.title.clone()]);
let body: Vec<String> = response.take("body").unwrap();
assert_eq!(body, vec![article.body]);
let mut response = Response {
results: to_map(vec![Ok(value.clone().into_inner())]),
..Response::new()
};
let vec: Vec<String> = response.take("title").unwrap();
assert_eq!(vec, vec![article.title]);
}
#[test]
fn take_partial_records() {
let mut response = Response {
results: to_map(vec![Ok(vec![true, false].into())]),
..Response::new()
};
let value: Value = response.take(0).unwrap();
assert_eq!(value.into_inner(), vec![CoreValue::from(true), CoreValue::from(false)].into());
let mut response = Response {
results: to_map(vec![Ok(vec![true, false].into())]),
..Response::new()
};
let vec: Vec<bool> = response.take(0).unwrap();
assert_eq!(vec, vec![true, false]);
let mut response = Response {
results: to_map(vec![Ok(vec![true, false].into())]),
..Response::new()
};
let Err(Api(Error::LossyTake(Response {
results: mut map,
..
}))): Result<Option<bool>> = response.take(0)
else {
panic!("silently dropping records not allowed");
};
let records = map.swap_remove(&0).unwrap().1.unwrap();
assert_eq!(records, vec![true, false].into());
}
#[test]
fn check_returns_the_first_error() {
let response = vec![
Ok(0.into()),
Ok(1.into()),
Ok(2.into()),
Err(Error::ConnectionUninitialised.into()),
Ok(3.into()),
Ok(4.into()),
Ok(5.into()),
Err(Error::BackupsNotSupported.into()),
Ok(6.into()),
Ok(7.into()),
Err(Error::DuplicateRequestId(0).into()),
];
let response = Response {
results: to_map(response),
..Response::new()
};
let crate::Error::Api(Error::ConnectionUninitialised) = response.check().unwrap_err()
else {
panic!("check did not return the first error");
};
}
#[test]
fn take_errors() {
let response = vec![
Ok(0.into()),
Ok(1.into()),
Ok(2.into()),
Err(Error::ConnectionUninitialised.into()),
Ok(3.into()),
Ok(4.into()),
Ok(5.into()),
Err(Error::BackupsNotSupported.into()),
Ok(6.into()),
Ok(7.into()),
Err(Error::DuplicateRequestId(0).into()),
];
let mut response = Response {
results: to_map(response),
..Response::new()
};
let errors = response.take_errors();
assert_eq!(response.num_statements(), 8);
assert_eq!(errors.len(), 3);
let crate::Error::Api(Error::DuplicateRequestId(0)) = errors.get(&10).unwrap() else {
panic!("index `10` is not `DuplicateRequestId`");
};
let crate::Error::Api(Error::BackupsNotSupported) = errors.get(&7).unwrap() else {
panic!("index `7` is not `BackupsNotSupported`");
};
let crate::Error::Api(Error::ConnectionUninitialised) = errors.get(&3).unwrap() else {
panic!("index `3` is not `ConnectionUninitialised`");
};
let Some(value): Option<i32> = response.take(2).unwrap() else {
panic!("statement not found");
};
assert_eq!(value, 2);
let value: Value = response.take(4).unwrap();
assert_eq!(value.into_inner(), CoreValue::from(3));
}
}