1use super::conversions::{IntoResponse, TryFromRequest, TryIntoRequest};
4use super::{responses, Method, Request, Response};
5use async_trait::async_trait;
6use routefinder::{Captures, Router as MethodRouter};
7use std::future::Future;
8use std::{collections::HashMap, fmt::Display};
9
10#[async_trait(?Send)]
15pub trait Handler {
16 async fn handle(&self, req: Request, params: Params) -> Response;
18}
19
20#[async_trait(?Send)]
21impl Handler for Box<dyn Handler> {
22 async fn handle(&self, req: Request, params: Params) -> Response {
23 self.as_ref().handle(req, params).await
24 }
25}
26
27#[async_trait(?Send)]
28impl<F, Fut> Handler for F
29where
30 F: Fn(Request, Params) -> Fut + 'static,
31 Fut: Future<Output = Response> + 'static,
32{
33 async fn handle(&self, req: Request, params: Params) -> Response {
34 let fut = (self)(req, params);
35 fut.await
36 }
37}
38
39pub type Params = Captures<'static, 'static>;
41
42pub struct Router {
44 methods_map: HashMap<Method, MethodRouter<Box<dyn Handler>>>,
45 any_methods: MethodRouter<Box<dyn Handler>>,
46}
47
48impl Default for Router {
49 fn default() -> Router {
50 Router::new()
51 }
52}
53
54impl Display for Router {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 writeln!(f, "Registered routes:")?;
57 for (method, router) in &self.methods_map {
58 for route in router.iter() {
59 writeln!(f, "- {}: {}", method, route.0)?;
60 }
61 }
62 Ok(())
63 }
64}
65
66struct RouteMatch<'a> {
67 params: Captures<'static, 'static>,
68 handler: &'a dyn Handler,
69}
70
71impl Router {
72 pub fn handle<R>(&self, request: R) -> Response
74 where
75 R: TryIntoRequest,
76 R::Error: IntoResponse,
77 {
78 crate::http::executor::run(self.handle_async(request))
79 }
80
81 pub async fn handle_async<R>(&self, request: R) -> Response
83 where
84 R: TryIntoRequest,
85 R::Error: IntoResponse,
86 {
87 let request = match R::try_into_request(request) {
88 Ok(r) => r,
89 Err(e) => return e.into_response(),
90 };
91 let method = request.method.clone();
92 let path = &request.path();
93 let RouteMatch { params, handler } = self.find(path, method);
94 handler.handle(request, params).await
95 }
96
97 fn find(&self, path: &str, method: Method) -> RouteMatch<'_> {
98 let best_match = self
99 .methods_map
100 .get(&method)
101 .and_then(|r| r.best_match(path));
102
103 if let Some(m) = best_match {
104 let params = m.captures().into_owned();
105 let handler = m.handler();
106 return RouteMatch { handler, params };
107 }
108
109 let best_match = self.any_methods.best_match(path);
110
111 match best_match {
112 Some(m) => {
113 let params = m.captures().into_owned();
114 let handler = m.handler();
115 RouteMatch { handler, params }
116 }
117 None if method == Method::Head => {
118 self.find(path, Method::Get)
121 }
122 None => {
123 self.fail(path, method)
125 }
126 }
127 }
128
129 fn fail(&self, path: &str, method: Method) -> RouteMatch<'_> {
131 let is_method_not_allowed = self
133 .methods_map
134 .iter()
135 .filter(|(k, _)| **k != method)
136 .any(|(_, r)| r.best_match(path).is_some());
137
138 if is_method_not_allowed {
139 RouteMatch {
142 handler: &method_not_allowed,
143 params: Captures::default(),
144 }
145 } else {
146 RouteMatch {
148 handler: ¬_found,
149 params: Captures::default(),
150 }
151 }
152 }
153
154 pub fn any<F, Req, Resp>(&mut self, path: &str, handler: F)
156 where
157 F: Fn(Req, Params) -> Resp + 'static,
158 Req: TryFromRequest + 'static,
159 Req::Error: IntoResponse + 'static,
160 Resp: IntoResponse + 'static,
161 {
162 let handler = move |req, params| {
163 let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
164 async move {
165 match res {
166 Ok(res) => res.into_response(),
167 Err(e) => e.into_response(),
168 }
169 }
170 };
171
172 self.any_async(path, handler)
173 }
174
175 pub fn any_async<F, Fut, I, O>(&mut self, path: &str, handler: F)
177 where
178 F: Fn(I, Params) -> Fut + 'static,
179 Fut: Future<Output = O> + 'static,
180 I: TryFromRequest + 'static,
181 I::Error: IntoResponse + 'static,
182 O: IntoResponse + 'static,
183 {
184 let handler = move |req, params| {
185 let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
186 async move {
187 match res {
188 Ok(f) => f.await.into_response(),
189 Err(e) => e.into_response(),
190 }
191 }
192 };
193
194 self.any_methods.add(path, Box::new(handler)).unwrap();
195 }
196
197 pub fn add<F, Req, Resp>(&mut self, path: &str, method: Method, handler: F)
199 where
200 F: Fn(Req, Params) -> Resp + 'static,
201 Req: TryFromRequest + 'static,
202 Req::Error: IntoResponse + 'static,
203 Resp: IntoResponse + 'static,
204 {
205 let handler = move |req, params| {
206 let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
207 async move {
208 match res {
209 Ok(res) => res.into_response(),
210 Err(e) => e.into_response(),
211 }
212 }
213 };
214
215 self.add_async(path, method, handler)
216 }
217
218 pub fn add_async<F, Fut, I, O>(&mut self, path: &str, method: Method, handler: F)
220 where
221 F: Fn(I, Params) -> Fut + 'static,
222 Fut: Future<Output = O> + 'static,
223 I: TryFromRequest + 'static,
224 I::Error: IntoResponse + 'static,
225 O: IntoResponse + 'static,
226 {
227 let handler = move |req, params| {
228 let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
229 async move {
230 match res {
231 Ok(f) => f.await.into_response(),
232 Err(e) => e.into_response(),
233 }
234 }
235 };
236
237 self.methods_map
238 .entry(method)
239 .or_default()
240 .add(path, Box::new(handler))
241 .unwrap();
242 }
243
244 pub fn get<F, Req, Resp>(&mut self, path: &str, handler: F)
246 where
247 F: Fn(Req, Params) -> Resp + 'static,
248 Req: TryFromRequest + 'static,
249 Req::Error: IntoResponse + 'static,
250 Resp: IntoResponse + 'static,
251 {
252 self.add(path, Method::Get, handler)
253 }
254
255 pub fn get_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
257 where
258 F: Fn(Req, Params) -> Fut + 'static,
259 Fut: Future<Output = Resp> + 'static,
260 Req: TryFromRequest + 'static,
261 Req::Error: IntoResponse + 'static,
262 Resp: IntoResponse + 'static,
263 {
264 self.add_async(path, Method::Get, handler)
265 }
266
267 pub fn head<F, Req, Resp>(&mut self, path: &str, handler: F)
269 where
270 F: Fn(Req, Params) -> Resp + 'static,
271 Req: TryFromRequest + 'static,
272 Req::Error: IntoResponse + 'static,
273 Resp: IntoResponse + 'static,
274 {
275 self.add(path, Method::Head, handler)
276 }
277
278 pub fn head_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
280 where
281 F: Fn(Req, Params) -> Fut + 'static,
282 Fut: Future<Output = Resp> + 'static,
283 Req: TryFromRequest + 'static,
284 Req::Error: IntoResponse + 'static,
285 Resp: IntoResponse + 'static,
286 {
287 self.add_async(path, Method::Head, handler)
288 }
289
290 pub fn post<F, Req, Resp>(&mut self, path: &str, handler: F)
292 where
293 F: Fn(Req, Params) -> Resp + 'static,
294 Req: TryFromRequest + 'static,
295 Req::Error: IntoResponse + 'static,
296 Resp: IntoResponse + 'static,
297 {
298 self.add(path, Method::Post, handler)
299 }
300
301 pub fn post_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
303 where
304 F: Fn(Req, Params) -> Fut + 'static,
305 Fut: Future<Output = Resp> + 'static,
306 Req: TryFromRequest + 'static,
307 Req::Error: IntoResponse + 'static,
308 Resp: IntoResponse + 'static,
309 {
310 self.add_async(path, Method::Post, handler)
311 }
312
313 pub fn delete<F, Req, Resp>(&mut self, path: &str, handler: F)
315 where
316 F: Fn(Req, Params) -> Resp + 'static,
317 Req: TryFromRequest + 'static,
318 Req::Error: IntoResponse + 'static,
319 Resp: IntoResponse + 'static,
320 {
321 self.add(path, Method::Delete, handler)
322 }
323
324 pub fn delete_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
326 where
327 F: Fn(Req, Params) -> Fut + 'static,
328 Fut: Future<Output = Resp> + 'static,
329 Req: TryFromRequest + 'static,
330 Req::Error: IntoResponse + 'static,
331 Resp: IntoResponse + 'static,
332 {
333 self.add_async(path, Method::Delete, handler)
334 }
335
336 pub fn put<F, Req, Resp>(&mut self, path: &str, handler: F)
338 where
339 F: Fn(Req, Params) -> Resp + 'static,
340 Req: TryFromRequest + 'static,
341 Req::Error: IntoResponse + 'static,
342 Resp: IntoResponse + 'static,
343 {
344 self.add(path, Method::Put, handler)
345 }
346
347 pub fn put_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
349 where
350 F: Fn(Req, Params) -> Fut + 'static,
351 Fut: Future<Output = Resp> + 'static,
352 Req: TryFromRequest + 'static,
353 Req::Error: IntoResponse + 'static,
354 Resp: IntoResponse + 'static,
355 {
356 self.add_async(path, Method::Put, handler)
357 }
358
359 pub fn patch<F, Req, Resp>(&mut self, path: &str, handler: F)
361 where
362 F: Fn(Req, Params) -> Resp + 'static,
363 Req: TryFromRequest + 'static,
364 Req::Error: IntoResponse + 'static,
365 Resp: IntoResponse + 'static,
366 {
367 self.add(path, Method::Patch, handler)
368 }
369
370 pub fn patch_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
372 where
373 F: Fn(Req, Params) -> Fut + 'static,
374 Fut: Future<Output = Resp> + 'static,
375 Req: TryFromRequest + 'static,
376 Req::Error: IntoResponse + 'static,
377 Resp: IntoResponse + 'static,
378 {
379 self.add_async(path, Method::Patch, handler)
380 }
381
382 pub fn options<F, Req, Resp>(&mut self, path: &str, handler: F)
384 where
385 F: Fn(Req, Params) -> Resp + 'static,
386 Req: TryFromRequest + 'static,
387 Req::Error: IntoResponse + 'static,
388 Resp: IntoResponse + 'static,
389 {
390 self.add(path, Method::Options, handler)
391 }
392
393 pub fn options_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
395 where
396 F: Fn(Req, Params) -> Fut + 'static,
397 Fut: Future<Output = Resp> + 'static,
398 Req: TryFromRequest + 'static,
399 Req::Error: IntoResponse + 'static,
400 Resp: IntoResponse + 'static,
401 {
402 self.add_async(path, Method::Options, handler)
403 }
404
405 pub fn new() -> Self {
407 Router {
408 methods_map: HashMap::default(),
409 any_methods: MethodRouter::new(),
410 }
411 }
412}
413
414async fn not_found(_req: Request, _params: Params) -> Response {
415 responses::not_found()
416}
417
418async fn method_not_allowed(_req: Request, _params: Params) -> Response {
419 responses::method_not_allowed()
420}
421
422#[macro_export]
424macro_rules! http_router {
425 ($($method:tt $path:literal => $h:expr),*) => {
426 {
427 let mut router = $crate::http::Router::new();
428 $(
429 $crate::http_router!(@build router $method $path => $h);
430 )*
431 router
432 }
433 };
434 (@build $r:ident HEAD $path:literal => $h:expr) => {
435 $r.head($path, $h);
436 };
437 (@build $r:ident GET $path:literal => $h:expr) => {
438 $r.get($path, $h);
439 };
440 (@build $r:ident PUT $path:literal => $h:expr) => {
441 $r.put($path, $h);
442 };
443 (@build $r:ident POST $path:literal => $h:expr) => {
444 $r.post($path, $h);
445 };
446 (@build $r:ident PATCH $path:literal => $h:expr) => {
447 $r.patch($path, $h);
448 };
449 (@build $r:ident DELETE $path:literal => $h:expr) => {
450 $r.delete($path, $h);
451 };
452 (@build $r:ident OPTIONS $path:literal => $h:expr) => {
453 $r.options($path, $h);
454 };
455 (@build $r:ident _ $path:literal => $h:expr) => {
456 $r.any($path, $h);
457 };
458}
459
460#[cfg(test)]
461mod tests {
462 use super::*;
463
464 fn make_request(method: Method, path: &str) -> Request {
465 Request::new(method, path)
466 }
467
468 fn echo_param(_req: Request, params: Params) -> Response {
469 match params.get("x") {
470 Some(path) => Response::new(200, path),
471 None => responses::not_found(),
472 }
473 }
474
475 #[test]
476 fn test_method_not_allowed() {
477 let mut router = Router::default();
478 router.get("/:x", echo_param);
479
480 let req = make_request(Method::Post, "/foobar");
481 let res = router.handle(req);
482 assert_eq!(res.status, hyperium::StatusCode::METHOD_NOT_ALLOWED);
483 }
484
485 #[test]
486 fn test_not_found() {
487 fn h1(_req: Request, _params: Params) -> anyhow::Result<Response> {
488 Ok(Response::new(200, ()))
489 }
490
491 let mut router = Router::default();
492 router.get("/h1/:param", h1);
493
494 let req = make_request(Method::Get, "/h1/");
495 let res = router.handle(req);
496 assert_eq!(res.status, hyperium::StatusCode::NOT_FOUND);
497 }
498
499 #[test]
500 fn test_multi_param() {
501 fn multiply(_req: Request, params: Params) -> anyhow::Result<Response> {
502 let x: i64 = params.get("x").unwrap().parse()?;
503 let y: i64 = params.get("y").unwrap().parse()?;
504 Ok(Response::new(200, format!("{result}", result = x * y)))
505 }
506
507 let mut router = Router::default();
508 router.get("/multiply/:x/:y", multiply);
509
510 let req = make_request(Method::Get, "/multiply/2/4");
511 let res = router.handle(req);
512
513 assert_eq!(res.body, "8".to_owned().into_bytes());
514 }
515
516 #[test]
517 fn test_param() {
518 let mut router = Router::default();
519 router.get("/:x", echo_param);
520
521 let req = make_request(Method::Get, "/y");
522 let res = router.handle(req);
523
524 assert_eq!(res.body, "y".to_owned().into_bytes());
525 }
526
527 #[test]
528 fn test_wildcard() {
529 fn echo_wildcard(_req: Request, params: Params) -> Response {
530 match params.wildcard() {
531 Some(path) => Response::new(200, path),
532 None => responses::not_found(),
533 }
534 }
535
536 let mut router = Router::default();
537 router.get("/*", echo_wildcard);
538
539 let req = make_request(Method::Get, "/foo/bar");
540 let res = router.handle(req);
541 assert_eq!(res.status, hyperium::StatusCode::OK);
542 assert_eq!(res.body, "foo/bar".to_owned().into_bytes());
543 }
544
545 #[test]
546 fn test_wildcard_last_segment() {
547 let mut router = Router::default();
548 router.get("/:x/*", echo_param);
549
550 let req = make_request(Method::Get, "/foo/bar");
551 let res = router.handle(req);
552 assert_eq!(res.body, "foo".to_owned().into_bytes());
553 }
554
555 #[test]
556 fn test_router_display() {
557 let mut router = Router::default();
558 router.get("/:x", echo_param);
559
560 let expected = "Registered routes:\n- GET: /:x\n";
561 let actual = format!("{}", router);
562
563 assert_eq!(actual.as_str(), expected);
564 }
565
566 #[test]
567 fn test_ambiguous_wildcard_vs_star() {
568 fn h1(_req: Request, _params: Params) -> anyhow::Result<Response> {
569 Ok(Response::new(200, "one/two"))
570 }
571
572 fn h2(_req: Request, _params: Params) -> anyhow::Result<Response> {
573 Ok(Response::new(200, "posts/*"))
574 }
575
576 let mut router = Router::default();
577 router.get("/:one/:two", h1);
578 router.get("/posts/*", h2);
579
580 let req = make_request(Method::Get, "/posts/2");
581 let res = router.handle(req);
582
583 assert_eq!(res.body, "posts/*".to_owned().into_bytes());
584 }
585}