use tokio::sync::watch;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum State {
NotStarted,
Starting,
Started,
Stopping,
Stopped,
StoppedWithError(String),
}
impl State {
pub fn not_started(&self) -> bool {
self == &State::NotStarted
}
pub fn starting(&self) -> bool {
self == &State::Starting
}
pub fn started(&self) -> bool {
self == &State::Started
}
pub fn stopped(&self) -> bool {
matches!(self, State::Stopped | State::StoppedWithError(_))
}
}
#[derive(Clone)]
pub struct StateWatcher(watch::Receiver<State>);
#[cfg(feature = "test-helpers")]
impl Default for StateWatcher {
fn default() -> Self {
let (_, receiver) = watch::channel(State::NotStarted);
Self(receiver)
}
}
impl StateWatcher {
pub fn borrow(&self) -> watch::Ref<'_, State> {
self.0.borrow()
}
pub fn borrow_and_update(&mut self) -> watch::Ref<'_, State> {
self.0.borrow_and_update()
}
pub fn has_changed(&self) -> Result<bool, watch::error::RecvError> {
self.0.has_changed()
}
pub async fn changed(&mut self) -> Result<(), watch::error::RecvError> {
self.0.changed().await
}
pub fn same_channel(&self, other: &Self) -> bool {
self.0.same_channel(&other.0)
}
}
impl StateWatcher {
#[tracing::instrument(level = "debug", skip(self), err, ret)]
pub async fn while_started(&mut self) -> anyhow::Result<State> {
loop {
let state = self.borrow().clone();
if !state.started() {
return Ok(state)
}
tracing::debug!("Service is started, waiting for the next state...");
self.changed().await?;
}
}
}
impl From<watch::Receiver<State>> for StateWatcher {
fn from(receiver: watch::Receiver<State>) -> Self {
Self(receiver)
}
}