use std::{
collections::HashMap,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, RwLock,
},
};
use crate::{WasiProcess, WasiProcessId};
#[derive(Debug, Clone)]
pub struct WasiControlPlane {
state: Arc<State>,
}
#[derive(Debug, Clone)]
pub struct WasiControlPlaneHandle {
inner: std::sync::Weak<State>,
}
impl WasiControlPlaneHandle {
fn new(inner: &Arc<State>) -> Self {
Self {
inner: Arc::downgrade(inner),
}
}
pub fn upgrade(&self) -> Option<WasiControlPlane> {
self.inner.upgrade().map(|state| WasiControlPlane { state })
}
pub fn must_upgrade(&self) -> WasiControlPlane {
let state = self.inner.upgrade().expect("control plane unavailable");
WasiControlPlane { state }
}
}
#[derive(Debug, Clone)]
pub struct ControlPlaneConfig {
pub max_task_count: Option<usize>,
}
impl ControlPlaneConfig {
pub fn new() -> Self {
Self {
max_task_count: None,
}
}
}
impl Default for ControlPlaneConfig {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct State {
config: ControlPlaneConfig,
task_count: Arc<AtomicUsize>,
mutable: RwLock<MutableState>,
}
#[derive(Debug)]
struct MutableState {
process_seed: u32,
processes: HashMap<WasiProcessId, WasiProcess>,
}
impl WasiControlPlane {
pub fn new(config: ControlPlaneConfig) -> Self {
Self {
state: Arc::new(State {
config,
task_count: Arc::new(AtomicUsize::new(0)),
mutable: RwLock::new(MutableState {
process_seed: 0,
processes: Default::default(),
}),
}),
}
}
pub fn handle(&self) -> WasiControlPlaneHandle {
WasiControlPlaneHandle::new(&self.state)
}
fn active_task_count(&self) -> usize {
self.state.task_count.load(Ordering::SeqCst)
}
pub(super) fn register_task(&self) -> Result<TaskCountGuard, ControlPlaneError> {
let count = self.state.task_count.fetch_add(1, Ordering::SeqCst);
if let Some(max) = self.state.config.max_task_count {
if count > max {
self.state.task_count.fetch_sub(1, Ordering::SeqCst);
return Err(ControlPlaneError::TaskLimitReached { max: count });
}
}
Ok(TaskCountGuard(self.state.task_count.clone()))
}
pub fn new_process(&self) -> Result<WasiProcess, ControlPlaneError> {
if let Some(max) = self.state.config.max_task_count {
if self.active_task_count() >= max {
return Err(ControlPlaneError::TaskLimitReached { max });
}
}
let mut proc = WasiProcess::new(WasiProcessId::from(0), self.handle());
let mut mutable = self.state.mutable.write().unwrap();
let pid = mutable.next_process_id()?;
proc.set_pid(pid);
mutable.processes.insert(pid, proc.clone());
Ok(proc)
}
pub fn generate_id(&self) -> Result<WasiProcessId, ControlPlaneError> {
let mut mutable = self.state.mutable.write().unwrap();
mutable.next_process_id()
}
pub fn get_process(&self, pid: WasiProcessId) -> Option<WasiProcess> {
self.state
.mutable
.read()
.unwrap()
.processes
.get(&pid)
.cloned()
}
}
impl MutableState {
fn next_process_id(&mut self) -> Result<WasiProcessId, ControlPlaneError> {
let id = self.process_seed.checked_add(1).ok_or({
ControlPlaneError::TaskLimitReached {
max: u32::MAX as usize,
}
})?;
self.process_seed = id;
Ok(WasiProcessId::from(id))
}
}
impl Default for WasiControlPlane {
fn default() -> Self {
let config = ControlPlaneConfig::default();
Self::new(config)
}
}
#[derive(Debug)]
pub struct TaskCountGuard(Arc<AtomicUsize>);
impl Drop for TaskCountGuard {
fn drop(&mut self) {
self.0.fetch_sub(1, Ordering::SeqCst);
}
}
#[derive(thiserror::Error, PartialEq, Eq, Clone, Debug)]
pub enum ControlPlaneError {
#[error("The maximum number of execution tasks has been reached ({max})")]
TaskLimitReached {
max: usize,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_control_plane_task_limits() {
let p = WasiControlPlane::new(ControlPlaneConfig {
max_task_count: Some(2),
});
let p1 = p.new_process().unwrap();
let _t1 = p1.new_thread().unwrap();
let _t2 = p1.new_thread().unwrap();
assert_eq!(
p.new_process().unwrap_err(),
ControlPlaneError::TaskLimitReached { max: 2 }
);
}
#[test]
fn test_control_plane_task_limits_with_dropped_threads() {
let p = WasiControlPlane::new(ControlPlaneConfig {
max_task_count: Some(2),
});
let p1 = p.new_process().unwrap();
for _ in 0..10 {
let _thread = p1.new_thread().unwrap();
}
let _t1 = p1.new_thread().unwrap();
let _t2 = p1.new_thread().unwrap();
assert_eq!(
p.new_process().unwrap_err(),
ControlPlaneError::TaskLimitReached { max: 2 }
);
}
}