rama_http/service/web/
router.rs

1use std::{convert::Infallible, sync::Arc};
2
3use crate::{
4    Request, Response,
5    matcher::{HttpMatcher, MethodMatcher, UriParams},
6};
7
8use matchit::Router as MatchitRouter;
9use rama_core::{
10    Context,
11    context::Extensions,
12    matcher::Matcher,
13    service::{BoxService, Service},
14};
15use rama_http_types::{Body, StatusCode};
16
17use super::IntoEndpointService;
18
19/// A basic router that can be used to route requests to different services based on the request path.
20///
21/// This router uses `matchit::Router` to efficiently match incoming requests
22/// to predefined routes. Each route is associated with an `HttpMatcher`
23/// and a corresponding service handler.
24pub struct Router<State> {
25    routes: MatchitRouter<
26        Vec<(
27            HttpMatcher<State, Body>,
28            BoxService<State, Request, Response, Infallible>,
29        )>,
30    >,
31    not_found: Option<BoxService<State, Request, Response, Infallible>>,
32}
33
34impl<State> std::fmt::Debug for Router<State> {
35    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36        f.debug_struct("Router").finish()
37    }
38}
39
40impl<State> Router<State>
41where
42    State: Clone + Send + Sync + 'static,
43{
44    /// create a new router.
45    pub fn new() -> Self {
46        Self {
47            routes: MatchitRouter::new(),
48            not_found: None,
49        }
50    }
51
52    /// add a GET route to the router.
53    /// the path can contain parameters, e.g. `/users/{id}`.
54    /// the path can also contain a catch call, e.g. `/assets/{*path}`.
55    pub fn get<I, T>(self, path: &str, service: I) -> Self
56    where
57        I: IntoEndpointService<State, T>,
58    {
59        let matcher = HttpMatcher::method(MethodMatcher::GET);
60        self.match_route(path, matcher, service)
61    }
62
63    /// add a POST route to the router.
64    pub fn post<I, T>(self, path: &str, service: I) -> Self
65    where
66        I: IntoEndpointService<State, T>,
67    {
68        let matcher = HttpMatcher::method(MethodMatcher::POST);
69        self.match_route(path, matcher, service)
70    }
71
72    /// add a PUT route to the router.
73    pub fn put<I, T>(self, path: &str, service: I) -> Self
74    where
75        I: IntoEndpointService<State, T>,
76    {
77        let matcher = HttpMatcher::method(MethodMatcher::PUT);
78        self.match_route(path, matcher, service)
79    }
80
81    /// add a DELETE route to the router.
82    pub fn delete<I, T>(self, path: &str, service: I) -> Self
83    where
84        I: IntoEndpointService<State, T>,
85    {
86        let matcher = HttpMatcher::method(MethodMatcher::DELETE);
87        self.match_route(path, matcher, service)
88    }
89
90    /// add a PATCH route to the router.
91    pub fn patch<I, T>(self, path: &str, service: I) -> Self
92    where
93        I: IntoEndpointService<State, T>,
94    {
95        let matcher = HttpMatcher::method(MethodMatcher::PATCH);
96        self.match_route(path, matcher, service)
97    }
98
99    /// add a HEAD route to the router.
100    pub fn head<I, T>(self, path: &str, service: I) -> Self
101    where
102        I: IntoEndpointService<State, T>,
103    {
104        let matcher = HttpMatcher::method(MethodMatcher::HEAD);
105        self.match_route(path, matcher, service)
106    }
107
108    /// add a OPTIONS route to the router.
109    pub fn options<I, T>(self, path: &str, service: I) -> Self
110    where
111        I: IntoEndpointService<State, T>,
112    {
113        let matcher = HttpMatcher::method(MethodMatcher::OPTIONS);
114        self.match_route(path, matcher, service)
115    }
116
117    /// add a TRACE route to the router.
118    pub fn trace<I, T>(self, path: &str, service: I) -> Self
119    where
120        I: IntoEndpointService<State, T>,
121    {
122        let matcher = HttpMatcher::method(MethodMatcher::TRACE);
123        self.match_route(path, matcher, service)
124    }
125
126    /// add a CONNECT route to the router.
127    pub fn connect<I, T>(self, path: &str, service: I) -> Self
128    where
129        I: IntoEndpointService<State, T>,
130    {
131        let matcher = HttpMatcher::method(MethodMatcher::CONNECT);
132        self.match_route(path, matcher, service)
133    }
134
135    /// register a nested router under a prefix.
136    ///
137    /// The prefix is used to match the request path and strip it from the request URI.
138    pub fn sub<I, T>(self, prefix: &str, service: I) -> Self
139    where
140        I: IntoEndpointService<State, T>,
141    {
142        let path = format!("{}/{}", prefix.trim().trim_end_matches(['/']), "{*nest}");
143        let nested = Arc::new(service.into_endpoint_service().boxed());
144
145        let nested_router_service = NestedRouterService {
146            prefix: Arc::from(prefix),
147            nested,
148        };
149
150        self.match_route(
151            prefix,
152            HttpMatcher::custom(true),
153            nested_router_service.clone(),
154        )
155        .match_route(&path, HttpMatcher::custom(true), nested_router_service)
156    }
157
158    /// add a route to the router with it's matcher and service.
159    pub fn match_route<I, T>(
160        mut self,
161        path: &str,
162        matcher: HttpMatcher<State, Body>,
163        service: I,
164    ) -> Self
165    where
166        I: IntoEndpointService<State, T>,
167    {
168        let service = service.into_endpoint_service().boxed();
169
170        let mut path = path.trim().trim_end_matches('/');
171        if path.is_empty() {
172            path = "/"
173        }
174
175        if let Ok(matched) = self.routes.at_mut(path) {
176            matched.value.push((matcher, service));
177        } else {
178            self.routes
179                .insert(path, vec![(matcher, service)])
180                .expect("Failed to add route");
181        }
182
183        self
184    }
185
186    /// use the provided service when no route matches the request.
187    pub fn not_found<I, T>(mut self, service: I) -> Self
188    where
189        I: IntoEndpointService<State, T>,
190    {
191        self.not_found = Some(service.into_endpoint_service().boxed());
192        self
193    }
194}
195
196#[derive(Debug, Clone)]
197struct NestedRouterService<State> {
198    #[expect(unused)]
199    prefix: Arc<str>,
200    nested: Arc<BoxService<State, Request, Response, Infallible>>,
201}
202
203impl<State> Service<State, Request> for NestedRouterService<State>
204where
205    State: Clone + Send + Sync + 'static,
206{
207    type Response = Response;
208    type Error = Infallible;
209
210    async fn serve(
211        &self,
212        mut ctx: Context<State>,
213        mut req: Request,
214    ) -> Result<Self::Response, Self::Error> {
215        let params: UriParams = match ctx.remove::<UriParams>() {
216            Some(params) => {
217                let nested_path = params.get("nest").unwrap_or_default();
218
219                let filtered_params: UriParams =
220                    params.iter().filter(|(key, _)| *key != "nest").collect();
221
222                // build the nested path and update the request URI
223                let path = format!("/{}", nested_path);
224                *req.uri_mut() = path.parse().unwrap();
225
226                filtered_params
227            }
228            None => UriParams::default(),
229        };
230
231        ctx.insert(params);
232
233        self.nested.serve(ctx, req).await
234    }
235}
236
237impl<State> Default for Router<State>
238where
239    State: Clone + Send + Sync + 'static,
240{
241    fn default() -> Self {
242        Self::new()
243    }
244}
245
246impl<State> Service<State, Request> for Router<State>
247where
248    State: Clone + Send + Sync + 'static,
249{
250    type Response = Response;
251    type Error = Infallible;
252
253    async fn serve(
254        &self,
255        mut ctx: Context<State>,
256        req: Request,
257    ) -> Result<Self::Response, Self::Error> {
258        let mut ext = Extensions::new();
259
260        if let Ok(matched) = self.routes.at(req.uri().path()) {
261            let uri_params = matched.params.iter();
262
263            let params: UriParams = match ctx.remove::<UriParams>() {
264                Some(mut params) => {
265                    params.extend(uri_params);
266                    params
267                }
268                None => uri_params.collect(),
269            };
270            ctx.insert(params);
271
272            for (matcher, service) in matched.value.iter() {
273                if matcher.matches(Some(&mut ext), &ctx, &req) {
274                    ctx.extend(ext);
275                    return service.serve(ctx, req).await;
276                }
277                ext.clear();
278            }
279        }
280
281        if let Some(not_found) = &self.not_found {
282            not_found.serve(ctx, req).await
283        } else {
284            Ok(Response::builder()
285                .status(StatusCode::NOT_FOUND)
286                .body(Body::from("Not Found"))
287                .unwrap())
288        }
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use crate::matcher::UriParams;
295
296    use super::*;
297    use rama_core::service::service_fn;
298    use rama_http_types::{Body, Method, Request, StatusCode, dep::http_body_util::BodyExt};
299
300    fn root_service() -> impl Service<(), Request, Response = Response, Error = Infallible> {
301        service_fn(|_ctx, _req| async {
302            Ok(Response::builder()
303                .status(200)
304                .body(Body::from("Hello, World!"))
305                .unwrap())
306        })
307    }
308
309    fn create_user_service() -> impl Service<(), Request, Response = Response, Error = Infallible> {
310        service_fn(|_ctx, _req| async {
311            Ok(Response::builder()
312                .status(200)
313                .body(Body::from("Create User"))
314                .unwrap())
315        })
316    }
317
318    fn get_users_service() -> impl Service<(), Request, Response = Response, Error = Infallible> {
319        service_fn(|_ctx, _req| async {
320            Ok(Response::builder()
321                .status(200)
322                .body(Body::from("List Users"))
323                .unwrap())
324        })
325    }
326
327    fn get_user_service() -> impl Service<(), Request, Response = Response, Error = Infallible> {
328        service_fn(|ctx: Context<()>, _req| async move {
329            let uri_params = ctx.get::<UriParams>().unwrap();
330            let id = uri_params.get("user_id").unwrap();
331            Ok(Response::builder()
332                .status(200)
333                .body(Body::from(format!("Get User: {}", id)))
334                .unwrap())
335        })
336    }
337
338    fn delete_user_service() -> impl Service<(), Request, Response = Response, Error = Infallible> {
339        service_fn(|ctx: Context<()>, _req| async move {
340            let uri_params = ctx.get::<UriParams>().unwrap();
341            let id = uri_params.get("user_id").unwrap();
342            Ok(Response::builder()
343                .status(200)
344                .body(Body::from(format!("Delete User: {}", id)))
345                .unwrap())
346        })
347    }
348
349    fn serve_assets_service() -> impl Service<(), Request, Response = Response, Error = Infallible>
350    {
351        service_fn(|ctx: Context<()>, _req| async move {
352            let uri_params = ctx.get::<UriParams>().unwrap();
353            let path = uri_params.get("path").unwrap();
354            Ok(Response::builder()
355                .status(200)
356                .body(Body::from(format!("Serve Assets: /{}", path)))
357                .unwrap())
358        })
359    }
360
361    fn not_found_service() -> impl Service<(), Request, Response = Response, Error = Infallible> {
362        service_fn(|_ctx, _req| async {
363            Ok(Response::builder()
364                .status(StatusCode::NOT_FOUND)
365                .body(Body::from("Not Found"))
366                .unwrap())
367        })
368    }
369
370    fn get_user_order_service() -> impl Service<(), Request, Response = Response, Error = Infallible>
371    {
372        service_fn(|ctx: Context<()>, _req| async move {
373            let uri_params = ctx.get::<UriParams>().unwrap();
374            let user_id = uri_params.get("user_id").unwrap();
375            let order_id = uri_params.get("order_id").unwrap();
376            Ok(Response::builder()
377                .status(200)
378                .body(Body::from(format!(
379                    "Get Order: {} for User: {}",
380                    order_id, user_id
381                )))
382                .unwrap())
383        })
384    }
385
386    #[tokio::test]
387    async fn test_router() {
388        let router = Router::new()
389            .get("/", root_service())
390            .get("/users", get_users_service())
391            .post("/users", create_user_service())
392            .get("/users/{user_id}", get_user_service())
393            .delete("/users/{user_id}", delete_user_service())
394            .get(
395                "/users/{user_id}/orders/{order_id}",
396                get_user_order_service(),
397            )
398            .get("/assets/{*path}", serve_assets_service())
399            .not_found(not_found_service());
400
401        let cases = vec![
402            (Method::GET, "/", "Hello, World!", StatusCode::OK),
403            (Method::GET, "/users", "List Users", StatusCode::OK),
404            (Method::POST, "/users", "Create User", StatusCode::OK),
405            (Method::GET, "/users/123", "Get User: 123", StatusCode::OK),
406            (
407                Method::DELETE,
408                "/users/123",
409                "Delete User: 123",
410                StatusCode::OK,
411            ),
412            (
413                Method::GET,
414                "/users/123/orders/456",
415                "Get Order: 456 for User: 123",
416                StatusCode::OK,
417            ),
418            (
419                Method::PUT,
420                "/users/123",
421                "Not Found",
422                StatusCode::NOT_FOUND,
423            ),
424            (
425                Method::GET,
426                "/assets/css/style.css",
427                "Serve Assets: /css/style.css",
428                StatusCode::OK,
429            ),
430            (
431                Method::GET,
432                "/not-found",
433                "Not Found",
434                StatusCode::NOT_FOUND,
435            ),
436        ];
437
438        for (method, path, expected_body, expected_status) in cases {
439            let req = match method {
440                Method::GET => Request::get(path),
441                Method::POST => Request::post(path),
442                Method::PUT => Request::put(path),
443                Method::DELETE => Request::delete(path),
444                _ => panic!("Unsupported HTTP method"),
445            }
446            .body(Body::empty())
447            .unwrap();
448
449            let res = router.serve(Context::default(), req).await.unwrap();
450            assert_eq!(res.status(), expected_status);
451            let body = res.into_body().collect().await.unwrap().to_bytes();
452            assert_eq!(body, expected_body);
453        }
454    }
455
456    #[tokio::test]
457    async fn test_router_nest() {
458        let api_router = Router::new()
459            .get("/users", get_users_service())
460            .post("/users", create_user_service())
461            .delete("/users/{user_id}", delete_user_service())
462            .sub(
463                "/users/{user_id}",
464                Router::new()
465                    .get("/", get_user_service())
466                    .get("/orders/{order_id}", get_user_order_service()),
467            );
468
469        let app = Router::new()
470            .sub("/api", api_router)
471            .get("/", root_service());
472
473        let cases = vec![
474            (Method::GET, "/", "Hello, World!", StatusCode::OK),
475            (Method::GET, "/api/users", "List Users", StatusCode::OK),
476            (Method::POST, "/api/users", "Create User", StatusCode::OK),
477            (
478                Method::DELETE,
479                "/api/users/123",
480                "Delete User: 123",
481                StatusCode::OK,
482            ),
483            (
484                Method::GET,
485                "/api/users/123",
486                "Get User: 123",
487                StatusCode::OK,
488            ),
489            (
490                Method::GET,
491                "/api/users/123/orders/456",
492                "Get Order: 456 for User: 123",
493                StatusCode::OK,
494            ),
495        ];
496
497        for (method, path, expected_body, expected_status) in cases {
498            let req = match method {
499                Method::GET => Request::get(path),
500                Method::POST => Request::post(path),
501                Method::DELETE => Request::delete(path),
502                _ => panic!("Unsupported HTTP method"),
503            }
504            .body(Body::empty())
505            .unwrap();
506
507            let res = app.serve(Context::default(), req).await.unwrap();
508            assert_eq!(
509                res.status(),
510                expected_status,
511                "method: {method} ; path = {path}"
512            );
513            let body = res.into_body().collect().await.unwrap().to_bytes();
514            assert_eq!(body, expected_body, "method: {method} ; path = {path}");
515        }
516    }
517}