surrealdb/api/engine/local/
native.rsuse crate::{
api::{
conn::{Connection, Route, Router},
engine::local::Db,
method::BoxFuture,
opt::{Endpoint, EndpointKind},
ExtraFeatures, OnceLockExt, Result, Surreal,
},
engine::tasks,
opt::{auth::Root, WaitFor},
value::Notification,
Action,
};
use channel::{Receiver, Sender};
use futures::{stream::poll_fn, StreamExt};
use std::{
collections::{BTreeMap, HashMap, HashSet},
sync::{atomic::AtomicI64, Arc, OnceLock},
task::Poll,
};
use surrealdb_core::{dbs::Session, iam::Level, kvs::Datastore, options::EngineOptions};
use tokio::sync::{watch, RwLock};
use tokio_util::sync::CancellationToken;
impl crate::api::Connection for Db {}
impl Connection for Db {
fn connect(address: Endpoint, capacity: usize) -> BoxFuture<'static, Result<Surreal<Self>>> {
Box::pin(async move {
let (route_tx, route_rx) = match capacity {
0 => channel::unbounded(),
capacity => channel::bounded(capacity),
};
let (conn_tx, conn_rx) = channel::bounded(1);
tokio::spawn(run_router(address, conn_tx, route_rx));
conn_rx.recv().await??;
let mut features = HashSet::new();
features.insert(ExtraFeatures::Backup);
features.insert(ExtraFeatures::LiveQueries);
Ok(Surreal::new_from_router_waiter(
Arc::new(OnceLock::with_value(Router {
features,
sender: route_tx,
last_id: AtomicI64::new(0),
})),
Arc::new(watch::channel(Some(WaitFor::Connection))),
))
})
}
}
pub(crate) async fn run_router(
address: Endpoint,
conn_tx: Sender<Result<()>>,
route_rx: Receiver<Route>,
) {
let configured_root = match address.config.auth {
Level::Root => Some(Root {
username: &address.config.username,
password: &address.config.password,
}),
_ => None,
};
let endpoint = match EndpointKind::from(address.url.scheme()) {
EndpointKind::TiKv => address.url.as_str(),
_ => &address.path,
};
let kvs = match Datastore::new(endpoint).await {
Ok(kvs) => {
if let Err(error) = kvs.check_version().await {
conn_tx.send(Err(error.into())).await.ok();
return;
};
if let Err(error) = kvs.bootstrap().await {
conn_tx.send(Err(error.into())).await.ok();
return;
}
if let Some(root) = configured_root {
if let Err(error) = kvs.initialise_credentials(root.username, root.password).await {
conn_tx.send(Err(error.into())).await.ok();
return;
}
}
conn_tx.send(Ok(())).await.ok();
kvs.with_auth_enabled(configured_root.is_some())
}
Err(error) => {
conn_tx.send(Err(error.into())).await.ok();
return;
}
};
let kvs = match address.config.capabilities.allows_live_query_notifications() {
true => kvs.with_notifications(),
false => kvs,
};
let kvs = kvs
.with_strict_mode(address.config.strict)
.with_query_timeout(address.config.query_timeout)
.with_transaction_timeout(address.config.transaction_timeout)
.with_capabilities(address.config.capabilities);
#[cfg(storage)]
let kvs = kvs.with_temporary_directory(address.config.temporary_directory);
let kvs = Arc::new(kvs);
let vars = Arc::new(RwLock::new(BTreeMap::default()));
let live_queries = Arc::new(RwLock::new(HashMap::new()));
let session = Arc::new(RwLock::new(Session::default().with_rt(true)));
let canceller = CancellationToken::new();
let mut opt = EngineOptions::default();
if let Some(interval) = address.config.node_membership_refresh_interval {
opt.node_membership_refresh_interval = interval;
}
if let Some(interval) = address.config.node_membership_check_interval {
opt.node_membership_check_interval = interval;
}
if let Some(interval) = address.config.node_membership_cleanup_interval {
opt.node_membership_cleanup_interval = interval;
}
if let Some(interval) = address.config.changefeed_gc_interval {
opt.changefeed_gc_interval = interval;
}
let tasks = tasks::init(kvs.clone(), canceller.clone(), &opt);
let mut notifications = kvs.notifications().map(Box::pin);
let mut notification_stream = poll_fn(move |cx| match &mut notifications {
Some(rx) => rx.poll_next_unpin(cx),
None => Poll::Pending,
});
loop {
let kvs = kvs.clone();
let session = session.clone();
let vars = vars.clone();
let live_queries = live_queries.clone();
tokio::select! {
route = route_rx.recv() => {
let Ok(route) = route else {
break
};
tokio::spawn(async move {
match super::router(route.request, &kvs, &session, &vars, &live_queries)
.await
{
Ok(value) => {
route.response.send(Ok(value)).await.ok();
}
Err(error) => {
route.response.send(Err(error)).await.ok();
}
}
});
}
notification = notification_stream.next() => {
let Some(notification) = notification else {
continue
};
let notification = Notification{
query_id: *notification.id,
action: Action::from_core(notification.action),
data: notification.result
};
tokio::spawn(async move {
let id = notification.query_id;
if let Some(sender) = live_queries.read().await.get(&id) {
if sender.send(notification).await.is_err() {
live_queries.write().await.remove(&id);
if let Err(error) =
super::kill_live_query(&kvs, id, &*session.read().await, vars.read().await.clone()).await
{
warn!("Failed to kill live query '{id}'; {error}");
}
}
}
});
}
}
}
canceller.cancel();
tasks.resolve().await.ok();
kvs.shutdown().await.ok();
}