broker_tokio/task/
task_local.rs

1use pin_project_lite::pin_project;
2use std::cell::RefCell;
3use std::error::Error;
4use std::future::Future;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use std::{fmt, thread};
8
9/// Declares a new task-local key of type [`tokio::task::LocalKey`].
10///
11/// # Syntax
12///
13/// The macro wraps any number of static declarations and makes them local to the current task.
14/// Publicity and attributes for each static is preserved. For example:
15///
16/// # Examples
17///
18/// ```
19/// # use tokio::task_local;
20/// task_local! {
21///     pub static ONE: u32;
22///
23///     #[allow(unused)]
24///     static TWO: f32;
25/// }
26/// # fn main() {}
27/// ```
28///
29/// See [LocalKey documentation][`tokio::task::LocalKey`] for more
30/// information.
31///
32/// [`tokio::task::LocalKey`]: ../tokio/task/struct.LocalKey.html
33#[macro_export]
34macro_rules! task_local {
35     // empty (base case for the recursion)
36    () => {};
37
38    ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => {
39        $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
40        $crate::task_local!($($rest)*);
41    };
42
43    ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => {
44        $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
45    }
46}
47
48#[doc(hidden)]
49#[macro_export]
50macro_rules! __task_local_inner {
51    ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => {
52        static $name: $crate::task::LocalKey<$t> = {
53            std::thread_local! {
54                static __KEY: std::cell::RefCell<Option<$t>> = std::cell::RefCell::new(None);
55            }
56
57            $crate::task::LocalKey { inner: __KEY }
58        };
59    };
60}
61
62/// A key for task-local data.
63///
64/// This type is generated by the `task_local!` macro.
65///
66/// Unlike [`std::thread::LocalKey`], `tokio::task::LocalKey` will
67/// _not_ lazily initialize the value on first access. Instead, the
68/// value is first initialized when the future containing
69/// the task-local is first polled by a futures executor, like Tokio.
70///
71/// # Examples
72///
73/// ```
74/// # async fn dox() {
75/// tokio::task_local! {
76///     static NUMBER: u32;
77/// }
78///
79/// NUMBER.scope(1, async move {
80///     assert_eq!(NUMBER.get(), 1);
81/// }).await;
82///
83/// NUMBER.scope(2, async move {
84///     assert_eq!(NUMBER.get(), 2);
85///
86///     NUMBER.scope(3, async move {
87///         assert_eq!(NUMBER.get(), 3);
88///     }).await;
89/// }).await;
90/// # }
91/// ```
92/// [`std::thread::LocalKey`]: https://doc.rust-lang.org/std/thread/struct.LocalKey.html
93pub struct LocalKey<T: 'static> {
94    #[doc(hidden)]
95    pub inner: thread::LocalKey<RefCell<Option<T>>>,
96}
97
98impl<T: 'static> LocalKey<T> {
99    /// Sets a value `T` as the task-local value for the future `F`.
100    ///
101    /// On completion of `scope`, the task-local will be dropped.
102    ///
103    /// ### Examples
104    ///
105    /// ```
106    /// # async fn dox() {
107    /// tokio::task_local! {
108    ///     static NUMBER: u32;
109    /// }
110    ///
111    /// NUMBER.scope(1, async move {
112    ///     println!("task local value: {}", NUMBER.get());
113    /// }).await;
114    /// # }
115    /// ```
116    pub async fn scope<F>(&'static self, value: T, f: F) -> F::Output
117    where
118        F: Future,
119    {
120        TaskLocalFuture {
121            local: &self,
122            slot: Some(value),
123            future: f,
124        }
125        .await
126    }
127
128    /// Accesses the current task-local and runs the provided closure.
129    ///
130    /// # Panics
131    ///
132    /// This function will panic if not called within the context
133    /// of a future containing a task-local with the corresponding key.
134    pub fn with<F, R>(&'static self, f: F) -> R
135    where
136        F: FnOnce(&T) -> R,
137    {
138        self.try_with(f).expect(
139            "cannot access a Task Local Storage value \
140             without setting it via `LocalKey::set`",
141        )
142    }
143
144    /// Accesses the current task-local and runs the provided closure.
145    ///
146    /// If the task-local with the accociated key is not present, this
147    /// method will return an `AccessError`. For a panicking variant,
148    /// see `with`.
149    pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
150    where
151        F: FnOnce(&T) -> R,
152    {
153        self.inner.with(|v| {
154            if let Some(val) = v.borrow().as_ref() {
155                Ok(f(val))
156            } else {
157                Err(AccessError { _private: () })
158            }
159        })
160    }
161}
162
163impl<T: Copy + 'static> LocalKey<T> {
164    /// Returns a copy of the task-local value
165    /// if the task-local value implements `Copy`.
166    pub fn get(&'static self) -> T {
167        self.with(|v| *v)
168    }
169}
170
171impl<T: 'static> fmt::Debug for LocalKey<T> {
172    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173        f.pad("LocalKey { .. }")
174    }
175}
176
177pin_project! {
178    struct TaskLocalFuture<T: StaticLifetime, F> {
179        local: &'static LocalKey<T>,
180        slot: Option<T>,
181        #[pin]
182        future: F,
183    }
184}
185
186impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> {
187    type Output = F::Output;
188
189    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
190        struct Guard<'a, T: 'static> {
191            local: &'static LocalKey<T>,
192            slot: &'a mut Option<T>,
193            prev: Option<T>,
194        }
195
196        impl<T> Drop for Guard<'_, T> {
197            fn drop(&mut self) {
198                let value = self.local.inner.with(|c| c.replace(self.prev.take()));
199                *self.slot = value;
200            }
201        }
202
203        let mut project = self.project();
204        let val = project.slot.take();
205
206        let prev = project.local.inner.with(|c| c.replace(val));
207
208        let _guard = Guard {
209            prev,
210            slot: &mut project.slot,
211            local: *project.local,
212        };
213
214        project.future.poll(cx)
215    }
216}
217
218// Required to make `pin_project` happy.
219trait StaticLifetime: 'static {}
220impl<T: 'static> StaticLifetime for T {}
221
222/// An error returned by [`LocalKey::try_with`](struct.LocalKey.html#method.try_with).
223#[derive(Clone, Copy, Eq, PartialEq)]
224pub struct AccessError {
225    _private: (),
226}
227
228impl fmt::Debug for AccessError {
229    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230        f.debug_struct("AccessError").finish()
231    }
232}
233
234impl fmt::Display for AccessError {
235    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
236        fmt::Display::fmt("task-local value not set", f)
237    }
238}
239
240impl Error for AccessError {}