axum_extra/extract/cookie/
mod.rs

1//! Cookie parsing and cookie jar management.
2//!
3//! See [`CookieJar`], [`SignedCookieJar`], and [`PrivateCookieJar`] for more details.
4
5use axum::{
6    extract::FromRequestParts,
7    response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
8};
9use http::{
10    header::{COOKIE, SET_COOKIE},
11    request::Parts,
12    HeaderMap,
13};
14use std::convert::Infallible;
15
16#[cfg(feature = "cookie-private")]
17mod private;
18#[cfg(feature = "cookie-signed")]
19mod signed;
20
21#[cfg(feature = "cookie-private")]
22pub use self::private::PrivateCookieJar;
23#[cfg(feature = "cookie-signed")]
24pub use self::signed::SignedCookieJar;
25
26pub use cookie::{Cookie, Expiration, SameSite};
27
28#[cfg(any(feature = "cookie-signed", feature = "cookie-private"))]
29pub use cookie::Key;
30
31/// Extractor that grabs cookies from the request and manages the jar.
32///
33/// Note that methods like [`CookieJar::add`], [`CookieJar::remove`], etc updates the [`CookieJar`]
34/// and returns it. This value _must_ be returned from the handler as part of the response for the
35/// changes to be propagated.
36///
37/// # Example
38///
39/// ```rust
40/// use axum::{
41///     Router,
42///     routing::{post, get},
43///     response::{IntoResponse, Redirect},
44///     http::StatusCode,
45/// };
46/// use axum_extra::{
47///     TypedHeader,
48///     headers::authorization::{Authorization, Bearer},
49///     extract::cookie::{CookieJar, Cookie},
50/// };
51///
52/// async fn create_session(
53///     TypedHeader(auth): TypedHeader<Authorization<Bearer>>,
54///     jar: CookieJar,
55/// ) -> Result<(CookieJar, Redirect), StatusCode> {
56///     if let Some(session_id) = authorize_and_create_session(auth.token()).await {
57///         Ok((
58///             // the updated jar must be returned for the changes
59///             // to be included in the response
60///             jar.add(Cookie::new("session_id", session_id)),
61///             Redirect::to("/me"),
62///         ))
63///     } else {
64///         Err(StatusCode::UNAUTHORIZED)
65///     }
66/// }
67///
68/// async fn me(jar: CookieJar) -> Result<(), StatusCode> {
69///     if let Some(session_id) = jar.get("session_id") {
70///         // fetch and render user...
71///         # Ok(())
72///     } else {
73///         Err(StatusCode::UNAUTHORIZED)
74///     }
75/// }
76///
77/// async fn authorize_and_create_session(token: &str) -> Option<String> {
78///     // authorize the user and create a session...
79///     # todo!()
80/// }
81///
82/// let app = Router::new()
83///     .route("/sessions", post(create_session))
84///     .route("/me", get(me));
85/// # let app: Router = app;
86/// ```
87#[derive(Debug, Default, Clone)]
88pub struct CookieJar {
89    jar: cookie::CookieJar,
90}
91
92impl<S> FromRequestParts<S> for CookieJar
93where
94    S: Send + Sync,
95{
96    type Rejection = Infallible;
97
98    async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
99        Ok(Self::from_headers(&parts.headers))
100    }
101}
102
103fn cookies_from_request(headers: &HeaderMap) -> impl Iterator<Item = Cookie<'static>> + '_ {
104    headers
105        .get_all(COOKIE)
106        .into_iter()
107        .filter_map(|value| value.to_str().ok())
108        .flat_map(|value| value.split(';'))
109        .filter_map(|cookie| Cookie::parse_encoded(cookie.to_owned()).ok())
110}
111
112impl CookieJar {
113    /// Create a new `CookieJar` from a map of request headers.
114    ///
115    /// The cookies in `headers` will be added to the jar.
116    ///
117    /// This is intended to be used in middleware and other places where it might be difficult to
118    /// run extractors. Normally you should create `CookieJar`s through [`FromRequestParts`].
119    ///
120    /// [`FromRequestParts`]: axum::extract::FromRequestParts
121    pub fn from_headers(headers: &HeaderMap) -> Self {
122        let mut jar = cookie::CookieJar::new();
123        for cookie in cookies_from_request(headers) {
124            jar.add_original(cookie);
125        }
126        Self { jar }
127    }
128
129    /// Create a new empty `CookieJar`.
130    ///
131    /// This is intended to be used in middleware and other places where it might be difficult to
132    /// run extractors. Normally you should create `CookieJar`s through [`FromRequestParts`].
133    ///
134    /// If you need a jar that contains the headers from a request use `impl From<&HeaderMap> for
135    /// CookieJar`.
136    ///
137    /// [`FromRequestParts`]: axum::extract::FromRequestParts
138    pub fn new() -> Self {
139        Self::default()
140    }
141
142    /// Get a cookie from the jar.
143    ///
144    /// # Example
145    ///
146    /// ```rust
147    /// use axum_extra::extract::cookie::CookieJar;
148    /// use axum::response::IntoResponse;
149    ///
150    /// async fn handle(jar: CookieJar) {
151    ///     let value: Option<String> = jar
152    ///         .get("foo")
153    ///         .map(|cookie| cookie.value().to_owned());
154    /// }
155    /// ```
156    pub fn get(&self, name: &str) -> Option<&Cookie<'static>> {
157        self.jar.get(name)
158    }
159
160    /// Remove a cookie from the jar.
161    ///
162    /// # Example
163    ///
164    /// ```rust
165    /// use axum_extra::extract::cookie::{CookieJar, Cookie};
166    /// use axum::response::IntoResponse;
167    ///
168    /// async fn handle(jar: CookieJar) -> CookieJar {
169    ///     jar.remove(Cookie::from("foo"))
170    /// }
171    /// ```
172    #[must_use]
173    pub fn remove<C: Into<Cookie<'static>>>(mut self, cookie: C) -> Self {
174        self.jar.remove(cookie);
175        self
176    }
177
178    /// Add a cookie to the jar.
179    ///
180    /// The value will automatically be percent-encoded.
181    ///
182    /// # Example
183    ///
184    /// ```rust
185    /// use axum_extra::extract::cookie::{CookieJar, Cookie};
186    /// use axum::response::IntoResponse;
187    ///
188    /// async fn handle(jar: CookieJar) -> CookieJar {
189    ///     jar.add(Cookie::new("foo", "bar"))
190    /// }
191    /// ```
192    #[must_use]
193    #[allow(clippy::should_implement_trait)]
194    pub fn add<C: Into<Cookie<'static>>>(mut self, cookie: C) -> Self {
195        self.jar.add(cookie);
196        self
197    }
198
199    /// Get an iterator over all cookies in the jar.
200    pub fn iter(&self) -> impl Iterator<Item = &'_ Cookie<'static>> {
201        self.jar.iter()
202    }
203}
204
205impl IntoResponseParts for CookieJar {
206    type Error = Infallible;
207
208    fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
209        set_cookies(self.jar, res.headers_mut());
210        Ok(res)
211    }
212}
213
214impl IntoResponse for CookieJar {
215    fn into_response(self) -> Response {
216        (self, ()).into_response()
217    }
218}
219
220fn set_cookies(jar: cookie::CookieJar, headers: &mut HeaderMap) {
221    for cookie in jar.delta() {
222        if let Ok(header_value) = cookie.encoded().to_string().parse() {
223            headers.append(SET_COOKIE, header_value);
224        }
225    }
226
227    // we don't need to call `jar.reset_delta()` because `into_response_parts` consumes the cookie
228    // jar so it cannot be called multiple times.
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use axum::{body::Body, extract::FromRef, http::Request, routing::get, Router};
235    use http_body_util::BodyExt;
236    use tower::ServiceExt;
237
238    macro_rules! cookie_test {
239        ($name:ident, $jar:ty) => {
240            #[tokio::test]
241            async fn $name() {
242                async fn set_cookie(jar: $jar) -> impl IntoResponse {
243                    jar.add(Cookie::new("key", "value"))
244                }
245
246                async fn get_cookie(jar: $jar) -> impl IntoResponse {
247                    jar.get("key").unwrap().value().to_owned()
248                }
249
250                async fn remove_cookie(jar: $jar) -> impl IntoResponse {
251                    jar.remove(Cookie::from("key"))
252                }
253
254                let state = AppState {
255                    key: Key::generate(),
256                    custom_key: CustomKey(Key::generate()),
257                };
258
259                let app = Router::new()
260                    .route("/set", get(set_cookie))
261                    .route("/get", get(get_cookie))
262                    .route("/remove", get(remove_cookie))
263                    .with_state(state);
264
265                let res = app
266                    .clone()
267                    .oneshot(Request::builder().uri("/set").body(Body::empty()).unwrap())
268                    .await
269                    .unwrap();
270                let cookie_value = res.headers()["set-cookie"].to_str().unwrap();
271
272                let res = app
273                    .clone()
274                    .oneshot(
275                        Request::builder()
276                            .uri("/get")
277                            .header("cookie", cookie_value)
278                            .body(Body::empty())
279                            .unwrap(),
280                    )
281                    .await
282                    .unwrap();
283                let body = body_text(res).await;
284                assert_eq!(body, "value");
285
286                let res = app
287                    .clone()
288                    .oneshot(
289                        Request::builder()
290                            .uri("/remove")
291                            .header("cookie", cookie_value)
292                            .body(Body::empty())
293                            .unwrap(),
294                    )
295                    .await
296                    .unwrap();
297                assert!(res.headers()["set-cookie"]
298                    .to_str()
299                    .unwrap()
300                    .contains("key=;"));
301            }
302        };
303    }
304
305    cookie_test!(plaintext_cookies, CookieJar);
306
307    #[cfg(feature = "cookie-signed")]
308    cookie_test!(signed_cookies, SignedCookieJar);
309    #[cfg(feature = "cookie-signed")]
310    cookie_test!(signed_cookies_with_custom_key, SignedCookieJar<CustomKey>);
311
312    #[cfg(feature = "cookie-private")]
313    cookie_test!(private_cookies, PrivateCookieJar);
314    #[cfg(feature = "cookie-private")]
315    cookie_test!(private_cookies_with_custom_key, PrivateCookieJar<CustomKey>);
316
317    #[derive(Clone)]
318    struct AppState {
319        key: Key,
320        custom_key: CustomKey,
321    }
322
323    impl FromRef<AppState> for Key {
324        fn from_ref(state: &AppState) -> Key {
325            state.key.clone()
326        }
327    }
328
329    impl FromRef<AppState> for CustomKey {
330        fn from_ref(state: &AppState) -> CustomKey {
331            state.custom_key.clone()
332        }
333    }
334
335    #[derive(Clone)]
336    struct CustomKey(Key);
337
338    impl From<CustomKey> for Key {
339        fn from(custom: CustomKey) -> Self {
340            custom.0
341        }
342    }
343
344    #[cfg(feature = "cookie-signed")]
345    #[tokio::test]
346    async fn signed_cannot_access_invalid_cookies() {
347        async fn get_cookie(jar: SignedCookieJar) -> impl IntoResponse {
348            format!("{:?}", jar.get("key"))
349        }
350
351        let state = AppState {
352            key: Key::generate(),
353            custom_key: CustomKey(Key::generate()),
354        };
355
356        let app = Router::new()
357            .route("/get", get(get_cookie))
358            .with_state(state);
359
360        let res = app
361            .clone()
362            .oneshot(
363                Request::builder()
364                    .uri("/get")
365                    .header("cookie", "key=value")
366                    .body(Body::empty())
367                    .unwrap(),
368            )
369            .await
370            .unwrap();
371        let body = body_text(res).await;
372        assert_eq!(body, "None");
373    }
374
375    async fn body_text<B>(body: B) -> String
376    where
377        B: axum::body::HttpBody,
378        B::Error: std::fmt::Debug,
379    {
380        let bytes = body.collect().await.unwrap().to_bytes();
381        String::from_utf8(bytes.to_vec()).unwrap()
382    }
383}