axum_extra/extract/
cached.rs

1use axum::extract::{Extension, FromRequestParts};
2use http::request::Parts;
3
4/// Cache results of other extractors.
5///
6/// `Cached` wraps another extractor and caches its result in [request extensions].
7///
8/// This is useful if you have a tree of extractors that share common sub-extractors that
9/// you only want to run once, perhaps because they're expensive.
10///
11/// The cache purely type based so you can only cache one value of each type. The cache is also
12/// local to the current request and not reused across requests.
13///
14/// # Example
15///
16/// ```rust
17/// use axum_extra::extract::Cached;
18/// use axum::{
19///     extract::FromRequestParts,
20///     response::{IntoResponse, Response},
21///     http::{StatusCode, request::Parts},
22/// };
23///
24/// #[derive(Clone)]
25/// struct Session { /* ... */ }
26///
27/// impl<S> FromRequestParts<S> for Session
28/// where
29///     S: Send + Sync,
30/// {
31///     type Rejection = (StatusCode, String);
32///
33///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
34///         // load session...
35///         # unimplemented!()
36///     }
37/// }
38///
39/// struct CurrentUser { /* ... */ }
40///
41/// impl<S> FromRequestParts<S> for CurrentUser
42/// where
43///     S: Send + Sync,
44/// {
45///     type Rejection = Response;
46///
47///     async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
48///         // loading a `CurrentUser` requires first loading the `Session`
49///         //
50///         // by using `Cached<Session>` we avoid extracting the session more than
51///         // once, in case other extractors for the same request also loads the session
52///         let session: Session = Cached::<Session>::from_request_parts(parts, state)
53///             .await
54///             .map_err(|err| err.into_response())?
55///             .0;
56///
57///         // load user from session...
58///         # unimplemented!()
59///     }
60/// }
61///
62/// // handler that extracts the current user and the session
63/// //
64/// // the session will only be loaded once, even though `CurrentUser`
65/// // also loads it
66/// async fn handler(
67///     current_user: CurrentUser,
68///     // we have to use `Cached<Session>` here otherwise the
69///     // cached session would not be used
70///     Cached(session): Cached<Session>,
71/// ) {
72///     // ...
73/// }
74/// ```
75///
76/// [request extensions]: http::Extensions
77#[derive(Debug, Clone, Default)]
78pub struct Cached<T>(pub T);
79
80#[derive(Clone)]
81struct CachedEntry<T>(T);
82
83impl<S, T> FromRequestParts<S> for Cached<T>
84where
85    S: Send + Sync,
86    T: FromRequestParts<S> + Clone + Send + Sync + 'static,
87{
88    type Rejection = T::Rejection;
89
90    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
91        match Extension::<CachedEntry<T>>::from_request_parts(parts, state).await {
92            Ok(Extension(CachedEntry(value))) => Ok(Self(value)),
93            Err(_) => {
94                let value = T::from_request_parts(parts, state).await?;
95                parts.extensions.insert(CachedEntry(value.clone()));
96                Ok(Self(value))
97            }
98        }
99    }
100}
101
102axum_core::__impl_deref!(Cached);
103
104#[cfg(test)]
105mod tests {
106    use super::*;
107    use axum::{http::Request, routing::get, Router};
108    use std::{
109        convert::Infallible,
110        sync::atomic::{AtomicU32, Ordering},
111        time::Instant,
112    };
113
114    #[tokio::test]
115    async fn works() {
116        static COUNTER: AtomicU32 = AtomicU32::new(0);
117
118        #[derive(Clone, Debug, PartialEq, Eq)]
119        struct Extractor(Instant);
120
121        impl<S> FromRequestParts<S> for Extractor
122        where
123            S: Send + Sync,
124        {
125            type Rejection = Infallible;
126
127            async fn from_request_parts(
128                _parts: &mut Parts,
129                _state: &S,
130            ) -> Result<Self, Self::Rejection> {
131                COUNTER.fetch_add(1, Ordering::SeqCst);
132                Ok(Self(Instant::now()))
133            }
134        }
135
136        let (mut parts, _) = Request::new(()).into_parts();
137
138        let first = Cached::<Extractor>::from_request_parts(&mut parts, &())
139            .await
140            .unwrap()
141            .0;
142        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
143
144        let second = Cached::<Extractor>::from_request_parts(&mut parts, &())
145            .await
146            .unwrap()
147            .0;
148        assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
149
150        assert_eq!(first, second);
151    }
152
153    // Not a #[test], we just want to know this compiles
154    async fn _last_handler_argument() {
155        async fn handler(_: http::Method, _: Cached<http::HeaderMap>) {}
156        let _r: Router = Router::new().route("/", get(handler));
157    }
158}