broker_tokio/task/
task_local.rs1use pin_project_lite::pin_project;
2use std::cell::RefCell;
3use std::error::Error;
4use std::future::Future;
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use std::{fmt, thread};
8
9#[macro_export]
34macro_rules! task_local {
35 () => {};
37
38 ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty; $($rest:tt)*) => {
39 $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
40 $crate::task_local!($($rest)*);
41 };
42
43 ($(#[$attr:meta])* $vis:vis static $name:ident: $t:ty) => {
44 $crate::__task_local_inner!($(#[$attr])* $vis $name, $t);
45 }
46}
47
48#[doc(hidden)]
49#[macro_export]
50macro_rules! __task_local_inner {
51 ($(#[$attr:meta])* $vis:vis $name:ident, $t:ty) => {
52 static $name: $crate::task::LocalKey<$t> = {
53 std::thread_local! {
54 static __KEY: std::cell::RefCell<Option<$t>> = std::cell::RefCell::new(None);
55 }
56
57 $crate::task::LocalKey { inner: __KEY }
58 };
59 };
60}
61
62pub struct LocalKey<T: 'static> {
94 #[doc(hidden)]
95 pub inner: thread::LocalKey<RefCell<Option<T>>>,
96}
97
98impl<T: 'static> LocalKey<T> {
99 pub async fn scope<F>(&'static self, value: T, f: F) -> F::Output
117 where
118 F: Future,
119 {
120 TaskLocalFuture {
121 local: &self,
122 slot: Some(value),
123 future: f,
124 }
125 .await
126 }
127
128 pub fn with<F, R>(&'static self, f: F) -> R
135 where
136 F: FnOnce(&T) -> R,
137 {
138 self.try_with(f).expect(
139 "cannot access a Task Local Storage value \
140 without setting it via `LocalKey::set`",
141 )
142 }
143
144 pub fn try_with<F, R>(&'static self, f: F) -> Result<R, AccessError>
150 where
151 F: FnOnce(&T) -> R,
152 {
153 self.inner.with(|v| {
154 if let Some(val) = v.borrow().as_ref() {
155 Ok(f(val))
156 } else {
157 Err(AccessError { _private: () })
158 }
159 })
160 }
161}
162
163impl<T: Copy + 'static> LocalKey<T> {
164 pub fn get(&'static self) -> T {
167 self.with(|v| *v)
168 }
169}
170
171impl<T: 'static> fmt::Debug for LocalKey<T> {
172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173 f.pad("LocalKey { .. }")
174 }
175}
176
177pin_project! {
178 struct TaskLocalFuture<T: StaticLifetime, F> {
179 local: &'static LocalKey<T>,
180 slot: Option<T>,
181 #[pin]
182 future: F,
183 }
184}
185
186impl<T: 'static, F: Future> Future for TaskLocalFuture<T, F> {
187 type Output = F::Output;
188
189 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
190 struct Guard<'a, T: 'static> {
191 local: &'static LocalKey<T>,
192 slot: &'a mut Option<T>,
193 prev: Option<T>,
194 }
195
196 impl<T> Drop for Guard<'_, T> {
197 fn drop(&mut self) {
198 let value = self.local.inner.with(|c| c.replace(self.prev.take()));
199 *self.slot = value;
200 }
201 }
202
203 let mut project = self.project();
204 let val = project.slot.take();
205
206 let prev = project.local.inner.with(|c| c.replace(val));
207
208 let _guard = Guard {
209 prev,
210 slot: &mut project.slot,
211 local: *project.local,
212 };
213
214 project.future.poll(cx)
215 }
216}
217
218trait StaticLifetime: 'static {}
220impl<T: 'static> StaticLifetime for T {}
221
222#[derive(Clone, Copy, Eq, PartialEq)]
224pub struct AccessError {
225 _private: (),
226}
227
228impl fmt::Debug for AccessError {
229 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230 f.debug_struct("AccessError").finish()
231 }
232}
233
234impl fmt::Display for AccessError {
235 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
236 fmt::Display::fmt("task-local value not set", f)
237 }
238}
239
240impl Error for AccessError {}