madsim_real_tokio/task/
task_local.rs

1use pin_project_lite::pin_project;
2use std::cell::RefCell;
3use std::error::Error;
4use std::future::Future;
5use std::marker::PhantomPinned;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::{fmt, mem, thread};
9
10/// Declares a new task-local key of type [`tokio::task::LocalKey`].
11///
12/// # Syntax
13///
14/// The macro wraps any number of static declarations and makes them local to the current task.
15/// Publicity and attributes for each static is preserved. For example:
16///
17/// # Examples
18///
19/// ```
20/// # use tokio::task_local;
21/// task_local! {
22///     pub static ONE: u32;
23///
24///     #[allow(unused)]
25///     static TWO: f32;
26/// }
27/// # fn main() {}
28/// ```
29///
30/// See [`LocalKey` documentation][`tokio::task::LocalKey`] for more
31/// information.
32///
33/// [`tokio::task::LocalKey`]: struct@crate::task::LocalKey
34#[macro_export]
35#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
36macro_rules! task_local {
37     // empty (base case for the recursion)
38    () => {};
39
40    ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => {
41        $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
42        $crate::task_local!($($rest)*);
43    };
44
45    ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => {
46        $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
47    }
48}
49
50#[doc(hidden)]
51#[macro_export]
52macro_rules! __task_local_inner {
53    ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => {
54        $(#[$attr])*
55        $vis static $name: $crate::task::LocalKey<$t> = {
56            std::thread_local! {
57                static __KEY: std::cell::RefCell<Option<$t>> = const { std::cell::RefCell::new(None) };
58            }
59
60            $crate::task::LocalKey { inner: __KEY }
61        };
62    };
63}
64
65/// A key for task-local data.
66///
67/// This type is generated by the [`task_local!`] macro.
68///
69/// Unlike [`std::thread::LocalKey`], `tokio::task::LocalKey` will
70/// _not_ lazily initialize the value on first access. Instead, the
71/// value is first initialized when the future containing
72/// the task-local is first polled by a futures executor, like Tokio.
73///
74/// # Examples
75///
76/// ```
77/// # async fn dox() {
78/// tokio::task_local! {
79///     static NUMBER: u32;
80/// }
81///
82/// NUMBER.scope(1, async move {
83///     assert_eq!(NUMBER.get(), 1);
84/// }).await;
85///
86/// NUMBER.scope(2, async move {
87///     assert_eq!(NUMBER.get(), 2);
88///
89///     NUMBER.scope(3, async move {
90///         assert_eq!(NUMBER.get(), 3);
91///     }).await;
92/// }).await;
93/// # }
94/// ```
95///
96/// [`std::thread::LocalKey`]: struct@std::thread::LocalKey
97/// [`task_local!`]: ../macro.task_local.html
98#[cfg_attr(docsrs, doc(cfg(feature = "rt")))]
99pub struct LocalKey<T: 'static> {
100    #[doc(hidden)]
101    pub inner: thread::LocalKey<RefCell<Option<T>>>,
102}
103
104impl<T: 'static> LocalKey<T> {
105    /// Sets a value `T` as the task-local value for the future `F`.
106    ///
107    /// On completion of `scope`, the task-local will be dropped.
108    ///
109    /// ### Panics
110    ///
111    /// If you poll the returned future inside a call to [`with`] or
112    /// [`try_with`] on the same `LocalKey`, then the call to `poll` will panic.
113    ///
114    /// ### Examples
115    ///
116    /// ```
117    /// # async fn dox() {
118    /// tokio::task_local! {
119    ///     static NUMBER: u32;
120    /// }
121    ///
122    /// NUMBER.scope(1, async move {
123    ///     println!("task local value: {}", NUMBER.get());
124    /// }).await;
125    /// # }
126    /// ```
127    ///
128    /// [`with`]: fn@Self::with
129    /// [`try_with`]: fn@Self::try_with
130    pub fn scope<F>(&'static self, value: T, f: F) -> TaskLocalFuture<T, F>
131    where
132        F: Future,
133    {
134        TaskLocalFuture {
135            local: self,
136            slot: Some(value),
137            future: Some(f),
138            _pinned: PhantomPinned,
139        }
140    }
141
142    /// Sets a value `T` as the task-local value for the closure `F`.
143    ///
144    /// On completion of `sync_scope`, the task-local will be dropped.
145    ///
146    /// ### Panics
147    ///
148    /// This method panics if called inside a call to [`with`] or [`try_with`]
149    /// on the same `LocalKey`.
150    ///
151    /// ### Examples
152    ///
153    /// ```
154    /// # async fn dox() {
155    /// tokio::task_local! {
156    ///     static NUMBER: u32;
157    /// }
158    ///
159    /// NUMBER.sync_scope(1, || {
160    ///     println!("task local value: {}", NUMBER.get());
161    /// });
162    /// # }
163    /// ```
164    ///
165    /// [`with`]: fn@Self::with
166    /// [`try_with`]: fn@Self::try_with
167    #[track_caller]
168    pub fn sync_scope<F, R>(&'static self, value: T, f: F) -> R
169    where
170        F: FnOnce() -> R,
171    {
172        let mut value = Some(value);
173        match self.scope_inner(&mut value, f) {
174            Ok(res) => res,
175            Err(err) => err.panic(),
176        }
177    }
178
179    fn scope_inner<F, R>(&'static self, slot: &mut Option<T>, f: F) -> Result<R, ScopeInnerErr>
180    where
181        F: FnOnce() -> R,
182    {
183        struct Guard<'a, T: 'static> {
184            local: &'static LocalKey<T>,
185            slot: &'a mut Option<T>,
186        }
187
188        impl<'a, T: 'static> Drop for Guard<'a, T> {
189            fn drop(&mut self) {
190                // This should not panic.
191                //
192                // We know that the RefCell was not borrowed before the call to
193                // `scope_inner`, so the only way for this to panic is if the
194                // closure has created but not destroyed a RefCell guard.
195                // However, we never give user-code access to the guards, so
196                // there's no way for user-code to forget to destroy a guard.
197                //
198                // The call to `with` also should not panic, since the
199                // thread-local wasn't destroyed when we first called
200                // `scope_inner`, and it shouldn't have gotten destroyed since
201                // then.
202                self.local.inner.with(|inner| {
203                    let mut ref_mut = inner.borrow_mut();
204                    mem::swap(self.slot, &mut *ref_mut);
205                });
206            }
207        }
208
209        self.inner.try_with(|inner| {
210            inner
211                .try_borrow_mut()
212                .map(|mut ref_mut| mem::swap(slot, &mut *ref_mut))
213        })??;
214
215        let guard = Guard { local: self, slot };
216
217        let res = f();
218
219        drop(guard);
220
221        Ok(res)
222    }
223
224    /// Accesses the current task-local and runs the provided closure.
225    ///
226    /// # Panics
227    ///
228    /// This function will panic if the task local doesn't have a value set.
229    #[track_caller]
230    pub fn with<F, R>(&'static self, f: F) -> R
231    where
232        F: FnOnce(&T) -> R,
233    {
234        match self.try_with(f) {
235            Ok(res) => res,
236            Err(_) => panic!("cannot access a task-local storage value without setting it first"),
237        }
238    }
239
240    /// Accesses the current task-local and runs the provided closure.
241    ///
242    /// If the task-local with the associated key is not present, this
243    /// method will return an `AccessError`. For a panicking variant,
244    /// see `with`.
245    pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
246    where
247        F: FnOnce(&T) -> R,
248    {
249        // If called after the thread-local storing the task-local is destroyed,
250        // then we are outside of a closure where the task-local is set.
251        //
252        // Therefore, it is correct to return an AccessError if `try_with`
253        // returns an error.
254        let try_with_res = self.inner.try_with(|v| {
255            // This call to `borrow` cannot panic because no user-defined code
256            // runs while a `borrow_mut` call is active.
257            v.borrow().as_ref().map(f)
258        });
259
260        match try_with_res {
261            Ok(Some(res)) => Ok(res),
262            Ok(None) | Err(_) => Err(AccessError { _private: () }),
263        }
264    }
265}
266
267impl<T: Copy + 'static> LocalKey<T> {
268    /// Returns a copy of the task-local value
269    /// if the task-local value implements `Copy`.
270    ///
271    /// # Panics
272    ///
273    /// This function will panic if the task local doesn't have a value set.
274    #[track_caller]
275    pub fn get(&'static self) -> T {
276        self.with(|v| *v)
277    }
278}
279
280impl<T: 'static> fmt::Debug for LocalKey<T> {
281    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
282        f.pad("LocalKey { .. }")
283    }
284}
285
286pin_project! {
287    /// A future that sets a value `T` of a task local for the future `F` during
288    /// its execution.
289    ///
290    /// The value of the task-local must be `'static` and will be dropped on the
291    /// completion of the future.
292    ///
293    /// Created by the function [`LocalKey::scope`](self::LocalKey::scope).
294    ///
295    /// ### Examples
296    ///
297    /// ```
298    /// # async fn dox() {
299    /// tokio::task_local! {
300    ///     static NUMBER: u32;
301    /// }
302    ///
303    /// NUMBER.scope(1, async move {
304    ///     println!("task local value: {}", NUMBER.get());
305    /// }).await;
306    /// # }
307    /// ```
308    pub struct TaskLocalFuture<T, F>
309    where
310        T: 'static,
311    {
312        local: &'static LocalKey<T>,
313        slot: Option<T>,
314        #[pin]
315        future: Option<F>,
316        #[pin]
317        _pinned: PhantomPinned,
318    }
319
320    impl<T: 'static, F> PinnedDrop for TaskLocalFuture<T, F> {
321        fn drop(this: Pin<&mut Self>) {
322            let this = this.project();
323            if mem::needs_drop::<F>() && this.future.is_some() {
324                // Drop the future while the task-local is set, if possible. Otherwise
325                // the future is dropped normally when the `Option<F>` field drops.
326                let mut future = this.future;
327                let _ = this.local.scope_inner(this.slot, || {
328                    future.set(None);
329                });
330            }
331        }
332    }
333}
334
335impl<T, F> TaskLocalFuture<T, F>
336where
337    T: 'static,
338{
339    /// Returns the value stored in the task local by this `TaskLocalFuture`.
340    ///
341    /// The function returns:
342    ///
343    /// * `Some(T)` if the task local value exists.
344    /// * `None` if the task local value has already been taken.
345    ///
346    /// Note that this function attempts to take the task local value even if
347    /// the future has not yet completed. In that case, the value will no longer
348    /// be available via the task local after the call to `take_value`.
349    ///
350    /// # Examples
351    ///
352    /// ```
353    /// # async fn dox() {
354    /// tokio::task_local! {
355    ///     static KEY: u32;
356    /// }
357    ///
358    /// let fut = KEY.scope(42, async {
359    ///     // Do some async work
360    /// });
361    ///
362    /// let mut pinned = Box::pin(fut);
363    ///
364    /// // Complete the TaskLocalFuture
365    /// let _ = pinned.as_mut().await;
366    ///
367    /// // And here, we can take task local value
368    /// let value = pinned.as_mut().take_value();
369    ///
370    /// assert_eq!(value, Some(42));
371    /// # }
372    /// ```
373    pub fn take_value(self: Pin<&mut Self>) -> Option<T> {
374        let this = self.project();
375        this.slot.take()
376    }
377}
378
379impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> {
380    type Output = F::Output;
381
382    #[track_caller]
383    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
384        let this = self.project();
385        let mut future_opt = this.future;
386
387        let res = this
388            .local
389            .scope_inner(this.slot, || match future_opt.as_mut().as_pin_mut() {
390                Some(fut) => {
391                    let res = fut.poll(cx);
392                    if res.is_ready() {
393                        future_opt.set(None);
394                    }
395                    Some(res)
396                }
397                None => None,
398            });
399
400        match res {
401            Ok(Some(res)) => res,
402            Ok(None) => panic!("`TaskLocalFuture` polled after completion"),
403            Err(err) => err.panic(),
404        }
405    }
406}
407
408impl<T: 'static, F> fmt::Debug for TaskLocalFuture<T, F>
409where
410    T: fmt::Debug,
411{
412    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
413        /// Format the Option without Some.
414        struct TransparentOption<'a, T> {
415            value: &'a Option<T>,
416        }
417        impl<'a, T: fmt::Debug> fmt::Debug for TransparentOption<'a, T> {
418            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
419                match self.value.as_ref() {
420                    Some(value) => value.fmt(f),
421                    // Hitting the None branch should not be possible.
422                    None => f.pad("<missing>"),
423                }
424            }
425        }
426
427        f.debug_struct("TaskLocalFuture")
428            .field("value", &TransparentOption { value: &self.slot })
429            .finish()
430    }
431}
432
433/// An error returned by [`LocalKey::try_with`](method@LocalKey::try_with).
434#[derive(Clone, Copy, Eq, PartialEq)]
435pub struct AccessError {
436    _private: (),
437}
438
439impl fmt::Debug for AccessError {
440    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
441        f.debug_struct("AccessError").finish()
442    }
443}
444
445impl fmt::Display for AccessError {
446    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
447        fmt::Display::fmt("task-local value not set", f)
448    }
449}
450
451impl Error for AccessError {}
452
453enum ScopeInnerErr {
454    BorrowError,
455    AccessError,
456}
457
458impl ScopeInnerErr {
459    #[track_caller]
460    fn panic(&self) -> ! {
461        match self {
462            Self::BorrowError => panic!("cannot enter a task-local scope while the task-local storage is borrowed"),
463            Self::AccessError => panic!("cannot enter a task-local scope during or after destruction of the underlying thread-local"),
464        }
465    }
466}
467
468impl From<std::cell::BorrowMutError> for ScopeInnerErr {
469    fn from(_: std::cell::BorrowMutError) -> Self {
470        Self::BorrowError
471    }
472}
473
474impl From<std::thread::AccessError> for ScopeInnerErr {
475    fn from(_: std::thread::AccessError) -> Self {
476        Self::AccessError
477    }
478}