axum_extra/extract/cookie/
signed.rs

1use super::{cookies_from_request, set_cookies};
2use axum::{
3    extract::{FromRef, FromRequestParts},
4    response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
5};
6use cookie::SignedJar;
7use cookie::{Cookie, Key};
8use http::{request::Parts, HeaderMap};
9use std::{convert::Infallible, fmt, marker::PhantomData};
10
11/// Extractor that grabs signed cookies from the request and manages the jar.
12///
13/// All cookies will be signed and verified with a [`Key`]. Do not use this to store private data
14/// as the values are still transmitted in plaintext.
15///
16/// Note that methods like [`SignedCookieJar::add`], [`SignedCookieJar::remove`], etc updates the
17/// [`SignedCookieJar`] and returns it. This value _must_ be returned from the handler as part of
18/// the response for the changes to be propagated.
19///
20/// # Example
21///
22/// ```rust
23/// use axum::{
24///     Router,
25///     routing::{post, get},
26///     extract::FromRef,
27///     response::{IntoResponse, Redirect},
28///     http::StatusCode,
29/// };
30/// use axum_extra::{
31///     TypedHeader,
32///     headers::authorization::{Authorization, Bearer},
33///     extract::cookie::{SignedCookieJar, Cookie, Key},
34/// };
35///
36/// async fn create_session(
37///     TypedHeader(auth): TypedHeader<Authorization<Bearer>>,
38///     jar: SignedCookieJar,
39/// ) -> Result<(SignedCookieJar, Redirect), StatusCode> {
40///     if let Some(session_id) = authorize_and_create_session(auth.token()).await {
41///         Ok((
42///             // the updated jar must be returned for the changes
43///             // to be included in the response
44///             jar.add(Cookie::new("session_id", session_id)),
45///             Redirect::to("/me"),
46///         ))
47///     } else {
48///         Err(StatusCode::UNAUTHORIZED)
49///     }
50/// }
51///
52/// async fn me(jar: SignedCookieJar) -> Result<(), StatusCode> {
53///     if let Some(session_id) = jar.get("session_id") {
54///         // fetch and render user...
55///         # Ok(())
56///     } else {
57///         Err(StatusCode::UNAUTHORIZED)
58///     }
59/// }
60///
61/// async fn authorize_and_create_session(token: &str) -> Option<String> {
62///     // authorize the user and create a session...
63///     # todo!()
64/// }
65///
66/// // our application state
67/// #[derive(Clone)]
68/// struct AppState {
69///     // that holds the key used to sign cookies
70///     key: Key,
71/// }
72///
73/// // this impl tells `SignedCookieJar` how to access the key from our state
74/// impl FromRef<AppState> for Key {
75///     fn from_ref(state: &AppState) -> Self {
76///         state.key.clone()
77///     }
78/// }
79///
80/// let state = AppState {
81///     // Generate a secure key
82///     //
83///     // You probably don't wanna generate a new one each time the app starts though
84///     key: Key::generate(),
85/// };
86///
87/// let app = Router::new()
88///     .route("/sessions", post(create_session))
89///     .route("/me", get(me))
90///     .with_state(state);
91/// # let _: axum::Router = app;
92/// ```
93/// If you have been using `Arc<AppState>` you cannot implement `FromRef<Arc<AppState>> for Key`.
94/// You can use a new type instead:
95///
96/// ```rust
97/// # use axum::extract::FromRef;
98/// # use axum_extra::extract::cookie::{PrivateCookieJar, Cookie, Key};
99/// use std::sync::Arc;
100/// use std::ops::Deref;
101///
102/// #[derive(Clone)]
103/// struct AppState(Arc<InnerState>);
104///
105/// // deref so you can still access the inner fields easily
106/// impl Deref for AppState {
107///     type Target = InnerState;
108///
109///     fn deref(&self) -> &Self::Target {
110///         &*self.0
111///     }
112/// }
113///
114/// struct InnerState {
115///     key: Key
116/// }
117///
118/// impl FromRef<AppState> for Key {
119///     fn from_ref(state: &AppState) -> Self {
120///         state.0.key.clone()
121///     }
122/// }
123/// ```
124pub struct SignedCookieJar<K = Key> {
125    jar: cookie::CookieJar,
126    key: Key,
127    // The key used to extract the key. Allows users to use multiple keys for different
128    // jars. Maybe a library wants its own key.
129    _marker: PhantomData<K>,
130}
131
132impl<K> fmt::Debug for SignedCookieJar<K> {
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        f.debug_struct("SignedCookieJar")
135            .field("jar", &self.jar)
136            .field("key", &"REDACTED")
137            .finish()
138    }
139}
140
141impl<S, K> FromRequestParts<S> for SignedCookieJar<K>
142where
143    S: Send + Sync,
144    K: FromRef<S> + Into<Key>,
145{
146    type Rejection = Infallible;
147
148    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
149        let k = K::from_ref(state);
150        let key = k.into();
151        let SignedCookieJar {
152            jar,
153            key,
154            _marker: _,
155        } = SignedCookieJar::from_headers(&parts.headers, key);
156        Ok(SignedCookieJar {
157            jar,
158            key,
159            _marker: PhantomData,
160        })
161    }
162}
163
164impl SignedCookieJar {
165    /// Create a new `SignedCookieJar` from a map of request headers.
166    ///
167    /// The valid cookies in `headers` will be added to the jar.
168    ///
169    /// This is intended to be used in middleware and other places where it might be difficult to
170    /// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequestParts`].
171    ///
172    /// [`FromRequestParts`]: axum::extract::FromRequestParts
173    pub fn from_headers(headers: &HeaderMap, key: Key) -> Self {
174        let mut jar = cookie::CookieJar::new();
175        let mut signed_jar = jar.signed_mut(&key);
176        for cookie in cookies_from_request(headers) {
177            if let Some(cookie) = signed_jar.verify(cookie) {
178                signed_jar.add_original(cookie);
179            }
180        }
181
182        Self {
183            jar,
184            key,
185            _marker: PhantomData,
186        }
187    }
188
189    /// Create a new empty `SignedCookieJar`.
190    ///
191    /// This is intended to be used in middleware and other places where it might be difficult to
192    /// run extractors. Normally you should create `SignedCookieJar`s through [`FromRequestParts`].
193    ///
194    /// [`FromRequestParts`]: axum::extract::FromRequestParts
195    pub fn new(key: Key) -> Self {
196        Self {
197            jar: Default::default(),
198            key,
199            _marker: PhantomData,
200        }
201    }
202}
203
204impl<K> SignedCookieJar<K> {
205    /// Get a cookie from the jar.
206    ///
207    /// If the cookie exists and its authenticity and integrity can be verified then it is returned
208    /// in plaintext.
209    ///
210    /// # Example
211    ///
212    /// ```rust
213    /// use axum_extra::extract::cookie::SignedCookieJar;
214    /// use axum::response::IntoResponse;
215    ///
216    /// async fn handle(jar: SignedCookieJar) {
217    ///     let value: Option<String> = jar
218    ///         .get("foo")
219    ///         .map(|cookie| cookie.value().to_owned());
220    /// }
221    /// ```
222    pub fn get(&self, name: &str) -> Option<Cookie<'static>> {
223        self.signed_jar().get(name)
224    }
225
226    /// Remove a cookie from the jar.
227    ///
228    /// # Example
229    ///
230    /// ```rust
231    /// use axum_extra::extract::cookie::{SignedCookieJar, Cookie};
232    /// use axum::response::IntoResponse;
233    ///
234    /// async fn handle(jar: SignedCookieJar) -> SignedCookieJar {
235    ///     jar.remove(Cookie::from("foo"))
236    /// }
237    /// ```
238    #[must_use]
239    pub fn remove<C: Into<Cookie<'static>>>(mut self, cookie: C) -> Self {
240        self.signed_jar_mut().remove(cookie);
241        self
242    }
243
244    /// Add a cookie to the jar.
245    ///
246    /// The value will automatically be percent-encoded.
247    ///
248    /// # Example
249    ///
250    /// ```rust
251    /// use axum_extra::extract::cookie::{SignedCookieJar, Cookie};
252    /// use axum::response::IntoResponse;
253    ///
254    /// async fn handle(jar: SignedCookieJar) -> SignedCookieJar {
255    ///     jar.add(Cookie::new("foo", "bar"))
256    /// }
257    /// ```
258    #[must_use]
259    #[allow(clippy::should_implement_trait)]
260    pub fn add<C: Into<Cookie<'static>>>(mut self, cookie: C) -> Self {
261        self.signed_jar_mut().add(cookie);
262        self
263    }
264
265    /// Verifies the authenticity and integrity of `cookie`, returning the plaintext version if
266    /// verification succeeds or `None` otherwise.
267    pub fn verify(&self, cookie: Cookie<'static>) -> Option<Cookie<'static>> {
268        self.signed_jar().verify(cookie)
269    }
270
271    /// Get an iterator over all cookies in the jar.
272    ///
273    /// Only cookies with valid authenticity and integrity are yielded by the iterator.
274    pub fn iter(&self) -> impl Iterator<Item = Cookie<'static>> + '_ {
275        SignedCookieJarIter {
276            jar: self,
277            iter: self.jar.iter(),
278        }
279    }
280
281    fn signed_jar(&self) -> SignedJar<&'_ cookie::CookieJar> {
282        self.jar.signed(&self.key)
283    }
284
285    fn signed_jar_mut(&mut self) -> SignedJar<&'_ mut cookie::CookieJar> {
286        self.jar.signed_mut(&self.key)
287    }
288}
289
290impl<K> IntoResponseParts for SignedCookieJar<K> {
291    type Error = Infallible;
292
293    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
294        set_cookies(self.jar, res.headers_mut());
295        Ok(res)
296    }
297}
298
299impl<K> IntoResponse for SignedCookieJar<K> {
300    fn into_response(self) -> Response {
301        (self, ()).into_response()
302    }
303}
304
305struct SignedCookieJarIter<'a, K> {
306    jar: &'a SignedCookieJar<K>,
307    iter: cookie::Iter<'a>,
308}
309
310impl<K> Iterator for SignedCookieJarIter<'_, K> {
311    type Item = Cookie<'static>;
312
313    fn next(&mut self) -> Option<Self::Item> {
314        loop {
315            let cookie = self.iter.next()?;
316
317            if let Some(cookie) = self.jar.get(cookie.name()) {
318                return Some(cookie);
319            }
320        }
321    }
322}
323
324impl<K> Clone for SignedCookieJar<K> {
325    fn clone(&self) -> Self {
326        Self {
327            jar: self.jar.clone(),
328            key: self.key.clone(),
329            _marker: self._marker,
330        }
331    }
332}