use crate::state::{
State,
StateWatcher,
};
use anyhow::anyhow;
use futures::FutureExt;
use tokio::sync::watch;
use tracing::Instrument;
pub type Shared<T> = std::sync::Arc<T>;
#[derive(Debug)]
pub struct SharedMutex<T>(Shared<parking_lot::Mutex<T>>);
impl<T> Clone for SharedMutex<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EmptyShared;
#[async_trait::async_trait]
pub trait Service {
fn start(&self) -> anyhow::Result<()>;
async fn start_and_await(&self) -> anyhow::Result<State>;
async fn await_start_or_stop(&self) -> anyhow::Result<State>;
fn stop(&self) -> bool;
async fn stop_and_await(&self) -> anyhow::Result<State>;
async fn await_stop(&self) -> anyhow::Result<State>;
fn state(&self) -> State;
fn state_watcher(&self) -> StateWatcher;
}
#[async_trait::async_trait]
pub trait RunnableService: Send {
const NAME: &'static str;
type SharedData: Clone + Send + Sync;
type Task: RunnableTask;
fn shared_data(&self) -> Self::SharedData;
async fn into_task(self, state_watcher: &StateWatcher) -> anyhow::Result<Self::Task>;
}
#[async_trait::async_trait]
pub trait RunnableTask: Send {
async fn run(&mut self, watcher: &mut StateWatcher) -> anyhow::Result<bool>;
async fn shutdown(self) -> anyhow::Result<()>;
}
#[derive(Debug)]
pub struct ServiceRunner<S>
where
S: RunnableService + 'static,
{
pub shared: S::SharedData,
state: Shared<watch::Sender<State>>,
}
impl<S> Drop for ServiceRunner<S>
where
S: RunnableService + 'static,
{
fn drop(&mut self) {
self.stop();
}
}
impl<S> ServiceRunner<S>
where
S: RunnableService + 'static,
{
pub fn new(service: S) -> Self {
let shared = service.shared_data();
let state = initialize_loop(service);
Self { shared, state }
}
async fn _await_start_or_stop(
&self,
mut start: StateWatcher,
) -> anyhow::Result<State> {
loop {
let state = start.borrow().clone();
if !state.starting() {
return Ok(state)
}
start.changed().await?;
}
}
async fn _await_stop(&self, mut stop: StateWatcher) -> anyhow::Result<State> {
loop {
let state = stop.borrow().clone();
if state.stopped() {
return Ok(state)
}
stop.changed().await?;
}
}
}
#[async_trait::async_trait]
impl<S> Service for ServiceRunner<S>
where
S: RunnableService + 'static,
{
fn start(&self) -> anyhow::Result<()> {
let started = self.state.send_if_modified(|state| {
if state.not_started() {
*state = State::Starting;
true
} else {
false
}
});
if started {
Ok(())
} else {
Err(anyhow!(
"The service `{}` already has been started.",
S::NAME
))
}
}
async fn start_and_await(&self) -> anyhow::Result<State> {
let start = self.state.subscribe().into();
self.start()?;
self._await_start_or_stop(start).await
}
async fn await_start_or_stop(&self) -> anyhow::Result<State> {
let start = self.state.subscribe().into();
self._await_start_or_stop(start).await
}
fn stop(&self) -> bool {
self.state.send_if_modified(|state| {
if state.not_started() || state.starting() || state.started() {
*state = State::Stopping;
true
} else {
false
}
})
}
async fn stop_and_await(&self) -> anyhow::Result<State> {
let stop = self.state.subscribe().into();
self.stop();
self._await_stop(stop).await
}
async fn await_stop(&self) -> anyhow::Result<State> {
let stop = self.state.subscribe().into();
self._await_stop(stop).await
}
fn state(&self) -> State {
self.state.borrow().clone()
}
fn state_watcher(&self) -> StateWatcher {
self.state.subscribe().into()
}
}
#[tracing::instrument(skip_all, fields(service = S::NAME))]
fn initialize_loop<S>(service: S) -> Shared<watch::Sender<State>>
where
S: RunnableService + 'static,
{
let (sender, _) = watch::channel(State::NotStarted);
let state = Shared::new(sender);
let stop_sender = state.clone();
tokio::task::spawn(
async move {
tracing::debug!("running");
let run = std::panic::AssertUnwindSafe(run(service, stop_sender.clone()));
tracing::debug!("awaiting run");
let result = run.catch_unwind().await;
let stopped_state = if let Err(e) = result {
let panic_information = panic_to_string(e);
State::StoppedWithError(panic_information)
} else {
State::Stopped
};
tracing::debug!("shutting down {:?}", stopped_state);
let _ = stop_sender.send_if_modified(|state| {
if !state.stopped() {
*state = stopped_state.clone();
tracing::debug!("Wasn't stopped, so sent stop.");
true
} else {
tracing::debug!("Was already stopped.");
false
}
});
if let State::StoppedWithError(err) = stopped_state {
std::panic::resume_unwind(Box::new(err));
}
}
.in_current_span(),
);
state
}
async fn run<S>(service: S, sender: Shared<watch::Sender<State>>)
where
S: RunnableService + 'static,
{
let mut state: StateWatcher = sender.subscribe().into();
if state.borrow_and_update().not_started() {
state.changed().await.expect("The service is destroyed");
}
if !state.borrow().starting() {
return
}
let mut task = service
.into_task(&state)
.await
.expect("The initialization of the service failed.");
sender.send_if_modified(|s| {
if s.starting() {
*s = State::Started;
true
} else {
false
}
});
let mut got_panic = None;
while state.borrow_and_update().started() {
let task = std::panic::AssertUnwindSafe(task.run(&mut state));
let panic_result = task.catch_unwind().await;
match panic_result {
Ok(Ok(should_continue)) => {
if !should_continue {
tracing::debug!("stopping");
break
}
tracing::debug!("run loop");
}
Ok(Err(e)) => {
let e: &dyn std::error::Error = &*e;
tracing::error!(e);
}
Err(panic) => {
tracing::debug!("got a panic");
got_panic = Some(panic);
break
}
}
}
let shutdown = std::panic::AssertUnwindSafe(task.shutdown());
match shutdown.catch_unwind().await {
Ok(Ok(_)) => {}
Ok(Err(e)) => {
tracing::error!("Go an error during shutdown of the task: {e}");
}
Err(e) => {
if got_panic.is_some() {
let panic_information = panic_to_string(e);
tracing::error!(
"Go a panic during execution and shutdown of the task. \
The error during shutdown: {panic_information}"
);
} else {
got_panic = Some(e);
}
}
}
if let Some(panic) = got_panic {
std::panic::resume_unwind(panic)
}
}
impl<T> SharedMutex<T> {
pub fn new(t: T) -> Self {
Self(Shared::new(parking_lot::Mutex::new(t)))
}
pub fn apply<R>(&self, f: impl FnOnce(&mut T) -> R) -> R {
let mut t = self.0.lock();
f(&mut t)
}
}
fn panic_to_string(e: Box<dyn core::any::Any + Send>) -> String {
let panic_information = match e.downcast::<String>() {
Ok(v) => *v,
Err(e) => match e.downcast::<&str>() {
Ok(v) => v.to_string(),
_ => "Unknown Source of Error".to_owned(),
},
};
panic_information
}
#[cfg(test)]
mod tests {
use super::*;
use futures::future::BoxFuture;
mockall::mock! {
Service {}
#[async_trait::async_trait]
impl RunnableService for Service {
const NAME: &'static str = "MockService";
type SharedData = EmptyShared;
type Task = MockTask;
fn shared_data(&self) -> EmptyShared;
async fn into_task(self, state: &StateWatcher) -> anyhow::Result<MockTask>;
}
}
mockall::mock! {
Task {}
#[async_trait::async_trait]
impl RunnableTask for Task {
fn run<'_self, '_state, 'a>(
&'_self mut self,
state: &'_state mut StateWatcher
) -> BoxFuture<'a, anyhow::Result<bool>>
where
'_self: 'a,
'_state: 'a,
Self: Sync + 'a;
async fn shutdown(self) -> anyhow::Result<()>;
}
}
impl MockService {
fn new_empty() -> Self {
let mut mock = MockService::default();
mock.expect_shared_data().returning(|| EmptyShared);
mock.expect_into_task().returning(|_| {
let mut mock = MockTask::default();
mock.expect_run().returning(|watcher| {
let mut watcher = watcher.clone();
Box::pin(async move {
watcher.while_started().await.unwrap();
let should_continue = false;
Ok(should_continue)
})
});
mock.expect_shutdown().times(1).returning(|| Ok(()));
Ok(mock)
});
mock
}
}
#[tokio::test]
async fn start_and_await_stop_and_await_works() {
let service = ServiceRunner::new(MockService::new_empty());
let state = service.start_and_await().await.unwrap();
assert!(state.started());
let state = service.stop_and_await().await.unwrap();
assert!(matches!(state, State::Stopped));
}
#[tokio::test]
async fn double_start_fails() {
let service = ServiceRunner::new(MockService::new_empty());
assert!(service.start().is_ok());
assert!(service.start().is_err());
}
#[tokio::test]
async fn double_start_and_await_fails() {
let service = ServiceRunner::new(MockService::new_empty());
assert!(service.start_and_await().await.is_ok());
assert!(service.start_and_await().await.is_err());
}
#[tokio::test]
async fn stop_without_start() {
let service = ServiceRunner::new(MockService::new_empty());
service.stop_and_await().await.unwrap();
assert!(matches!(service.state(), State::Stopped));
}
#[tokio::test]
async fn panic_during_run() {
let mut mock = MockService::default();
mock.expect_shared_data().returning(|| EmptyShared);
mock.expect_into_task().returning(|_| {
let mut mock = MockTask::default();
mock.expect_run().returning(|_| panic!("Should fail"));
mock.expect_shutdown().times(1).returning(|| Ok(()));
Ok(mock)
});
let service = ServiceRunner::new(mock);
let state = service.start_and_await().await.unwrap();
assert!(matches!(state, State::StoppedWithError(s) if s.contains("Should fail")));
let state = service.await_stop().await.unwrap();
assert!(matches!(state, State::StoppedWithError(s) if s.contains("Should fail")));
}
#[tokio::test]
async fn panic_during_shutdown() {
let mut mock = MockService::default();
mock.expect_shared_data().returning(|| EmptyShared);
mock.expect_into_task().returning(|_| {
let mut mock = MockTask::default();
mock.expect_run().returning(|_| {
Box::pin(async move {
let should_continue = false;
Ok(should_continue)
})
});
mock.expect_shutdown()
.times(1)
.returning(|| panic!("Shutdown should fail"));
Ok(mock)
});
let service = ServiceRunner::new(mock);
let state = service.start_and_await().await.unwrap();
assert!(
matches!(state, State::StoppedWithError(s) if s.contains("Shutdown should fail"))
);
let state = service.await_stop().await.unwrap();
assert!(
matches!(state, State::StoppedWithError(s) if s.contains("Shutdown should fail"))
);
}
#[tokio::test]
async fn double_await_stop_works() {
let service = ServiceRunner::new(MockService::new_empty());
service.start().unwrap();
service.stop();
let state = service.await_stop().await.unwrap();
assert!(matches!(state, State::Stopped));
let state = service.await_stop().await.unwrap();
assert!(matches!(state, State::Stopped));
}
#[tokio::test]
async fn double_stop_and_await_works() {
let service = ServiceRunner::new(MockService::new_empty());
service.start().unwrap();
let state = service.stop_and_await().await.unwrap();
assert!(matches!(state, State::Stopped));
let state = service.stop_and_await().await.unwrap();
assert!(matches!(state, State::Stopped));
}
#[tokio::test]
async fn stop_unused_service() {
let mut receiver;
{
let service = ServiceRunner::new(MockService::new_empty());
service.start().unwrap();
receiver = service.state.subscribe();
}
receiver.changed().await.unwrap();
assert!(matches!(receiver.borrow().clone(), State::Stopping));
receiver.changed().await.unwrap();
assert!(matches!(receiver.borrow().clone(), State::Stopped));
}
}