1use std::{convert::Infallible, sync::Arc};
2
3use crate::{
4 Request, Response,
5 matcher::{HttpMatcher, MethodMatcher, UriParams},
6};
7
8use matchit::Router as MatchitRouter;
9use rama_core::{
10 Context,
11 context::Extensions,
12 matcher::Matcher,
13 service::{BoxService, Service},
14};
15use rama_http_types::{Body, StatusCode};
16
17use super::IntoEndpointService;
18
19pub struct Router<State> {
25 routes: MatchitRouter<
26 Vec<(
27 HttpMatcher<State, Body>,
28 BoxService<State, Request, Response, Infallible>,
29 )>,
30 >,
31 not_found: Option<BoxService<State, Request, Response, Infallible>>,
32}
33
34impl<State> std::fmt::Debug for Router<State> {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("Router").finish()
37 }
38}
39
40impl<State> Router<State>
41where
42 State: Clone + Send + Sync + 'static,
43{
44 pub fn new() -> Self {
46 Self {
47 routes: MatchitRouter::new(),
48 not_found: None,
49 }
50 }
51
52 pub fn get<I, T>(self, path: &str, service: I) -> Self
56 where
57 I: IntoEndpointService<State, T>,
58 {
59 let matcher = HttpMatcher::method(MethodMatcher::GET);
60 self.match_route(path, matcher, service)
61 }
62
63 pub fn post<I, T>(self, path: &str, service: I) -> Self
65 where
66 I: IntoEndpointService<State, T>,
67 {
68 let matcher = HttpMatcher::method(MethodMatcher::POST);
69 self.match_route(path, matcher, service)
70 }
71
72 pub fn put<I, T>(self, path: &str, service: I) -> Self
74 where
75 I: IntoEndpointService<State, T>,
76 {
77 let matcher = HttpMatcher::method(MethodMatcher::PUT);
78 self.match_route(path, matcher, service)
79 }
80
81 pub fn delete<I, T>(self, path: &str, service: I) -> Self
83 where
84 I: IntoEndpointService<State, T>,
85 {
86 let matcher = HttpMatcher::method(MethodMatcher::DELETE);
87 self.match_route(path, matcher, service)
88 }
89
90 pub fn patch<I, T>(self, path: &str, service: I) -> Self
92 where
93 I: IntoEndpointService<State, T>,
94 {
95 let matcher = HttpMatcher::method(MethodMatcher::PATCH);
96 self.match_route(path, matcher, service)
97 }
98
99 pub fn head<I, T>(self, path: &str, service: I) -> Self
101 where
102 I: IntoEndpointService<State, T>,
103 {
104 let matcher = HttpMatcher::method(MethodMatcher::HEAD);
105 self.match_route(path, matcher, service)
106 }
107
108 pub fn options<I, T>(self, path: &str, service: I) -> Self
110 where
111 I: IntoEndpointService<State, T>,
112 {
113 let matcher = HttpMatcher::method(MethodMatcher::OPTIONS);
114 self.match_route(path, matcher, service)
115 }
116
117 pub fn trace<I, T>(self, path: &str, service: I) -> Self
119 where
120 I: IntoEndpointService<State, T>,
121 {
122 let matcher = HttpMatcher::method(MethodMatcher::TRACE);
123 self.match_route(path, matcher, service)
124 }
125
126 pub fn connect<I, T>(self, path: &str, service: I) -> Self
128 where
129 I: IntoEndpointService<State, T>,
130 {
131 let matcher = HttpMatcher::method(MethodMatcher::CONNECT);
132 self.match_route(path, matcher, service)
133 }
134
135 pub fn sub<I, T>(self, prefix: &str, service: I) -> Self
139 where
140 I: IntoEndpointService<State, T>,
141 {
142 let path = format!("{}/{}", prefix.trim().trim_end_matches(['/']), "{*nest}");
143 let nested = Arc::new(service.into_endpoint_service().boxed());
144
145 let nested_router_service = NestedRouterService {
146 prefix: Arc::from(prefix),
147 nested,
148 };
149
150 self.match_route(
151 prefix,
152 HttpMatcher::custom(true),
153 nested_router_service.clone(),
154 )
155 .match_route(&path, HttpMatcher::custom(true), nested_router_service)
156 }
157
158 pub fn match_route<I, T>(
160 mut self,
161 path: &str,
162 matcher: HttpMatcher<State, Body>,
163 service: I,
164 ) -> Self
165 where
166 I: IntoEndpointService<State, T>,
167 {
168 let service = service.into_endpoint_service().boxed();
169
170 let mut path = path.trim().trim_end_matches('/');
171 if path.is_empty() {
172 path = "/"
173 }
174
175 if let Ok(matched) = self.routes.at_mut(path) {
176 matched.value.push((matcher, service));
177 } else {
178 self.routes
179 .insert(path, vec![(matcher, service)])
180 .expect("Failed to add route");
181 }
182
183 self
184 }
185
186 pub fn not_found<I, T>(mut self, service: I) -> Self
188 where
189 I: IntoEndpointService<State, T>,
190 {
191 self.not_found = Some(service.into_endpoint_service().boxed());
192 self
193 }
194}
195
196#[derive(Debug, Clone)]
197struct NestedRouterService<State> {
198 #[expect(unused)]
199 prefix: Arc<str>,
200 nested: Arc<BoxService<State, Request, Response, Infallible>>,
201}
202
203impl<State> Service<State, Request> for NestedRouterService<State>
204where
205 State: Clone + Send + Sync + 'static,
206{
207 type Response = Response;
208 type Error = Infallible;
209
210 async fn serve(
211 &self,
212 mut ctx: Context<State>,
213 mut req: Request,
214 ) -> Result<Self::Response, Self::Error> {
215 let params: UriParams = match ctx.remove::<UriParams>() {
216 Some(params) => {
217 let nested_path = params.get("nest").unwrap_or_default();
218
219 let filtered_params: UriParams =
220 params.iter().filter(|(key, _)| *key != "nest").collect();
221
222 let path = format!("/{}", nested_path);
224 *req.uri_mut() = path.parse().unwrap();
225
226 filtered_params
227 }
228 None => UriParams::default(),
229 };
230
231 ctx.insert(params);
232
233 self.nested.serve(ctx, req).await
234 }
235}
236
237impl<State> Default for Router<State>
238where
239 State: Clone + Send + Sync + 'static,
240{
241 fn default() -> Self {
242 Self::new()
243 }
244}
245
246impl<State> Service<State, Request> for Router<State>
247where
248 State: Clone + Send + Sync + 'static,
249{
250 type Response = Response;
251 type Error = Infallible;
252
253 async fn serve(
254 &self,
255 mut ctx: Context<State>,
256 req: Request,
257 ) -> Result<Self::Response, Self::Error> {
258 let mut ext = Extensions::new();
259
260 if let Ok(matched) = self.routes.at(req.uri().path()) {
261 let uri_params = matched.params.iter();
262
263 let params: UriParams = match ctx.remove::<UriParams>() {
264 Some(mut params) => {
265 params.extend(uri_params);
266 params
267 }
268 None => uri_params.collect(),
269 };
270 ctx.insert(params);
271
272 for (matcher, service) in matched.value.iter() {
273 if matcher.matches(Some(&mut ext), &ctx, &req) {
274 ctx.extend(ext);
275 return service.serve(ctx, req).await;
276 }
277 ext.clear();
278 }
279 }
280
281 if let Some(not_found) = &self.not_found {
282 not_found.serve(ctx, req).await
283 } else {
284 Ok(Response::builder()
285 .status(StatusCode::NOT_FOUND)
286 .body(Body::from("Not Found"))
287 .unwrap())
288 }
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use crate::matcher::UriParams;
295
296 use super::*;
297 use rama_core::service::service_fn;
298 use rama_http_types::{Body, Method, Request, StatusCode, dep::http_body_util::BodyExt};
299
300 fn root_service() -> impl Service<(), Request, Response = Response, Error = Infallible> {
301 service_fn(|_ctx, _req| async {
302 Ok(Response::builder()
303 .status(200)
304 .body(Body::from("Hello, World!"))
305 .unwrap())
306 })
307 }
308
309 fn create_user_service() -> impl Service<(), Request, Response = Response, Error = Infallible> {
310 service_fn(|_ctx, _req| async {
311 Ok(Response::builder()
312 .status(200)
313 .body(Body::from("Create User"))
314 .unwrap())
315 })
316 }
317
318 fn get_users_service() -> impl Service<(), Request, Response = Response, Error = Infallible> {
319 service_fn(|_ctx, _req| async {
320 Ok(Response::builder()
321 .status(200)
322 .body(Body::from("List Users"))
323 .unwrap())
324 })
325 }
326
327 fn get_user_service() -> impl Service<(), Request, Response = Response, Error = Infallible> {
328 service_fn(|ctx: Context<()>, _req| async move {
329 let uri_params = ctx.get::<UriParams>().unwrap();
330 let id = uri_params.get("user_id").unwrap();
331 Ok(Response::builder()
332 .status(200)
333 .body(Body::from(format!("Get User: {}", id)))
334 .unwrap())
335 })
336 }
337
338 fn delete_user_service() -> impl Service<(), Request, Response = Response, Error = Infallible> {
339 service_fn(|ctx: Context<()>, _req| async move {
340 let uri_params = ctx.get::<UriParams>().unwrap();
341 let id = uri_params.get("user_id").unwrap();
342 Ok(Response::builder()
343 .status(200)
344 .body(Body::from(format!("Delete User: {}", id)))
345 .unwrap())
346 })
347 }
348
349 fn serve_assets_service() -> impl Service<(), Request, Response = Response, Error = Infallible>
350 {
351 service_fn(|ctx: Context<()>, _req| async move {
352 let uri_params = ctx.get::<UriParams>().unwrap();
353 let path = uri_params.get("path").unwrap();
354 Ok(Response::builder()
355 .status(200)
356 .body(Body::from(format!("Serve Assets: /{}", path)))
357 .unwrap())
358 })
359 }
360
361 fn not_found_service() -> impl Service<(), Request, Response = Response, Error = Infallible> {
362 service_fn(|_ctx, _req| async {
363 Ok(Response::builder()
364 .status(StatusCode::NOT_FOUND)
365 .body(Body::from("Not Found"))
366 .unwrap())
367 })
368 }
369
370 fn get_user_order_service() -> impl Service<(), Request, Response = Response, Error = Infallible>
371 {
372 service_fn(|ctx: Context<()>, _req| async move {
373 let uri_params = ctx.get::<UriParams>().unwrap();
374 let user_id = uri_params.get("user_id").unwrap();
375 let order_id = uri_params.get("order_id").unwrap();
376 Ok(Response::builder()
377 .status(200)
378 .body(Body::from(format!(
379 "Get Order: {} for User: {}",
380 order_id, user_id
381 )))
382 .unwrap())
383 })
384 }
385
386 #[tokio::test]
387 async fn test_router() {
388 let router = Router::new()
389 .get("/", root_service())
390 .get("/users", get_users_service())
391 .post("/users", create_user_service())
392 .get("/users/{user_id}", get_user_service())
393 .delete("/users/{user_id}", delete_user_service())
394 .get(
395 "/users/{user_id}/orders/{order_id}",
396 get_user_order_service(),
397 )
398 .get("/assets/{*path}", serve_assets_service())
399 .not_found(not_found_service());
400
401 let cases = vec![
402 (Method::GET, "/", "Hello, World!", StatusCode::OK),
403 (Method::GET, "/users", "List Users", StatusCode::OK),
404 (Method::POST, "/users", "Create User", StatusCode::OK),
405 (Method::GET, "/users/123", "Get User: 123", StatusCode::OK),
406 (
407 Method::DELETE,
408 "/users/123",
409 "Delete User: 123",
410 StatusCode::OK,
411 ),
412 (
413 Method::GET,
414 "/users/123/orders/456",
415 "Get Order: 456 for User: 123",
416 StatusCode::OK,
417 ),
418 (
419 Method::PUT,
420 "/users/123",
421 "Not Found",
422 StatusCode::NOT_FOUND,
423 ),
424 (
425 Method::GET,
426 "/assets/css/style.css",
427 "Serve Assets: /css/style.css",
428 StatusCode::OK,
429 ),
430 (
431 Method::GET,
432 "/not-found",
433 "Not Found",
434 StatusCode::NOT_FOUND,
435 ),
436 ];
437
438 for (method, path, expected_body, expected_status) in cases {
439 let req = match method {
440 Method::GET => Request::get(path),
441 Method::POST => Request::post(path),
442 Method::PUT => Request::put(path),
443 Method::DELETE => Request::delete(path),
444 _ => panic!("Unsupported HTTP method"),
445 }
446 .body(Body::empty())
447 .unwrap();
448
449 let res = router.serve(Context::default(), req).await.unwrap();
450 assert_eq!(res.status(), expected_status);
451 let body = res.into_body().collect().await.unwrap().to_bytes();
452 assert_eq!(body, expected_body);
453 }
454 }
455
456 #[tokio::test]
457 async fn test_router_nest() {
458 let api_router = Router::new()
459 .get("/users", get_users_service())
460 .post("/users", create_user_service())
461 .delete("/users/{user_id}", delete_user_service())
462 .sub(
463 "/users/{user_id}",
464 Router::new()
465 .get("/", get_user_service())
466 .get("/orders/{order_id}", get_user_order_service()),
467 );
468
469 let app = Router::new()
470 .sub("/api", api_router)
471 .get("/", root_service());
472
473 let cases = vec![
474 (Method::GET, "/", "Hello, World!", StatusCode::OK),
475 (Method::GET, "/api/users", "List Users", StatusCode::OK),
476 (Method::POST, "/api/users", "Create User", StatusCode::OK),
477 (
478 Method::DELETE,
479 "/api/users/123",
480 "Delete User: 123",
481 StatusCode::OK,
482 ),
483 (
484 Method::GET,
485 "/api/users/123",
486 "Get User: 123",
487 StatusCode::OK,
488 ),
489 (
490 Method::GET,
491 "/api/users/123/orders/456",
492 "Get Order: 456 for User: 123",
493 StatusCode::OK,
494 ),
495 ];
496
497 for (method, path, expected_body, expected_status) in cases {
498 let req = match method {
499 Method::GET => Request::get(path),
500 Method::POST => Request::post(path),
501 Method::DELETE => Request::delete(path),
502 _ => panic!("Unsupported HTTP method"),
503 }
504 .body(Body::empty())
505 .unwrap();
506
507 let res = app.serve(Context::default(), req).await.unwrap();
508 assert_eq!(
509 res.status(),
510 expected_status,
511 "method: {method} ; path = {path}"
512 );
513 let body = res.into_body().collect().await.unwrap().to_bytes();
514 assert_eq!(body, expected_body, "method: {method} ; path = {path}");
515 }
516 }
517}