axum_extra/routing/
mod.rs

1//! Additional types for defining routes.
2
3use axum::{
4    extract::{OriginalUri, Request},
5    response::{IntoResponse, Redirect, Response},
6    routing::{any, MethodRouter},
7    Router,
8};
9use http::{uri::PathAndQuery, StatusCode, Uri};
10use std::{borrow::Cow, convert::Infallible};
11use tower_service::Service;
12
13mod resource;
14
15#[cfg(feature = "typed-routing")]
16mod typed;
17
18pub use self::resource::Resource;
19
20#[cfg(feature = "typed-routing")]
21pub use self::typed::WithQueryParams;
22#[cfg(feature = "typed-routing")]
23pub use axum_macros::TypedPath;
24
25#[cfg(feature = "typed-routing")]
26pub use self::typed::{SecondElementIs, TypedPath};
27
28/// Extension trait that adds additional methods to [`Router`].
29pub trait RouterExt<S>: sealed::Sealed {
30    /// Add a typed `GET` route to the router.
31    ///
32    /// The path will be inferred from the first argument to the handler function which must
33    /// implement [`TypedPath`].
34    ///
35    /// See [`TypedPath`] for more details and examples.
36    #[cfg(feature = "typed-routing")]
37    fn typed_get<H, T, P>(self, handler: H) -> Self
38    where
39        H: axum::handler::Handler<T, S>,
40        T: SecondElementIs<P> + 'static,
41        P: TypedPath;
42
43    /// Add a typed `DELETE` route to the router.
44    ///
45    /// The path will be inferred from the first argument to the handler function which must
46    /// implement [`TypedPath`].
47    ///
48    /// See [`TypedPath`] for more details and examples.
49    #[cfg(feature = "typed-routing")]
50    fn typed_delete<H, T, P>(self, handler: H) -> Self
51    where
52        H: axum::handler::Handler<T, S>,
53        T: SecondElementIs<P> + 'static,
54        P: TypedPath;
55
56    /// Add a typed `HEAD` route to the router.
57    ///
58    /// The path will be inferred from the first argument to the handler function which must
59    /// implement [`TypedPath`].
60    ///
61    /// See [`TypedPath`] for more details and examples.
62    #[cfg(feature = "typed-routing")]
63    fn typed_head<H, T, P>(self, handler: H) -> Self
64    where
65        H: axum::handler::Handler<T, S>,
66        T: SecondElementIs<P> + 'static,
67        P: TypedPath;
68
69    /// Add a typed `OPTIONS` route to the router.
70    ///
71    /// The path will be inferred from the first argument to the handler function which must
72    /// implement [`TypedPath`].
73    ///
74    /// See [`TypedPath`] for more details and examples.
75    #[cfg(feature = "typed-routing")]
76    fn typed_options<H, T, P>(self, handler: H) -> Self
77    where
78        H: axum::handler::Handler<T, S>,
79        T: SecondElementIs<P> + 'static,
80        P: TypedPath;
81
82    /// Add a typed `PATCH` route to the router.
83    ///
84    /// The path will be inferred from the first argument to the handler function which must
85    /// implement [`TypedPath`].
86    ///
87    /// See [`TypedPath`] for more details and examples.
88    #[cfg(feature = "typed-routing")]
89    fn typed_patch<H, T, P>(self, handler: H) -> Self
90    where
91        H: axum::handler::Handler<T, S>,
92        T: SecondElementIs<P> + 'static,
93        P: TypedPath;
94
95    /// Add a typed `POST` route to the router.
96    ///
97    /// The path will be inferred from the first argument to the handler function which must
98    /// implement [`TypedPath`].
99    ///
100    /// See [`TypedPath`] for more details and examples.
101    #[cfg(feature = "typed-routing")]
102    fn typed_post<H, T, P>(self, handler: H) -> Self
103    where
104        H: axum::handler::Handler<T, S>,
105        T: SecondElementIs<P> + 'static,
106        P: TypedPath;
107
108    /// Add a typed `PUT` route to the router.
109    ///
110    /// The path will be inferred from the first argument to the handler function which must
111    /// implement [`TypedPath`].
112    ///
113    /// See [`TypedPath`] for more details and examples.
114    #[cfg(feature = "typed-routing")]
115    fn typed_put<H, T, P>(self, handler: H) -> Self
116    where
117        H: axum::handler::Handler<T, S>,
118        T: SecondElementIs<P> + 'static,
119        P: TypedPath;
120
121    /// Add a typed `TRACE` route to the router.
122    ///
123    /// The path will be inferred from the first argument to the handler function which must
124    /// implement [`TypedPath`].
125    ///
126    /// See [`TypedPath`] for more details and examples.
127    #[cfg(feature = "typed-routing")]
128    fn typed_trace<H, T, P>(self, handler: H) -> Self
129    where
130        H: axum::handler::Handler<T, S>,
131        T: SecondElementIs<P> + 'static,
132        P: TypedPath;
133
134    /// Add a typed `CONNECT` route to the router.
135    ///
136    /// The path will be inferred from the first argument to the handler function which must
137    /// implement [`TypedPath`].
138    ///
139    /// See [`TypedPath`] for more details and examples.
140    #[cfg(feature = "typed-routing")]
141    fn typed_connect<H, T, P>(self, handler: H) -> Self
142    where
143        H: axum::handler::Handler<T, S>,
144        T: SecondElementIs<P> + 'static,
145        P: TypedPath;
146
147    /// Add another route to the router with an additional "trailing slash redirect" route.
148    ///
149    /// If you add a route _without_ a trailing slash, such as `/foo`, this method will also add a
150    /// route for `/foo/` that redirects to `/foo`.
151    ///
152    /// If you add a route _with_ a trailing slash, such as `/bar/`, this method will also add a
153    /// route for `/bar` that redirects to `/bar/`.
154    ///
155    /// This is similar to what axum 0.5.x did by default, except this explicitly adds another
156    /// route, so trying to add a `/foo/` route after calling `.route_with_tsr("/foo", /* ... */)`
157    /// will result in a panic due to route overlap.
158    ///
159    /// # Example
160    ///
161    /// ```
162    /// use axum::{Router, routing::get};
163    /// use axum_extra::routing::RouterExt;
164    ///
165    /// let app = Router::new()
166    ///     // `/foo/` will redirect to `/foo`
167    ///     .route_with_tsr("/foo", get(|| async {}))
168    ///     // `/bar` will redirect to `/bar/`
169    ///     .route_with_tsr("/bar/", get(|| async {}));
170    /// # let _: Router = app;
171    /// ```
172    fn route_with_tsr(self, path: &str, method_router: MethodRouter<S>) -> Self
173    where
174        Self: Sized;
175
176    /// Add another route to the router with an additional "trailing slash redirect" route.
177    ///
178    /// This works like [`RouterExt::route_with_tsr`] but accepts any [`Service`].
179    fn route_service_with_tsr<T>(self, path: &str, service: T) -> Self
180    where
181        T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
182        T::Response: IntoResponse,
183        T::Future: Send + 'static,
184        Self: Sized;
185}
186
187impl<S> RouterExt<S> for Router<S>
188where
189    S: Clone + Send + Sync + 'static,
190{
191    #[cfg(feature = "typed-routing")]
192    fn typed_get<H, T, P>(self, handler: H) -> Self
193    where
194        H: axum::handler::Handler<T, S>,
195        T: SecondElementIs<P> + 'static,
196        P: TypedPath,
197    {
198        self.route(P::PATH, axum::routing::get(handler))
199    }
200
201    #[cfg(feature = "typed-routing")]
202    fn typed_delete<H, T, P>(self, handler: H) -> Self
203    where
204        H: axum::handler::Handler<T, S>,
205        T: SecondElementIs<P> + 'static,
206        P: TypedPath,
207    {
208        self.route(P::PATH, axum::routing::delete(handler))
209    }
210
211    #[cfg(feature = "typed-routing")]
212    fn typed_head<H, T, P>(self, handler: H) -> Self
213    where
214        H: axum::handler::Handler<T, S>,
215        T: SecondElementIs<P> + 'static,
216        P: TypedPath,
217    {
218        self.route(P::PATH, axum::routing::head(handler))
219    }
220
221    #[cfg(feature = "typed-routing")]
222    fn typed_options<H, T, P>(self, handler: H) -> Self
223    where
224        H: axum::handler::Handler<T, S>,
225        T: SecondElementIs<P> + 'static,
226        P: TypedPath,
227    {
228        self.route(P::PATH, axum::routing::options(handler))
229    }
230
231    #[cfg(feature = "typed-routing")]
232    fn typed_patch<H, T, P>(self, handler: H) -> Self
233    where
234        H: axum::handler::Handler<T, S>,
235        T: SecondElementIs<P> + 'static,
236        P: TypedPath,
237    {
238        self.route(P::PATH, axum::routing::patch(handler))
239    }
240
241    #[cfg(feature = "typed-routing")]
242    fn typed_post<H, T, P>(self, handler: H) -> Self
243    where
244        H: axum::handler::Handler<T, S>,
245        T: SecondElementIs<P> + 'static,
246        P: TypedPath,
247    {
248        self.route(P::PATH, axum::routing::post(handler))
249    }
250
251    #[cfg(feature = "typed-routing")]
252    fn typed_put<H, T, P>(self, handler: H) -> Self
253    where
254        H: axum::handler::Handler<T, S>,
255        T: SecondElementIs<P> + 'static,
256        P: TypedPath,
257    {
258        self.route(P::PATH, axum::routing::put(handler))
259    }
260
261    #[cfg(feature = "typed-routing")]
262    fn typed_trace<H, T, P>(self, handler: H) -> Self
263    where
264        H: axum::handler::Handler<T, S>,
265        T: SecondElementIs<P> + 'static,
266        P: TypedPath,
267    {
268        self.route(P::PATH, axum::routing::trace(handler))
269    }
270
271    #[cfg(feature = "typed-routing")]
272    fn typed_connect<H, T, P>(self, handler: H) -> Self
273    where
274        H: axum::handler::Handler<T, S>,
275        T: SecondElementIs<P> + 'static,
276        P: TypedPath,
277    {
278        self.route(P::PATH, axum::routing::connect(handler))
279    }
280
281    #[track_caller]
282    fn route_with_tsr(mut self, path: &str, method_router: MethodRouter<S>) -> Self
283    where
284        Self: Sized,
285    {
286        validate_tsr_path(path);
287        self = self.route(path, method_router);
288        add_tsr_redirect_route(self, path)
289    }
290
291    #[track_caller]
292    fn route_service_with_tsr<T>(mut self, path: &str, service: T) -> Self
293    where
294        T: Service<Request, Error = Infallible> + Clone + Send + Sync + 'static,
295        T::Response: IntoResponse,
296        T::Future: Send + 'static,
297        Self: Sized,
298    {
299        validate_tsr_path(path);
300        self = self.route_service(path, service);
301        add_tsr_redirect_route(self, path)
302    }
303}
304
305#[track_caller]
306fn validate_tsr_path(path: &str) {
307    if path == "/" {
308        panic!("Cannot add a trailing slash redirect route for `/`")
309    }
310}
311
312fn add_tsr_redirect_route<S>(router: Router<S>, path: &str) -> Router<S>
313where
314    S: Clone + Send + Sync + 'static,
315{
316    async fn redirect_handler(OriginalUri(uri): OriginalUri) -> Response {
317        let new_uri = map_path(uri, |path| {
318            path.strip_suffix('/')
319                .map(Cow::Borrowed)
320                .unwrap_or_else(|| Cow::Owned(format!("{path}/")))
321        });
322
323        if let Some(new_uri) = new_uri {
324            Redirect::permanent(&new_uri.to_string()).into_response()
325        } else {
326            StatusCode::BAD_REQUEST.into_response()
327        }
328    }
329
330    if let Some(path_without_trailing_slash) = path.strip_suffix('/') {
331        router.route(path_without_trailing_slash, any(redirect_handler))
332    } else {
333        router.route(&format!("{path}/"), any(redirect_handler))
334    }
335}
336
337/// Map the path of a `Uri`.
338///
339/// Returns `None` if the `Uri` cannot be put back together with the new path.
340fn map_path<F>(original_uri: Uri, f: F) -> Option<Uri>
341where
342    F: FnOnce(&str) -> Cow<'_, str>,
343{
344    let mut parts = original_uri.into_parts();
345    let path_and_query = parts.path_and_query.as_ref()?;
346
347    let new_path = f(path_and_query.path());
348
349    let new_path_and_query = if let Some(query) = &path_and_query.query() {
350        format!("{new_path}?{query}").parse::<PathAndQuery>().ok()?
351    } else {
352        new_path.parse::<PathAndQuery>().ok()?
353    };
354    parts.path_and_query = Some(new_path_and_query);
355
356    Uri::from_parts(parts).ok()
357}
358
359mod sealed {
360    pub trait Sealed {}
361    impl<S> Sealed for axum::Router<S> {}
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use crate::test_helpers::*;
368    use axum::{extract::Path, routing::get};
369
370    #[tokio::test]
371    async fn test_tsr() {
372        let app = Router::new()
373            .route_with_tsr("/foo", get(|| async {}))
374            .route_with_tsr("/bar/", get(|| async {}));
375
376        let client = TestClient::new(app);
377
378        let res = client.get("/foo").await;
379        assert_eq!(res.status(), StatusCode::OK);
380
381        let res = client.get("/foo/").await;
382        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
383        assert_eq!(res.headers()["location"], "/foo");
384
385        let res = client.get("/bar/").await;
386        assert_eq!(res.status(), StatusCode::OK);
387
388        let res = client.get("/bar").await;
389        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
390        assert_eq!(res.headers()["location"], "/bar/");
391    }
392
393    #[tokio::test]
394    async fn tsr_with_params() {
395        let app = Router::new()
396            .route_with_tsr(
397                "/a/{a}",
398                get(|Path(param): Path<String>| async move { param }),
399            )
400            .route_with_tsr(
401                "/b/{b}/",
402                get(|Path(param): Path<String>| async move { param }),
403            );
404
405        let client = TestClient::new(app);
406
407        let res = client.get("/a/foo").await;
408        assert_eq!(res.status(), StatusCode::OK);
409        assert_eq!(res.text().await, "foo");
410
411        let res = client.get("/a/foo/").await;
412        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
413        assert_eq!(res.headers()["location"], "/a/foo");
414
415        let res = client.get("/b/foo/").await;
416        assert_eq!(res.status(), StatusCode::OK);
417        assert_eq!(res.text().await, "foo");
418
419        let res = client.get("/b/foo").await;
420        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
421        assert_eq!(res.headers()["location"], "/b/foo/");
422    }
423
424    #[tokio::test]
425    async fn tsr_maintains_query_params() {
426        let app = Router::new().route_with_tsr("/foo", get(|| async {}));
427
428        let client = TestClient::new(app);
429
430        let res = client.get("/foo/?a=a").await;
431        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
432        assert_eq!(res.headers()["location"], "/foo?a=a");
433    }
434
435    #[tokio::test]
436    async fn tsr_works_in_nested_router() {
437        let app = Router::new().nest(
438            "/neko",
439            Router::new().route_with_tsr("/nyan/", get(|| async {})),
440        );
441
442        let client = TestClient::new(app);
443        let res = client.get("/neko/nyan/").await;
444        assert_eq!(res.status(), StatusCode::OK);
445
446        let res = client.get("/neko/nyan").await;
447        assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT);
448        assert_eq!(res.headers()["location"], "/neko/nyan/");
449    }
450
451    #[test]
452    #[should_panic = "Cannot add a trailing slash redirect route for `/`"]
453    fn tsr_at_root() {
454        let _: Router = Router::new().route_with_tsr("/", get(|| async move {}));
455    }
456}