axum_extra/extract/cached.rs
1use axum::extract::{Extension, FromRequestParts};
2use http::request::Parts;
3
4/// Cache results of other extractors.
5///
6/// `Cached` wraps another extractor and caches its result in [request extensions].
7///
8/// This is useful if you have a tree of extractors that share common sub-extractors that
9/// you only want to run once, perhaps because they're expensive.
10///
11/// The cache purely type based so you can only cache one value of each type. The cache is also
12/// local to the current request and not reused across requests.
13///
14/// # Example
15///
16/// ```rust
17/// use axum_extra::extract::Cached;
18/// use axum::{
19/// extract::FromRequestParts,
20/// response::{IntoResponse, Response},
21/// http::{StatusCode, request::Parts},
22/// };
23///
24/// #[derive(Clone)]
25/// struct Session { /* ... */ }
26///
27/// impl<S> FromRequestParts<S> for Session
28/// where
29/// S: Send + Sync,
30/// {
31/// type Rejection = (StatusCode, String);
32///
33/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
34/// // load session...
35/// # unimplemented!()
36/// }
37/// }
38///
39/// struct CurrentUser { /* ... */ }
40///
41/// impl<S> FromRequestParts<S> for CurrentUser
42/// where
43/// S: Send + Sync,
44/// {
45/// type Rejection = Response;
46///
47/// async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
48/// // loading a `CurrentUser` requires first loading the `Session`
49/// //
50/// // by using `Cached<Session>` we avoid extracting the session more than
51/// // once, in case other extractors for the same request also loads the session
52/// let session: Session = Cached::<Session>::from_request_parts(parts, state)
53/// .await
54/// .map_err(|err| err.into_response())?
55/// .0;
56///
57/// // load user from session...
58/// # unimplemented!()
59/// }
60/// }
61///
62/// // handler that extracts the current user and the session
63/// //
64/// // the session will only be loaded once, even though `CurrentUser`
65/// // also loads it
66/// async fn handler(
67/// current_user: CurrentUser,
68/// // we have to use `Cached<Session>` here otherwise the
69/// // cached session would not be used
70/// Cached(session): Cached<Session>,
71/// ) {
72/// // ...
73/// }
74/// ```
75///
76/// [request extensions]: http::Extensions
77#[derive(Debug, Clone, Default)]
78pub struct Cached<T>(pub T);
79
80#[derive(Clone)]
81struct CachedEntry<T>(T);
82
83impl<S, T> FromRequestParts<S> for Cached<T>
84where
85 S: Send + Sync,
86 T: FromRequestParts<S> + Clone + Send + Sync + 'static,
87{
88 type Rejection = T::Rejection;
89
90 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
91 match Extension::<CachedEntry<T>>::from_request_parts(parts, state).await {
92 Ok(Extension(CachedEntry(value))) => Ok(Self(value)),
93 Err(_) => {
94 let value = T::from_request_parts(parts, state).await?;
95 parts.extensions.insert(CachedEntry(value.clone()));
96 Ok(Self(value))
97 }
98 }
99 }
100}
101
102axum_core::__impl_deref!(Cached);
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use axum::{http::Request, routing::get, Router};
108 use std::{
109 convert::Infallible,
110 sync::atomic::{AtomicU32, Ordering},
111 time::Instant,
112 };
113
114 #[tokio::test]
115 async fn works() {
116 static COUNTER: AtomicU32 = AtomicU32::new(0);
117
118 #[derive(Clone, Debug, PartialEq, Eq)]
119 struct Extractor(Instant);
120
121 impl<S> FromRequestParts<S> for Extractor
122 where
123 S: Send + Sync,
124 {
125 type Rejection = Infallible;
126
127 async fn from_request_parts(
128 _parts: &mut Parts,
129 _state: &S,
130 ) -> Result<Self, Self::Rejection> {
131 COUNTER.fetch_add(1, Ordering::SeqCst);
132 Ok(Self(Instant::now()))
133 }
134 }
135
136 let (mut parts, _) = Request::new(()).into_parts();
137
138 let first = Cached::<Extractor>::from_request_parts(&mut parts, &())
139 .await
140 .unwrap()
141 .0;
142 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
143
144 let second = Cached::<Extractor>::from_request_parts(&mut parts, &())
145 .await
146 .unwrap()
147 .0;
148 assert_eq!(COUNTER.load(Ordering::SeqCst), 1);
149
150 assert_eq!(first, second);
151 }
152
153 // Not a #[test], we just want to know this compiles
154 async fn _last_handler_argument() {
155 async fn handler(_: http::Method, _: Cached<http::HeaderMap>) {}
156 let _r: Router = Router::new().route("/", get(handler));
157 }
158}