axum_extra/extract/cookie/
mod.rs1use axum::{
6 extract::FromRequestParts,
7 response::{IntoResponse, IntoResponseParts, Response, ResponseParts},
8};
9use http::{
10 header::{COOKIE, SET_COOKIE},
11 request::Parts,
12 HeaderMap,
13};
14use std::convert::Infallible;
15
16#[cfg(feature = "cookie-private")]
17mod private;
18#[cfg(feature = "cookie-signed")]
19mod signed;
20
21#[cfg(feature = "cookie-private")]
22pub use self::private::PrivateCookieJar;
23#[cfg(feature = "cookie-signed")]
24pub use self::signed::SignedCookieJar;
25
26pub use cookie::{Cookie, Expiration, SameSite};
27
28#[cfg(any(feature = "cookie-signed", feature = "cookie-private"))]
29pub use cookie::Key;
30
31#[derive(Debug, Default, Clone)]
88pub struct CookieJar {
89 jar: cookie::CookieJar,
90}
91
92impl<S> FromRequestParts<S> for CookieJar
93where
94 S: Send + Sync,
95{
96 type Rejection = Infallible;
97
98 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
99 Ok(Self::from_headers(&parts.headers))
100 }
101}
102
103fn cookies_from_request(headers: &HeaderMap) -> impl Iterator<Item = Cookie<'static>> + '_ {
104 headers
105 .get_all(COOKIE)
106 .into_iter()
107 .filter_map(|value| value.to_str().ok())
108 .flat_map(|value| value.split(';'))
109 .filter_map(|cookie| Cookie::parse_encoded(cookie.to_owned()).ok())
110}
111
112impl CookieJar {
113 pub fn from_headers(headers: &HeaderMap) -> Self {
122 let mut jar = cookie::CookieJar::new();
123 for cookie in cookies_from_request(headers) {
124 jar.add_original(cookie);
125 }
126 Self { jar }
127 }
128
129 pub fn new() -> Self {
139 Self::default()
140 }
141
142 pub fn get(&self, name: &str) -> Option<&Cookie<'static>> {
157 self.jar.get(name)
158 }
159
160 #[must_use]
173 pub fn remove<C: Into<Cookie<'static>>>(mut self, cookie: C) -> Self {
174 self.jar.remove(cookie);
175 self
176 }
177
178 #[must_use]
193 #[allow(clippy::should_implement_trait)]
194 pub fn add<C: Into<Cookie<'static>>>(mut self, cookie: C) -> Self {
195 self.jar.add(cookie);
196 self
197 }
198
199 pub fn iter(&self) -> impl Iterator<Item = &'_ Cookie<'static>> {
201 self.jar.iter()
202 }
203}
204
205impl IntoResponseParts for CookieJar {
206 type Error = Infallible;
207
208 fn into_response_parts(self, mut res: ResponseParts) -> Result<ResponseParts, Self::Error> {
209 set_cookies(self.jar, res.headers_mut());
210 Ok(res)
211 }
212}
213
214impl IntoResponse for CookieJar {
215 fn into_response(self) -> Response {
216 (self, ()).into_response()
217 }
218}
219
220fn set_cookies(jar: cookie::CookieJar, headers: &mut HeaderMap) {
221 for cookie in jar.delta() {
222 if let Ok(header_value) = cookie.encoded().to_string().parse() {
223 headers.append(SET_COOKIE, header_value);
224 }
225 }
226
227 }
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use axum::{body::Body, extract::FromRef, http::Request, routing::get, Router};
235 use http_body_util::BodyExt;
236 use tower::ServiceExt;
237
238 macro_rules! cookie_test {
239 ($name:ident, $jar:ty) => {
240 #[tokio::test]
241 async fn $name() {
242 async fn set_cookie(jar: $jar) -> impl IntoResponse {
243 jar.add(Cookie::new("key", "value"))
244 }
245
246 async fn get_cookie(jar: $jar) -> impl IntoResponse {
247 jar.get("key").unwrap().value().to_owned()
248 }
249
250 async fn remove_cookie(jar: $jar) -> impl IntoResponse {
251 jar.remove(Cookie::from("key"))
252 }
253
254 let state = AppState {
255 key: Key::generate(),
256 custom_key: CustomKey(Key::generate()),
257 };
258
259 let app = Router::new()
260 .route("/set", get(set_cookie))
261 .route("/get", get(get_cookie))
262 .route("/remove", get(remove_cookie))
263 .with_state(state);
264
265 let res = app
266 .clone()
267 .oneshot(Request::builder().uri("/set").body(Body::empty()).unwrap())
268 .await
269 .unwrap();
270 let cookie_value = res.headers()["set-cookie"].to_str().unwrap();
271
272 let res = app
273 .clone()
274 .oneshot(
275 Request::builder()
276 .uri("/get")
277 .header("cookie", cookie_value)
278 .body(Body::empty())
279 .unwrap(),
280 )
281 .await
282 .unwrap();
283 let body = body_text(res).await;
284 assert_eq!(body, "value");
285
286 let res = app
287 .clone()
288 .oneshot(
289 Request::builder()
290 .uri("/remove")
291 .header("cookie", cookie_value)
292 .body(Body::empty())
293 .unwrap(),
294 )
295 .await
296 .unwrap();
297 assert!(res.headers()["set-cookie"]
298 .to_str()
299 .unwrap()
300 .contains("key=;"));
301 }
302 };
303 }
304
305 cookie_test!(plaintext_cookies, CookieJar);
306
307 #[cfg(feature = "cookie-signed")]
308 cookie_test!(signed_cookies, SignedCookieJar);
309 #[cfg(feature = "cookie-signed")]
310 cookie_test!(signed_cookies_with_custom_key, SignedCookieJar<CustomKey>);
311
312 #[cfg(feature = "cookie-private")]
313 cookie_test!(private_cookies, PrivateCookieJar);
314 #[cfg(feature = "cookie-private")]
315 cookie_test!(private_cookies_with_custom_key, PrivateCookieJar<CustomKey>);
316
317 #[derive(Clone)]
318 struct AppState {
319 key: Key,
320 custom_key: CustomKey,
321 }
322
323 impl FromRef<AppState> for Key {
324 fn from_ref(state: &AppState) -> Key {
325 state.key.clone()
326 }
327 }
328
329 impl FromRef<AppState> for CustomKey {
330 fn from_ref(state: &AppState) -> CustomKey {
331 state.custom_key.clone()
332 }
333 }
334
335 #[derive(Clone)]
336 struct CustomKey(Key);
337
338 impl From<CustomKey> for Key {
339 fn from(custom: CustomKey) -> Self {
340 custom.0
341 }
342 }
343
344 #[cfg(feature = "cookie-signed")]
345 #[tokio::test]
346 async fn signed_cannot_access_invalid_cookies() {
347 async fn get_cookie(jar: SignedCookieJar) -> impl IntoResponse {
348 format!("{:?}", jar.get("key"))
349 }
350
351 let state = AppState {
352 key: Key::generate(),
353 custom_key: CustomKey(Key::generate()),
354 };
355
356 let app = Router::new()
357 .route("/get", get(get_cookie))
358 .with_state(state);
359
360 let res = app
361 .clone()
362 .oneshot(
363 Request::builder()
364 .uri("/get")
365 .header("cookie", "key=value")
366 .body(Body::empty())
367 .unwrap(),
368 )
369 .await
370 .unwrap();
371 let body = body_text(res).await;
372 assert_eq!(body, "None");
373 }
374
375 async fn body_text<B>(body: B) -> String
376 where
377 B: axum::body::HttpBody,
378 B::Error: std::fmt::Debug,
379 {
380 let bytes = body.collect().await.unwrap().to_bytes();
381 String::from_utf8(bytes.to_vec()).unwrap()
382 }
383}