async_std/task/
task_local.rs

1use std::cell::UnsafeCell;
2use std::error::Error;
3use std::fmt;
4use std::sync::atomic::{AtomicU32, Ordering};
5
6use crate::task::TaskLocalsWrapper;
7
8/// The key for accessing a task-local value.
9///
10/// Every task-local value is lazily initialized on first access and destroyed when the task
11/// completes.
12#[derive(Debug)]
13pub struct LocalKey<T: Send + 'static> {
14    #[doc(hidden)]
15    pub __init: fn() -> T,
16
17    #[doc(hidden)]
18    pub __key: AtomicU32,
19}
20
21impl<T: Send + 'static> LocalKey<T> {
22    /// Gets a reference to the task-local value with this key.
23    ///
24    /// The passed closure receives a reference to the task-local value.
25    ///
26    /// The task-local value will be lazily initialized if this task has not accessed it before.
27    ///
28    /// # Panics
29    ///
30    /// This function will panic if not called within the context of a task created by
31    /// [`block_on`], [`spawn`], or [`Builder::spawn`].
32    ///
33    /// [`block_on`]: fn.block_on.html
34    /// [`spawn`]: fn.spawn.html
35    /// [`Builder::spawn`]: struct.Builder.html#method.spawn
36    ///
37    /// # Examples
38    ///
39    /// ```
40    /// #
41    /// use std::cell::Cell;
42    ///
43    /// use async_std::task;
44    /// use async_std::prelude::*;
45    ///
46    /// task_local! {
47    ///     static NUMBER: Cell<u32> = Cell::new(5);
48    /// }
49    ///
50    /// task::block_on(async {
51    ///     let v = NUMBER.with(|c| c.get());
52    ///     assert_eq!(v, 5);
53    /// });
54    /// ```
55    pub fn with<F, R>(&'static self, f: F) -> R
56    where
57        F: FnOnce(&T) -> R,
58    {
59        self.try_with(f)
60            .expect("`LocalKey::with` called outside the context of a task")
61    }
62
63    /// Attempts to get a reference to the task-local value with this key.
64    ///
65    /// The passed closure receives a reference to the task-local value.
66    ///
67    /// The task-local value will be lazily initialized if this task has not accessed it before.
68    ///
69    /// This function returns an error if not called within the context of a task created by
70    /// [`block_on`], [`spawn`], or [`Builder::spawn`].
71    ///
72    /// [`block_on`]: fn.block_on.html
73    /// [`spawn`]: fn.spawn.html
74    /// [`Builder::spawn`]: struct.Builder.html#method.spawn
75    ///
76    /// # Examples
77    ///
78    /// ```
79    /// #
80    /// use std::cell::Cell;
81    ///
82    /// use async_std::task;
83    /// use async_std::prelude::*;
84    ///
85    /// task_local! {
86    ///     static VAL: Cell<u32> = Cell::new(5);
87    /// }
88    ///
89    /// task::block_on(async {
90    ///     let v = VAL.try_with(|c| c.get());
91    ///     assert_eq!(v, Ok(5));
92    /// });
93    ///
94    /// // Returns an error because not called within the context of a task.
95    /// assert!(VAL.try_with(|c| c.get()).is_err());
96    /// ```
97    pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
98    where
99        F: FnOnce(&T) -> R,
100    {
101        TaskLocalsWrapper::get_current(|task| unsafe {
102            // Prepare the numeric key, initialization function, and the map of task-locals.
103            let key = self.key();
104            let init = || Box::new((self.__init)()) as Box<dyn Send>;
105
106            // Get the value in the map of task-locals, or initialize and insert one.
107            let value: *const dyn Send = task.locals().get_or_insert(key, init);
108
109            // Call the closure with the value passed as an argument.
110            f(&*(value as *const T))
111        })
112        .ok_or(AccessError { _private: () })
113    }
114
115    /// Returns the numeric key associated with this task-local.
116    #[inline]
117    fn key(&self) -> u32 {
118        #[cold]
119        fn init(key: &AtomicU32) -> u32 {
120            static COUNTER: AtomicU32 = AtomicU32::new(1);
121
122            let counter = COUNTER.fetch_add(1, Ordering::Relaxed);
123            if counter > u32::max_value() / 2 {
124                std::process::abort();
125            }
126
127            match key.compare_exchange(0, counter, Ordering::AcqRel, Ordering::Acquire) {
128                Ok(_) => counter,
129                Err(k) => k,
130            }
131        }
132
133        match self.__key.load(Ordering::Acquire) {
134            0 => init(&self.__key),
135            k => k,
136        }
137    }
138}
139
140/// An error returned by [`LocalKey::try_with`].
141///
142/// [`LocalKey::try_with`]: struct.LocalKey.html#method.try_with
143#[derive(Clone, Copy, Eq, PartialEq)]
144pub struct AccessError {
145    _private: (),
146}
147
148impl fmt::Debug for AccessError {
149    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
150        f.debug_struct("AccessError").finish()
151    }
152}
153
154impl fmt::Display for AccessError {
155    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156        "already destroyed or called outside the context of a task".fmt(f)
157    }
158}
159
160impl Error for AccessError {}
161
162/// A key-value entry in a map of task-locals.
163struct Entry {
164    /// Key identifying the task-local variable.
165    key: u32,
166
167    /// Value stored in this entry.
168    value: Box<dyn Send>,
169}
170
171/// A map that holds task-locals.
172pub(crate) struct LocalsMap {
173    /// A list of key-value entries sorted by the key.
174    entries: UnsafeCell<Option<Vec<Entry>>>,
175}
176
177impl LocalsMap {
178    /// Creates an empty map of task-locals.
179    pub fn new() -> LocalsMap {
180        LocalsMap {
181            entries: UnsafeCell::new(Some(Vec::new())),
182        }
183    }
184
185    /// Returns a task-local value associated with `key` or inserts one constructed by `init`.
186    #[inline]
187    pub fn get_or_insert(&self, key: u32, init: impl FnOnce() -> Box<dyn Send>) -> &dyn Send {
188        match unsafe { (*self.entries.get()).as_mut() } {
189            None => panic!("can't access task-locals while the task is being dropped"),
190            Some(entries) => {
191                let index = match entries.binary_search_by_key(&key, |e| e.key) {
192                    Ok(i) => i,
193                    Err(i) => {
194                        let value = init();
195                        entries.insert(i, Entry { key, value });
196                        i
197                    }
198                };
199                &*entries[index].value
200            }
201        }
202    }
203
204    /// Clears the map and drops all task-locals.
205    ///
206    /// This method is only safe to call at the end of the task.
207    pub unsafe fn clear(&self) {
208        // Since destructors may attempt to access task-locals, we musnt't hold a mutable reference
209        // to the `Vec` while dropping them. Instead, we first take the `Vec` out and then drop it.
210        let entries = (*self.entries.get()).take();
211        drop(entries);
212    }
213}