spin_sdk/http/
router.rs

1// This router implementation is heavily inspired by the `Endpoint` type in the https://github.com/http-rs/tide project.
2
3use super::conversions::{IntoResponse, TryFromRequest, TryIntoRequest};
4use super::{responses, Method, Request, Response};
5use async_trait::async_trait;
6use routefinder::{Captures, Router as MethodRouter};
7use std::future::Future;
8use std::{collections::HashMap, fmt::Display};
9
10/// An HTTP request handler.
11///  
12/// This trait is automatically implemented for `Fn` types, and so is rarely implemented
13/// directly by Spin users.
14#[async_trait(?Send)]
15pub trait Handler {
16    /// Invoke the handler.
17    async fn handle(&self, req: Request, params: Params) -> Response;
18}
19
20#[async_trait(?Send)]
21impl Handler for Box<dyn Handler> {
22    async fn handle(&self, req: Request, params: Params) -> Response {
23        self.as_ref().handle(req, params).await
24    }
25}
26
27#[async_trait(?Send)]
28impl<F, Fut> Handler for F
29where
30    F: Fn(Request, Params) -> Fut + 'static,
31    Fut: Future<Output = Response> + 'static,
32{
33    async fn handle(&self, req: Request, params: Params) -> Response {
34        let fut = (self)(req, params);
35        fut.await
36    }
37}
38
39/// Route parameters extracted from a URI that match a route pattern.
40pub type Params = Captures<'static, 'static>;
41
42/// The Spin SDK HTTP router.
43pub struct Router {
44    methods_map: HashMap<Method, MethodRouter<Box<dyn Handler>>>,
45    any_methods: MethodRouter<Box<dyn Handler>>,
46}
47
48impl Default for Router {
49    fn default() -> Router {
50        Router::new()
51    }
52}
53
54impl Display for Router {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        writeln!(f, "Registered routes:")?;
57        for (method, router) in &self.methods_map {
58            for route in router.iter() {
59                writeln!(f, "- {}: {}", method, route.0)?;
60            }
61        }
62        Ok(())
63    }
64}
65
66struct RouteMatch<'a> {
67    params: Captures<'static, 'static>,
68    handler: &'a dyn Handler,
69}
70
71impl Router {
72    /// Synchronously dispatches a request to the appropriate handler along with the URI parameters.
73    pub fn handle<R>(&self, request: R) -> Response
74    where
75        R: TryIntoRequest,
76        R::Error: IntoResponse,
77    {
78        crate::http::executor::run(self.handle_async(request))
79    }
80
81    /// Asynchronously dispatches a request to the appropriate handler along with the URI parameters.
82    pub async fn handle_async<R>(&self, request: R) -> Response
83    where
84        R: TryIntoRequest,
85        R::Error: IntoResponse,
86    {
87        let request = match R::try_into_request(request) {
88            Ok(r) => r,
89            Err(e) => return e.into_response(),
90        };
91        let method = request.method.clone();
92        let path = &request.path();
93        let RouteMatch { params, handler } = self.find(path, method);
94        handler.handle(request, params).await
95    }
96
97    fn find(&self, path: &str, method: Method) -> RouteMatch<'_> {
98        let best_match = self
99            .methods_map
100            .get(&method)
101            .and_then(|r| r.best_match(path));
102
103        if let Some(m) = best_match {
104            let params = m.captures().into_owned();
105            let handler = m.handler();
106            return RouteMatch { handler, params };
107        }
108
109        let best_match = self.any_methods.best_match(path);
110
111        match best_match {
112            Some(m) => {
113                let params = m.captures().into_owned();
114                let handler = m.handler();
115                RouteMatch { handler, params }
116            }
117            None if method == Method::Head => {
118                // If it is a HTTP HEAD request then check if there is a callback in the methods map
119                // if not then fallback to the behavior of HTTP GET else proceed as usual
120                self.find(path, Method::Get)
121            }
122            None => {
123                // Handle the failure case where no match could be resolved.
124                self.fail(path, method)
125            }
126        }
127    }
128
129    // Helper function to handle the case where a best match couldn't be resolved.
130    fn fail(&self, path: &str, method: Method) -> RouteMatch<'_> {
131        // First, filter all routers to determine if the path can match but the provided method is not allowed.
132        let is_method_not_allowed = self
133            .methods_map
134            .iter()
135            .filter(|(k, _)| **k != method)
136            .any(|(_, r)| r.best_match(path).is_some());
137
138        if is_method_not_allowed {
139            // If this `path` can be handled by a callback registered with a different HTTP method
140            // should return 405 Method Not Allowed
141            RouteMatch {
142                handler: &method_not_allowed,
143                params: Captures::default(),
144            }
145        } else {
146            // ... Otherwise, nothing matched so 404.
147            RouteMatch {
148                handler: &not_found,
149                params: Captures::default(),
150            }
151        }
152    }
153
154    /// Register a handler at the path for all methods.
155    pub fn any<F, Req, Resp>(&mut self, path: &str, handler: F)
156    where
157        F: Fn(Req, Params) -> Resp + 'static,
158        Req: TryFromRequest + 'static,
159        Req::Error: IntoResponse + 'static,
160        Resp: IntoResponse + 'static,
161    {
162        let handler = move |req, params| {
163            let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
164            async move {
165                match res {
166                    Ok(res) => res.into_response(),
167                    Err(e) => e.into_response(),
168                }
169            }
170        };
171
172        self.any_async(path, handler)
173    }
174
175    /// Register an async handler at the path for all methods.
176    pub fn any_async<F, Fut, I, O>(&mut self, path: &str, handler: F)
177    where
178        F: Fn(I, Params) -> Fut + 'static,
179        Fut: Future<Output = O> + 'static,
180        I: TryFromRequest + 'static,
181        I::Error: IntoResponse + 'static,
182        O: IntoResponse + 'static,
183    {
184        let handler = move |req, params| {
185            let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
186            async move {
187                match res {
188                    Ok(f) => f.await.into_response(),
189                    Err(e) => e.into_response(),
190                }
191            }
192        };
193
194        self.any_methods.add(path, Box::new(handler)).unwrap();
195    }
196
197    /// Register a handler at the path for the specified HTTP method.
198    pub fn add<F, Req, Resp>(&mut self, path: &str, method: Method, handler: F)
199    where
200        F: Fn(Req, Params) -> Resp + 'static,
201        Req: TryFromRequest + 'static,
202        Req::Error: IntoResponse + 'static,
203        Resp: IntoResponse + 'static,
204    {
205        let handler = move |req, params| {
206            let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
207            async move {
208                match res {
209                    Ok(res) => res.into_response(),
210                    Err(e) => e.into_response(),
211                }
212            }
213        };
214
215        self.add_async(path, method, handler)
216    }
217
218    /// Register an async handler at the path for the specified HTTP method.
219    pub fn add_async<F, Fut, I, O>(&mut self, path: &str, method: Method, handler: F)
220    where
221        F: Fn(I, Params) -> Fut + 'static,
222        Fut: Future<Output = O> + 'static,
223        I: TryFromRequest + 'static,
224        I::Error: IntoResponse + 'static,
225        O: IntoResponse + 'static,
226    {
227        let handler = move |req, params| {
228            let res = TryFromRequest::try_from_request(req).map(|r| handler(r, params));
229            async move {
230                match res {
231                    Ok(f) => f.await.into_response(),
232                    Err(e) => e.into_response(),
233                }
234            }
235        };
236
237        self.methods_map
238            .entry(method)
239            .or_default()
240            .add(path, Box::new(handler))
241            .unwrap();
242    }
243
244    /// Register a handler at the path for the HTTP GET method.
245    pub fn get<F, Req, Resp>(&mut self, path: &str, handler: F)
246    where
247        F: Fn(Req, Params) -> Resp + 'static,
248        Req: TryFromRequest + 'static,
249        Req::Error: IntoResponse + 'static,
250        Resp: IntoResponse + 'static,
251    {
252        self.add(path, Method::Get, handler)
253    }
254
255    /// Register an async handler at the path for the HTTP GET method.
256    pub fn get_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
257    where
258        F: Fn(Req, Params) -> Fut + 'static,
259        Fut: Future<Output = Resp> + 'static,
260        Req: TryFromRequest + 'static,
261        Req::Error: IntoResponse + 'static,
262        Resp: IntoResponse + 'static,
263    {
264        self.add_async(path, Method::Get, handler)
265    }
266
267    /// Register a handler at the path for the HTTP HEAD method.
268    pub fn head<F, Req, Resp>(&mut self, path: &str, handler: F)
269    where
270        F: Fn(Req, Params) -> Resp + 'static,
271        Req: TryFromRequest + 'static,
272        Req::Error: IntoResponse + 'static,
273        Resp: IntoResponse + 'static,
274    {
275        self.add(path, Method::Head, handler)
276    }
277
278    /// Register an async handler at the path for the HTTP HEAD method.
279    pub fn head_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
280    where
281        F: Fn(Req, Params) -> Fut + 'static,
282        Fut: Future<Output = Resp> + 'static,
283        Req: TryFromRequest + 'static,
284        Req::Error: IntoResponse + 'static,
285        Resp: IntoResponse + 'static,
286    {
287        self.add_async(path, Method::Head, handler)
288    }
289
290    /// Register a handler at the path for the HTTP POST method.
291    pub fn post<F, Req, Resp>(&mut self, path: &str, handler: F)
292    where
293        F: Fn(Req, Params) -> Resp + 'static,
294        Req: TryFromRequest + 'static,
295        Req::Error: IntoResponse + 'static,
296        Resp: IntoResponse + 'static,
297    {
298        self.add(path, Method::Post, handler)
299    }
300
301    /// Register an async handler at the path for the HTTP POST method.
302    pub fn post_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
303    where
304        F: Fn(Req, Params) -> Fut + 'static,
305        Fut: Future<Output = Resp> + 'static,
306        Req: TryFromRequest + 'static,
307        Req::Error: IntoResponse + 'static,
308        Resp: IntoResponse + 'static,
309    {
310        self.add_async(path, Method::Post, handler)
311    }
312
313    /// Register a handler at the path for the HTTP DELETE method.
314    pub fn delete<F, Req, Resp>(&mut self, path: &str, handler: F)
315    where
316        F: Fn(Req, Params) -> Resp + 'static,
317        Req: TryFromRequest + 'static,
318        Req::Error: IntoResponse + 'static,
319        Resp: IntoResponse + 'static,
320    {
321        self.add(path, Method::Delete, handler)
322    }
323
324    /// Register an async handler at the path for the HTTP DELETE method.
325    pub fn delete_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
326    where
327        F: Fn(Req, Params) -> Fut + 'static,
328        Fut: Future<Output = Resp> + 'static,
329        Req: TryFromRequest + 'static,
330        Req::Error: IntoResponse + 'static,
331        Resp: IntoResponse + 'static,
332    {
333        self.add_async(path, Method::Delete, handler)
334    }
335
336    /// Register a handler at the path for the HTTP PUT method.
337    pub fn put<F, Req, Resp>(&mut self, path: &str, handler: F)
338    where
339        F: Fn(Req, Params) -> Resp + 'static,
340        Req: TryFromRequest + 'static,
341        Req::Error: IntoResponse + 'static,
342        Resp: IntoResponse + 'static,
343    {
344        self.add(path, Method::Put, handler)
345    }
346
347    /// Register an async handler at the path for the HTTP PUT method.
348    pub fn put_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
349    where
350        F: Fn(Req, Params) -> Fut + 'static,
351        Fut: Future<Output = Resp> + 'static,
352        Req: TryFromRequest + 'static,
353        Req::Error: IntoResponse + 'static,
354        Resp: IntoResponse + 'static,
355    {
356        self.add_async(path, Method::Put, handler)
357    }
358
359    /// Register a handler at the path for the HTTP PATCH method.
360    pub fn patch<F, Req, Resp>(&mut self, path: &str, handler: F)
361    where
362        F: Fn(Req, Params) -> Resp + 'static,
363        Req: TryFromRequest + 'static,
364        Req::Error: IntoResponse + 'static,
365        Resp: IntoResponse + 'static,
366    {
367        self.add(path, Method::Patch, handler)
368    }
369
370    /// Register an async handler at the path for the HTTP PATCH method.
371    pub fn patch_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
372    where
373        F: Fn(Req, Params) -> Fut + 'static,
374        Fut: Future<Output = Resp> + 'static,
375        Req: TryFromRequest + 'static,
376        Req::Error: IntoResponse + 'static,
377        Resp: IntoResponse + 'static,
378    {
379        self.add_async(path, Method::Patch, handler)
380    }
381
382    /// Register a handler at the path for the HTTP OPTIONS method.
383    pub fn options<F, Req, Resp>(&mut self, path: &str, handler: F)
384    where
385        F: Fn(Req, Params) -> Resp + 'static,
386        Req: TryFromRequest + 'static,
387        Req::Error: IntoResponse + 'static,
388        Resp: IntoResponse + 'static,
389    {
390        self.add(path, Method::Options, handler)
391    }
392
393    /// Register an async handler at the path for the HTTP OPTIONS method.
394    pub fn options_async<F, Fut, Req, Resp>(&mut self, path: &str, handler: F)
395    where
396        F: Fn(Req, Params) -> Fut + 'static,
397        Fut: Future<Output = Resp> + 'static,
398        Req: TryFromRequest + 'static,
399        Req::Error: IntoResponse + 'static,
400        Resp: IntoResponse + 'static,
401    {
402        self.add_async(path, Method::Options, handler)
403    }
404
405    /// Construct a new Router.
406    pub fn new() -> Self {
407        Router {
408            methods_map: HashMap::default(),
409            any_methods: MethodRouter::new(),
410        }
411    }
412}
413
414async fn not_found(_req: Request, _params: Params) -> Response {
415    responses::not_found()
416}
417
418async fn method_not_allowed(_req: Request, _params: Params) -> Response {
419    responses::method_not_allowed()
420}
421
422/// A macro to help with constructing a Router from a stream of tokens.
423#[macro_export]
424macro_rules! http_router {
425    ($($method:tt $path:literal => $h:expr),*) => {
426        {
427            let mut router = $crate::http::Router::new();
428            $(
429                $crate::http_router!(@build router $method $path => $h);
430            )*
431            router
432        }
433    };
434    (@build $r:ident HEAD $path:literal => $h:expr) => {
435        $r.head($path, $h);
436    };
437    (@build $r:ident GET $path:literal => $h:expr) => {
438        $r.get($path, $h);
439    };
440    (@build $r:ident PUT $path:literal => $h:expr) => {
441        $r.put($path, $h);
442    };
443    (@build $r:ident POST $path:literal => $h:expr) => {
444        $r.post($path, $h);
445    };
446    (@build $r:ident PATCH $path:literal => $h:expr) => {
447        $r.patch($path, $h);
448    };
449    (@build $r:ident DELETE $path:literal => $h:expr) => {
450        $r.delete($path, $h);
451    };
452    (@build $r:ident OPTIONS $path:literal => $h:expr) => {
453        $r.options($path, $h);
454    };
455    (@build $r:ident _ $path:literal => $h:expr) => {
456        $r.any($path, $h);
457    };
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463
464    fn make_request(method: Method, path: &str) -> Request {
465        Request::new(method, path)
466    }
467
468    fn echo_param(_req: Request, params: Params) -> Response {
469        match params.get("x") {
470            Some(path) => Response::new(200, path),
471            None => responses::not_found(),
472        }
473    }
474
475    #[test]
476    fn test_method_not_allowed() {
477        let mut router = Router::default();
478        router.get("/:x", echo_param);
479
480        let req = make_request(Method::Post, "/foobar");
481        let res = router.handle(req);
482        assert_eq!(res.status, hyperium::StatusCode::METHOD_NOT_ALLOWED);
483    }
484
485    #[test]
486    fn test_not_found() {
487        fn h1(_req: Request, _params: Params) -> anyhow::Result<Response> {
488            Ok(Response::new(200, ()))
489        }
490
491        let mut router = Router::default();
492        router.get("/h1/:param", h1);
493
494        let req = make_request(Method::Get, "/h1/");
495        let res = router.handle(req);
496        assert_eq!(res.status, hyperium::StatusCode::NOT_FOUND);
497    }
498
499    #[test]
500    fn test_multi_param() {
501        fn multiply(_req: Request, params: Params) -> anyhow::Result<Response> {
502            let x: i64 = params.get("x").unwrap().parse()?;
503            let y: i64 = params.get("y").unwrap().parse()?;
504            Ok(Response::new(200, format!("{result}", result = x * y)))
505        }
506
507        let mut router = Router::default();
508        router.get("/multiply/:x/:y", multiply);
509
510        let req = make_request(Method::Get, "/multiply/2/4");
511        let res = router.handle(req);
512
513        assert_eq!(res.body, "8".to_owned().into_bytes());
514    }
515
516    #[test]
517    fn test_param() {
518        let mut router = Router::default();
519        router.get("/:x", echo_param);
520
521        let req = make_request(Method::Get, "/y");
522        let res = router.handle(req);
523
524        assert_eq!(res.body, "y".to_owned().into_bytes());
525    }
526
527    #[test]
528    fn test_wildcard() {
529        fn echo_wildcard(_req: Request, params: Params) -> Response {
530            match params.wildcard() {
531                Some(path) => Response::new(200, path),
532                None => responses::not_found(),
533            }
534        }
535
536        let mut router = Router::default();
537        router.get("/*", echo_wildcard);
538
539        let req = make_request(Method::Get, "/foo/bar");
540        let res = router.handle(req);
541        assert_eq!(res.status, hyperium::StatusCode::OK);
542        assert_eq!(res.body, "foo/bar".to_owned().into_bytes());
543    }
544
545    #[test]
546    fn test_wildcard_last_segment() {
547        let mut router = Router::default();
548        router.get("/:x/*", echo_param);
549
550        let req = make_request(Method::Get, "/foo/bar");
551        let res = router.handle(req);
552        assert_eq!(res.body, "foo".to_owned().into_bytes());
553    }
554
555    #[test]
556    fn test_router_display() {
557        let mut router = Router::default();
558        router.get("/:x", echo_param);
559
560        let expected = "Registered routes:\n- GET: /:x\n";
561        let actual = format!("{}", router);
562
563        assert_eq!(actual.as_str(), expected);
564    }
565
566    #[test]
567    fn test_ambiguous_wildcard_vs_star() {
568        fn h1(_req: Request, _params: Params) -> anyhow::Result<Response> {
569            Ok(Response::new(200, "one/two"))
570        }
571
572        fn h2(_req: Request, _params: Params) -> anyhow::Result<Response> {
573            Ok(Response::new(200, "posts/*"))
574        }
575
576        let mut router = Router::default();
577        router.get("/:one/:two", h1);
578        router.get("/posts/*", h2);
579
580        let req = make_request(Method::Get, "/posts/2");
581        let res = router.handle(req);
582
583        assert_eq!(res.body, "posts/*".to_owned().into_bytes());
584    }
585}