1use crate::{
2 state::{
3 State,
4 StateWatcher,
5 },
6 Shared,
7};
8use anyhow::anyhow;
9use fuel_core_metrics::futures::{
10 future_tracker::FutureTracker,
11 FuturesMetrics,
12};
13use futures::FutureExt;
14use std::any::Any;
15use tokio::sync::watch;
16use tracing::Instrument;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub struct EmptyShared;
21
22#[async_trait::async_trait]
25pub trait Service {
26 fn start(&self) -> anyhow::Result<()>;
29
30 async fn start_and_await(&self) -> anyhow::Result<State>;
33
34 async fn await_start_or_stop(&self) -> anyhow::Result<State>;
36
37 fn stop(&self) -> bool;
40
41 async fn stop_and_await(&self) -> anyhow::Result<State>;
43
44 async fn await_stop(&self) -> anyhow::Result<State>;
46
47 fn state(&self) -> State;
49
50 fn state_watcher(&self) -> StateWatcher;
52}
53
54#[async_trait::async_trait]
56pub trait RunnableService: Send {
57 const NAME: &'static str;
59
60 type SharedData: Clone + Send + Sync;
65
66 type Task: RunnableTask;
68
69 type TaskParams: Send;
71
72 fn shared_data(&self) -> Self::SharedData;
74
75 async fn into_task(
80 self,
81 state_watcher: &StateWatcher,
82 params: Self::TaskParams,
83 ) -> anyhow::Result<Self::Task>;
84}
85
86#[derive(Debug)]
88pub enum TaskNextAction {
89 Continue,
91 Stop,
93 ErrorContinue(anyhow::Error),
95}
96
97impl TaskNextAction {
98 pub fn always_continue<T, E: Into<anyhow::Error>>(
100 res: Result<T, E>,
101 ) -> TaskNextAction {
102 match res {
103 Ok(_) => TaskNextAction::Continue,
104 Err(e) => TaskNextAction::ErrorContinue(e.into()),
105 }
106 }
107}
108
109impl From<Result<bool, anyhow::Error>> for TaskNextAction {
110 fn from(result: Result<bool, anyhow::Error>) -> Self {
111 match result {
112 Ok(should_continue) => {
113 if should_continue {
114 TaskNextAction::Continue
115 } else {
116 TaskNextAction::Stop
117 }
118 }
119 Err(e) => TaskNextAction::ErrorContinue(e),
120 }
121 }
122}
123
124pub trait RunnableTask: Send {
127 fn run(
137 &mut self,
138 watcher: &mut StateWatcher,
139 ) -> impl core::future::Future<Output = TaskNextAction> + Send;
140
141 fn shutdown(self) -> impl core::future::Future<Output = anyhow::Result<()>> + Send;
143}
144
145#[derive(Debug)]
148pub struct ServiceRunner<S>
149where
150 S: RunnableService + 'static,
151{
152 pub shared: S::SharedData,
154 state: Shared<watch::Sender<State>>,
155}
156
157impl<S> Drop for ServiceRunner<S>
158where
159 S: RunnableService + 'static,
160{
161 fn drop(&mut self) {
162 self.stop();
163 }
164}
165
166impl<S> ServiceRunner<S>
167where
168 S: RunnableService + 'static,
169 S::TaskParams: Default,
170{
171 pub fn new(service: S) -> Self {
173 Self::new_with_params(service, S::TaskParams::default())
174 }
175}
176
177impl<S> ServiceRunner<S>
178where
179 S: RunnableService + 'static,
180{
181 pub fn new_with_params(service: S, params: S::TaskParams) -> Self {
183 let shared = service.shared_data();
184 let metric = FuturesMetrics::obtain_futures_metrics(S::NAME);
185 let state = initialize_loop(service, params, metric);
186 Self { shared, state }
187 }
188
189 async fn _await_start_or_stop(
190 &self,
191 mut start: StateWatcher,
192 ) -> anyhow::Result<State> {
193 loop {
194 let state = start.borrow().clone();
195 if !state.starting() {
196 return Ok(state);
197 }
198 start.changed().await?;
199 }
200 }
201
202 async fn _await_stop(&self, mut stop: StateWatcher) -> anyhow::Result<State> {
203 loop {
204 let state = stop.borrow().clone();
205 if state.stopped() {
206 return Ok(state);
207 }
208 stop.changed().await?;
209 }
210 }
211}
212
213#[async_trait::async_trait]
214impl<S> Service for ServiceRunner<S>
215where
216 S: RunnableService + 'static,
217{
218 fn start(&self) -> anyhow::Result<()> {
219 let started = self.state.send_if_modified(|state| {
220 if state.not_started() {
221 *state = State::Starting;
222 true
223 } else {
224 false
225 }
226 });
227
228 if started {
229 Ok(())
230 } else {
231 Err(anyhow!(
232 "The service `{}` already has been started.",
233 S::NAME
234 ))
235 }
236 }
237
238 async fn start_and_await(&self) -> anyhow::Result<State> {
239 let start = self.state.subscribe().into();
240 self.start()?;
241 self._await_start_or_stop(start).await
242 }
243
244 async fn await_start_or_stop(&self) -> anyhow::Result<State> {
245 let start = self.state.subscribe().into();
246 self._await_start_or_stop(start).await
247 }
248
249 fn stop(&self) -> bool {
250 self.state.send_if_modified(|state| {
251 if state.not_started() || state.starting() || state.started() {
252 *state = State::Stopping;
253 true
254 } else {
255 false
256 }
257 })
258 }
259
260 async fn stop_and_await(&self) -> anyhow::Result<State> {
261 let stop = self.state.subscribe().into();
262 self.stop();
263 self._await_stop(stop).await
264 }
265
266 async fn await_stop(&self) -> anyhow::Result<State> {
267 let stop = self.state.subscribe().into();
268 self._await_stop(stop).await
269 }
270
271 fn state(&self) -> State {
272 self.state.borrow().clone()
273 }
274
275 fn state_watcher(&self) -> StateWatcher {
276 self.state.subscribe().into()
277 }
278}
279
280#[tracing::instrument(skip_all, fields(service = S::NAME))]
281fn initialize_loop<S>(
283 service: S,
284 params: S::TaskParams,
285 metric: FuturesMetrics,
286) -> Shared<watch::Sender<State>>
287where
288 S: RunnableService + 'static,
289{
290 let (sender, _) = watch::channel(State::NotStarted);
291 let state = Shared::new(sender);
292 let stop_sender = state.clone();
293 tokio::task::spawn(
295 async move {
296 tracing::debug!("running");
297 let run = std::panic::AssertUnwindSafe(run(
298 service,
299 stop_sender.clone(),
300 params,
301 metric,
302 ));
303 tracing::debug!("awaiting run");
304 let result = run.catch_unwind().await;
305
306 let stopped_state = if let Err(e) = result {
307 let panic_information = panic_to_string(e);
308 State::StoppedWithError(panic_information)
309 } else {
310 State::Stopped
311 };
312
313 tracing::debug!("shutting down {:?}", stopped_state);
314
315 let _ = stop_sender.send_if_modified(|state| {
316 if !state.stopped() {
317 *state = stopped_state.clone();
318 tracing::debug!("Wasn't stopped, so sent stop.");
319 true
320 } else {
321 tracing::debug!("Was already stopped.");
322 false
323 }
324 });
325
326 tracing::info!("The service {} is shut down", S::NAME);
327
328 if let State::StoppedWithError(err) = stopped_state {
329 std::panic::resume_unwind(Box::new(err));
330 }
331 }
332 .in_current_span(),
333 );
334 state
335}
336
337async fn run<S>(
339 service: S,
340 sender: Shared<watch::Sender<State>>,
341 params: S::TaskParams,
342 metric: FuturesMetrics,
343) where
344 S: RunnableService + 'static,
345{
346 let mut state: StateWatcher = sender.subscribe().into();
347 if state.borrow_and_update().not_started() {
348 state.changed().await.expect("The service is destroyed");
350 }
351
352 if !state.borrow().starting() {
354 return;
355 }
356
357 tracing::info!("Starting {} service", S::NAME);
359 let mut task = service
360 .into_task(&state, params)
361 .await
362 .expect("The initialization of the service failed");
363
364 sender.send_if_modified(|s| {
365 if s.starting() {
366 *s = State::Started;
367 true
368 } else {
369 false
370 }
371 });
372
373 let got_panic = run_task(&mut task, state, &metric).await;
374
375 let got_panic = shutdown_task(S::NAME, task, got_panic).await;
376
377 if let Some(panic) = got_panic {
378 std::panic::resume_unwind(panic)
379 }
380}
381
382async fn run_task<S: RunnableTask>(
383 task: &mut S,
384 mut state: StateWatcher,
385 metric: &FuturesMetrics,
386) -> Option<Box<dyn Any + Send>> {
387 let mut got_panic = None;
388
389 while state.borrow_and_update().started() {
390 let tracked_task = FutureTracker::new(task.run(&mut state));
391 let task = std::panic::AssertUnwindSafe(tracked_task);
392 let panic_result = task.catch_unwind().await;
393
394 if let Err(panic) = panic_result {
395 tracing::debug!("got a panic");
396 got_panic = Some(panic);
397 break;
398 }
399
400 let tracked_result = panic_result.expect("Checked the panic above");
401 let result = tracked_result.extract(metric);
402
403 match result {
404 TaskNextAction::Continue => {
405 tracing::debug!("run loop");
406 }
407 TaskNextAction::Stop => {
408 tracing::debug!("stopping");
409 break;
410 }
411 TaskNextAction::ErrorContinue(e) => {
412 let e: &dyn std::error::Error = &*e;
413 tracing::error!(e);
414 }
415 }
416 }
417 got_panic
418}
419
420async fn shutdown_task<S>(
421 name: &str,
422 task: S,
423 mut got_panic: Option<Box<dyn Any + Send>>,
424) -> Option<Box<dyn Any + Send>>
425where
426 S: RunnableTask,
427{
428 tracing::info!("Shutting down {} service", name);
429 let shutdown = std::panic::AssertUnwindSafe(task.shutdown());
430 match shutdown.catch_unwind().await {
431 Ok(Ok(_)) => {}
432 Ok(Err(e)) => {
433 tracing::error!("Got an error during shutdown of the task: {e}");
434 }
435 Err(e) => {
436 if got_panic.is_some() {
437 let panic_information = panic_to_string(e);
438 tracing::error!(
439 "Go a panic during execution and shutdown of the task. \
440 The error during shutdown: {panic_information}"
441 );
442 } else {
443 got_panic = Some(e);
444 }
445 }
446 }
447 got_panic
448}
449
450fn panic_to_string(e: Box<dyn core::any::Any + Send>) -> String {
451 match e.downcast::<String>() {
452 Ok(v) => *v,
453 Err(e) => match e.downcast::<&str>() {
454 Ok(v) => v.to_string(),
455 _ => "Unknown Source of Error".to_owned(),
456 },
457 }
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 mockall::mock! {
465 Service {}
466
467 #[async_trait::async_trait]
468 impl RunnableService for Service {
469 const NAME: &'static str = "MockService";
470
471 type SharedData = EmptyShared;
472 type Task = MockTask;
473 type TaskParams = ();
474
475 fn shared_data(&self) -> EmptyShared;
476
477 async fn into_task(self, state: &StateWatcher, params: <MockService as RunnableService>::TaskParams) -> anyhow::Result<MockTask>;
478 }
479 }
480
481 mockall::mock! {
482 Task {}
483
484 impl RunnableTask for Task {
485 fn run(
486 &mut self,
487 state: &mut StateWatcher
488 ) -> impl core::future::Future<Output = TaskNextAction> + Send;
489
490 async fn shutdown(self) -> anyhow::Result<()>;
491 }
492 }
493
494 impl MockService {
495 fn new_empty() -> Self {
496 let mut mock = MockService::default();
497 mock.expect_shared_data().returning(|| EmptyShared);
498 mock.expect_into_task().returning(|_, _| {
499 let mut mock = MockTask::default();
500 mock.expect_run().returning(|watcher| {
501 let mut watcher = watcher.clone();
502 Box::pin(async move {
503 watcher.while_started().await.unwrap();
504 TaskNextAction::Stop
505 })
506 });
507 mock.expect_shutdown().times(1).returning(|| Ok(()));
508 Ok(mock)
509 });
510 mock
511 }
512 }
513
514 #[tokio::test]
515 async fn start_and_await_stop_and_await_works() {
516 let service = ServiceRunner::new(MockService::new_empty());
517 let state = service.start_and_await().await.unwrap();
518 assert!(state.started());
519 let state = service.stop_and_await().await.unwrap();
520 assert!(matches!(state, State::Stopped));
521 }
522
523 #[tokio::test]
524 async fn double_start_fails() {
525 let service = ServiceRunner::new(MockService::new_empty());
526 assert!(service.start().is_ok());
527 assert!(service.start().is_err());
528 }
529
530 #[tokio::test]
531 async fn double_start_and_await_fails() {
532 let service = ServiceRunner::new(MockService::new_empty());
533 assert!(service.start_and_await().await.is_ok());
534 assert!(service.start_and_await().await.is_err());
535 }
536
537 #[tokio::test]
538 async fn stop_without_start() {
539 let service = ServiceRunner::new(MockService::new_empty());
540 service.stop_and_await().await.unwrap();
541 assert!(matches!(service.state(), State::Stopped));
542 }
543
544 #[tokio::test]
545 async fn panic_during_run() {
546 let mut mock = MockService::default();
547 mock.expect_shared_data().returning(|| EmptyShared);
548 mock.expect_into_task().returning(|_, _| {
549 let mut mock = MockTask::default();
550 mock.expect_run().returning(|_| panic!("Should fail"));
551 mock.expect_shutdown().times(1).returning(|| Ok(()));
552 Ok(mock)
553 });
554 let service = ServiceRunner::new(mock);
555 let state = service.start_and_await().await.unwrap();
556 assert!(matches!(state, State::StoppedWithError(s) if s.contains("Should fail")));
557
558 let state = service.await_stop().await.unwrap();
559 assert!(matches!(state, State::StoppedWithError(s) if s.contains("Should fail")));
560 }
561
562 #[tokio::test]
563 async fn panic_during_shutdown() {
564 let mut mock = MockService::default();
565 mock.expect_shared_data().returning(|| EmptyShared);
566 mock.expect_into_task().returning(|_, _| {
567 let mut mock = MockTask::default();
568 mock.expect_run()
569 .returning(|_| Box::pin(async move { TaskNextAction::Stop }));
570 mock.expect_shutdown()
571 .times(1)
572 .returning(|| panic!("Shutdown should fail"));
573 Ok(mock)
574 });
575 let service = ServiceRunner::new(mock);
576 let state = service.start_and_await().await.unwrap();
577 assert!(
578 matches!(state, State::StoppedWithError(s) if s.contains("Shutdown should fail"))
579 );
580
581 let state = service.await_stop().await.unwrap();
582 assert!(
583 matches!(state, State::StoppedWithError(s) if s.contains("Shutdown should fail"))
584 );
585 }
586
587 #[tokio::test]
588 async fn double_await_stop_works() {
589 let service = ServiceRunner::new(MockService::new_empty());
590 service.start().unwrap();
591 service.stop();
592
593 let state = service.await_stop().await.unwrap();
594 assert!(matches!(state, State::Stopped));
595 let state = service.await_stop().await.unwrap();
596 assert!(matches!(state, State::Stopped));
597 }
598
599 #[tokio::test]
600 async fn double_stop_and_await_works() {
601 let service = ServiceRunner::new(MockService::new_empty());
602 service.start().unwrap();
603
604 let state = service.stop_and_await().await.unwrap();
605 assert!(matches!(state, State::Stopped));
606 let state = service.stop_and_await().await.unwrap();
607 assert!(matches!(state, State::Stopped));
608 }
609
610 #[tokio::test]
611 async fn stop_unused_service() {
612 let mut receiver;
613 {
614 let service = ServiceRunner::new(MockService::new_empty());
615 service.start().unwrap();
616 receiver = service.state.subscribe();
617 }
618
619 receiver.changed().await.unwrap();
620 assert!(matches!(receiver.borrow().clone(), State::Stopping));
621 receiver.changed().await.unwrap();
622 assert!(matches!(receiver.borrow().clone(), State::Stopped));
623 }
624}