1use axum::{
4 extract::{OriginalUri, Request},
5 response::{IntoResponse, Redirect, Response},
6 routing::{any, MethodRouter},
7 Router,
8};
9use http::{uri::PathAndQuery, StatusCode, Uri};
10use std::{borrow::Cow, convert::Infallible};
11use tower_service::Service;
12
13mod resource;
14
15#[cfg(feature = "typed-routing")]
16mod typed;
17
18pub use self::resource::Resource;
19
20#[cfg(feature = "typed-routing")]
21pub use self::typed::WithQueryParams;
22#[cfg(feature = "typed-routing")]
23pub use axum_macros::TypedPath;
24
25#[cfg(feature = "typed-routing")]
26pub use self::typed::{SecondElementIs, TypedPath};
27
28pub trait RouterExt<S>: sealed::Sealed {
30 #[cfg(feature = "typed-routing")]
37 fn typed_get<H, T, P>(self, handler: H) -> Self
38 where
39 H: axum::handler::Handler<T, S>,
40 T: SecondElementIs<P> + 'static,
41 P: TypedPath;
42
43 #[cfg(feature = "typed-routing")]
50 fn typed_delete<H, T, P>(self, handler: H) -> Self
51 where
52 H: axum::handler::Handler<T, S>,
53 T: SecondElementIs<P> + 'static,
54 P: TypedPath;
55
56 #[cfg(feature = "typed-routing")]
63 fn typed_head<H, T, P>(self, handler: H) -> Self
64 where
65 H: axum::handler::Handler<T, S>,
66 T: SecondElementIs<P> + 'static,
67 P: TypedPath;
68
69 #[cfg(feature = "typed-routing")]
76 fn typed_options<H, T, P>(self, handler: H) -> Self
77 where
78 H: axum::handler::Handler<T, S>,
79 T: SecondElementIs<P> + 'static,
80 P: TypedPath;
81
82 #[cfg(feature = "typed-routing")]
89 fn typed_patch<H, T, P>(self, handler: H) -> Self
90 where
91 H: axum::handler::Handler<T, S>,
92 T: SecondElementIs<P> + 'static,
93 P: TypedPath;
94
95 #[cfg(feature = "typed-routing")]
102 fn typed_post<H, T, P>(self, handler: H) -> Self
103 where
104 H: axum::handler::Handler<T, S>,
105 T: SecondElementIs<P> + 'static,
106 P: TypedPath;
107
108 #[cfg(feature = "typed-routing")]
115 fn typed_put<H, T, P>(self, handler: H) -> Self
116 where
117 H: axum::handler::Handler<T, S>,
118 T: SecondElementIs<P> + 'static,
119 P: TypedPath;
120
121 #[cfg(feature = "typed-routing")]
128 fn typed_trace<H, T, P>(self, handler: H) -> Self
129 where
130 H: axum::handler::Handler<T, S>,
131 T: SecondElementIs<P> + 'static,
132 P: TypedPath;
133
134 #[cfg(feature = "typed-routing")]
141 fn typed_connect<H, T, P>(self, handler: H) -> Self
142 where
143 H: axum::handler::Handler<T, S>,
144 T: SecondElementIs<P> + 'static,
145 P: TypedPath;
146
147 fn route_with_tsr(self, path: &str, method_router: MethodRouter<S>) -> Self
173 where
174 Self: Sized;
175
176 fn route_service_with_tsr<T>(self, path: &str, service: T) -> Self
180 where
181 T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
182 T::Response: IntoResponse,
183 T::Future: Send + 'static,
184 Self: Sized;
185}
186
187impl<S> RouterExt<S> for Router<S>
188where
189 S: Clone + Send + Sync + 'static,
190{
191 #[cfg(feature = "typed-routing")]
192 fn typed_get<H, T, P>(self, handler: H) -> Self
193 where
194 H: axum::handler::Handler<T, S>,
195 T: SecondElementIs<P> + 'static,
196 P: TypedPath,
197 {
198 self.route(P::PATH, axum::routing::get(handler))
199 }
200
201 #[cfg(feature = "typed-routing")]
202 fn typed_delete<H, T, P>(self, handler: H) -> Self
203 where
204 H: axum::handler::Handler<T, S>,
205 T: SecondElementIs<P> + 'static,
206 P: TypedPath,
207 {
208 self.route(P::PATH, axum::routing::delete(handler))
209 }
210
211 #[cfg(feature = "typed-routing")]
212 fn typed_head<H, T, P>(self, handler: H) -> Self
213 where
214 H: axum::handler::Handler<T, S>,
215 T: SecondElementIs<P> + 'static,
216 P: TypedPath,
217 {
218 self.route(P::PATH, axum::routing::head(handler))
219 }
220
221 #[cfg(feature = "typed-routing")]
222 fn typed_options<H, T, P>(self, handler: H) -> Self
223 where
224 H: axum::handler::Handler<T, S>,
225 T: SecondElementIs<P> + 'static,
226 P: TypedPath,
227 {
228 self.route(P::PATH, axum::routing::options(handler))
229 }
230
231 #[cfg(feature = "typed-routing")]
232 fn typed_patch<H, T, P>(self, handler: H) -> Self
233 where
234 H: axum::handler::Handler<T, S>,
235 T: SecondElementIs<P> + 'static,
236 P: TypedPath,
237 {
238 self.route(P::PATH, axum::routing::patch(handler))
239 }
240
241 #[cfg(feature = "typed-routing")]
242 fn typed_post<H, T, P>(self, handler: H) -> Self
243 where
244 H: axum::handler::Handler<T, S>,
245 T: SecondElementIs<P> + 'static,
246 P: TypedPath,
247 {
248 self.route(P::PATH, axum::routing::post(handler))
249 }
250
251 #[cfg(feature = "typed-routing")]
252 fn typed_put<H, T, P>(self, handler: H) -> Self
253 where
254 H: axum::handler::Handler<T, S>,
255 T: SecondElementIs<P> + 'static,
256 P: TypedPath,
257 {
258 self.route(P::PATH, axum::routing::put(handler))
259 }
260
261 #[cfg(feature = "typed-routing")]
262 fn typed_trace<H, T, P>(self, handler: H) -> Self
263 where
264 H: axum::handler::Handler<T, S>,
265 T: SecondElementIs<P> + 'static,
266 P: TypedPath,
267 {
268 self.route(P::PATH, axum::routing::trace(handler))
269 }
270
271 #[cfg(feature = "typed-routing")]
272 fn typed_connect<H, T, P>(self, handler: H) -> Self
273 where
274 H: axum::handler::Handler<T, S>,
275 T: SecondElementIs<P> + 'static,
276 P: TypedPath,
277 {
278 self.route(P::PATH, axum::routing::connect(handler))
279 }
280
281 #[track_caller]
282 fn route_with_tsr(mut self, path: &str, method_router: MethodRouter<S>) -> Self
283 where
284 Self: Sized,
285 {
286 validate_tsr_path(path);
287 self = self.route(path, method_router);
288 add_tsr_redirect_route(self, path)
289 }
290
291 #[track_caller]
292 fn route_service_with_tsr<T>(mut self, path: &str, service: T) -> Self
293 where
294 T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
295 T::Response: IntoResponse,
296 T::Future: Send + 'static,
297 Self: Sized,
298 {
299 validate_tsr_path(path);
300 self = self.route_service(path, service);
301 add_tsr_redirect_route(self, path)
302 }
303}
304
305#[track_caller]
306fn validate_tsr_path(path: &str) {
307 if path == "/" {
308 panic!("Cannot add a trailing slash redirect route for `/`")
309 }
310}
311
312fn add_tsr_redirect_route<S>(router: Router<S>, path: &str) -> Router<S>
313where
314 S: Clone + Send + Sync + 'static,
315{
316 async fn redirect_handler(OriginalUri(uri): OriginalUri) -> Response {
317 let new_uri = map_path(uri, |path| {
318 path.strip_suffix('/')
319 .map(Cow::Borrowed)
320 .unwrap_or_else(|| Cow::Owned(format!("{path}/")))
321 });
322
323 if let Some(new_uri) = new_uri {
324 Redirect::permanent(&new_uri.to_string()).into_response()
325 } else {
326 StatusCode::BAD_REQUEST.into_response()
327 }
328 }
329
330 if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
331 router.route(path_without_trailing_slash, any(redirect_handler))
332 } else {
333 router.route(&format!("{path}/"), any(redirect_handler))
334 }
335}
336
337fn map_path<F>(original_uri: Uri, f: F) -> Option<Uri>
341where
342 F: FnOnce(&str) -> Cow<'_, str>,
343{
344 let mut parts = original_uri.into_parts();
345 let path_and_query = parts.path_and_query.as_ref()?;
346
347 let new_path = f(path_and_query.path());
348
349 let new_path_and_query = if let Some(query) = &path_and_query.query() {
350 format!("{new_path}?{query}").parse::<PathAndQuery>().ok()?
351 } else {
352 new_path.parse::<PathAndQuery>().ok()?
353 };
354 parts.path_and_query = Some(new_path_and_query);
355
356 Uri::from_parts(parts).ok()
357}
358
359mod sealed {
360 pub trait Sealed {}
361 impl<S> Sealed for axum::Router<S> {}
362}
363
364#[cfg(test)]
365mod tests {
366 use super::*;
367 use crate::test_helpers::*;
368 use axum::{extract::Path, routing::get};
369
370 #[tokio::test]
371 async fn test_tsr() {
372 let app = Router::new()
373 .route_with_tsr("/foo", get(|| async {}))
374 .route_with_tsr("/bar/", get(|| async {}));
375
376 let client = TestClient::new(app);
377
378 let res = client.get("/foo").await;
379 assert_eq!(res.status(), StatusCode::OK);
380
381 let res = client.get("/foo/").await;
382 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
383 assert_eq!(res.headers()["location"], "/foo");
384
385 let res = client.get("/bar/").await;
386 assert_eq!(res.status(), StatusCode::OK);
387
388 let res = client.get("/bar").await;
389 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
390 assert_eq!(res.headers()["location"], "/bar/");
391 }
392
393 #[tokio::test]
394 async fn tsr_with_params() {
395 let app = Router::new()
396 .route_with_tsr(
397 "/a/{a}",
398 get(|Path(param): Path<String>| async move { param }),
399 )
400 .route_with_tsr(
401 "/b/{b}/",
402 get(|Path(param): Path<String>| async move { param }),
403 );
404
405 let client = TestClient::new(app);
406
407 let res = client.get("/a/foo").await;
408 assert_eq!(res.status(), StatusCode::OK);
409 assert_eq!(res.text().await, "foo");
410
411 let res = client.get("/a/foo/").await;
412 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
413 assert_eq!(res.headers()["location"], "/a/foo");
414
415 let res = client.get("/b/foo/").await;
416 assert_eq!(res.status(), StatusCode::OK);
417 assert_eq!(res.text().await, "foo");
418
419 let res = client.get("/b/foo").await;
420 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
421 assert_eq!(res.headers()["location"], "/b/foo/");
422 }
423
424 #[tokio::test]
425 async fn tsr_maintains_query_params() {
426 let app = Router::new().route_with_tsr("/foo", get(|| async {}));
427
428 let client = TestClient::new(app);
429
430 let res = client.get("/foo/?a=a").await;
431 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
432 assert_eq!(res.headers()["location"], "/foo?a=a");
433 }
434
435 #[tokio::test]
436 async fn tsr_works_in_nested_router() {
437 let app = Router::new().nest(
438 "/neko",
439 Router::new().route_with_tsr("/nyan/", get(|| async {})),
440 );
441
442 let client = TestClient::new(app);
443 let res = client.get("/neko/nyan/").await;
444 assert_eq!(res.status(), StatusCode::OK);
445
446 let res = client.get("/neko/nyan").await;
447 assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
448 assert_eq!(res.headers()["location"], "/neko/nyan/");
449 }
450
451 #[test]
452 #[should_panic = "Cannot add a trailing slash redirect route for `/`"]
453 fn tsr_at_root() {
454 let _: Router = Router::new().route_with_tsr("/", get(|| async move {}));
455 }
456}