axum/routing/
method_routing.rs

1//! Route to services and handlers based on HTTP methods.
2
3use super::{future::InfallibleRouteFuture, IntoMakeService};
4#[cfg(feature = "tokio")]
5use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6use crate::{
7    body::{Body, Bytes, HttpBody},
8    boxed::BoxedIntoRoute,
9    error_handling::{HandleError, HandleErrorLayer},
10    handler::Handler,
11    http::{Method, StatusCode},
12    response::Response,
13    routing::{future::RouteFuture, Fallback, MethodFilter, Route},
14};
15use axum_core::{extract::Request, response::IntoResponse, BoxError};
16use bytes::BytesMut;
17use std::{
18    convert::Infallible,
19    fmt,
20    task::{Context, Poll},
21};
22use tower::{service_fn, util::MapResponseLayer};
23use tower_layer::Layer;
24use tower_service::Service;
25
26macro_rules! top_level_service_fn {
27    (
28        $name:ident, GET
29    ) => {
30        top_level_service_fn!(
31            /// Route `GET` requests to the given service.
32            ///
33            /// # Example
34            ///
35            /// ```rust
36            /// use axum::{
37            ///     extract::Request,
38            ///     Router,
39            ///     routing::get_service,
40            ///     body::Body,
41            /// };
42            /// use http::Response;
43            /// use std::convert::Infallible;
44            ///
45            /// let service = tower::service_fn(|request: Request| async {
46            ///     Ok::<_, Infallible>(Response::new(Body::empty()))
47            /// });
48            ///
49            /// // Requests to `GET /` will go to `service`.
50            /// let app = Router::new().route("/", get_service(service));
51            /// # let _: Router = app;
52            /// ```
53            ///
54            /// Note that `get` routes will also be called for `HEAD` requests but will have
55            /// the response body removed. Make sure to add explicit `HEAD` routes
56            /// afterwards.
57            $name,
58            GET
59        );
60    };
61
62    (
63        $name:ident, CONNECT
64    ) => {
65        top_level_service_fn!(
66            /// Route `CONNECT` requests to the given service.
67            ///
68            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
69            /// and [`get_service`] for an example.
70            $name,
71            CONNECT
72        );
73    };
74
75    (
76        $name:ident, $method:ident
77    ) => {
78        top_level_service_fn!(
79            #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
80            ///
81            /// See [`get_service`] for an example.
82            $name,
83            $method
84        );
85    };
86
87    (
88        $(#[$m:meta])+
89        $name:ident, $method:ident
90    ) => {
91        $(#[$m])+
92        pub fn $name<T, S>(svc: T) -> MethodRouter<S, T::Error>
93        where
94            T: Service<Request> + Clone + Send + Sync + 'static,
95            T::Response: IntoResponse + 'static,
96            T::Future: Send + 'static,
97            S: Clone,
98        {
99            on_service(MethodFilter::$method, svc)
100        }
101    };
102}
103
104macro_rules! top_level_handler_fn {
105    (
106        $name:ident, GET
107    ) => {
108        top_level_handler_fn!(
109            /// Route `GET` requests to the given handler.
110            ///
111            /// # Example
112            ///
113            /// ```rust
114            /// use axum::{
115            ///     routing::get,
116            ///     Router,
117            /// };
118            ///
119            /// async fn handler() {}
120            ///
121            /// // Requests to `GET /` will go to `handler`.
122            /// let app = Router::new().route("/", get(handler));
123            /// # let _: Router = app;
124            /// ```
125            ///
126            /// Note that `get` routes will also be called for `HEAD` requests but will have
127            /// the response body removed. Make sure to add explicit `HEAD` routes
128            /// afterwards.
129            $name,
130            GET
131        );
132    };
133
134    (
135        $name:ident, CONNECT
136    ) => {
137        top_level_handler_fn!(
138            /// Route `CONNECT` requests to the given handler.
139            ///
140            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
141            /// and [`get`] for an example.
142            $name,
143            CONNECT
144        );
145    };
146
147    (
148        $name:ident, $method:ident
149    ) => {
150        top_level_handler_fn!(
151            #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
152            ///
153            /// See [`get`] for an example.
154            $name,
155            $method
156        );
157    };
158
159    (
160        $(#[$m:meta])+
161        $name:ident, $method:ident
162    ) => {
163        $(#[$m])+
164        pub fn $name<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
165        where
166            H: Handler<T, S>,
167            T: 'static,
168            S: Clone + Send + Sync + 'static,
169        {
170            on(MethodFilter::$method, handler)
171        }
172    };
173}
174
175macro_rules! chained_service_fn {
176    (
177        $name:ident, GET
178    ) => {
179        chained_service_fn!(
180            /// Chain an additional service that will only accept `GET` requests.
181            ///
182            /// # Example
183            ///
184            /// ```rust
185            /// use axum::{
186            ///     extract::Request,
187            ///     Router,
188            ///     routing::post_service,
189            ///     body::Body,
190            /// };
191            /// use http::Response;
192            /// use std::convert::Infallible;
193            ///
194            /// let service = tower::service_fn(|request: Request| async {
195            ///     Ok::<_, Infallible>(Response::new(Body::empty()))
196            /// });
197            ///
198            /// let other_service = tower::service_fn(|request: Request| async {
199            ///     Ok::<_, Infallible>(Response::new(Body::empty()))
200            /// });
201            ///
202            /// // Requests to `POST /` will go to `service` and `GET /` will go to
203            /// // `other_service`.
204            /// let app = Router::new().route("/", post_service(service).get_service(other_service));
205            /// # let _: Router = app;
206            /// ```
207            ///
208            /// Note that `get` routes will also be called for `HEAD` requests but will have
209            /// the response body removed. Make sure to add explicit `HEAD` routes
210            /// afterwards.
211            $name,
212            GET
213        );
214    };
215
216    (
217        $name:ident, CONNECT
218    ) => {
219        chained_service_fn!(
220            /// Chain an additional service that will only accept `CONNECT` requests.
221            ///
222            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
223            /// and [`MethodRouter::get_service`] for an example.
224            $name,
225            CONNECT
226        );
227    };
228
229    (
230        $name:ident, $method:ident
231    ) => {
232        chained_service_fn!(
233            #[doc = concat!("Chain an additional service that will only accept `", stringify!($method),"` requests.")]
234            ///
235            /// See [`MethodRouter::get_service`] for an example.
236            $name,
237            $method
238        );
239    };
240
241    (
242        $(#[$m:meta])+
243        $name:ident, $method:ident
244    ) => {
245        $(#[$m])+
246        #[track_caller]
247        pub fn $name<T>(self, svc: T) -> Self
248        where
249            T: Service<Request, Error = E>
250                + Clone
251                + Send
252                + Sync
253                + 'static,
254            T::Response: IntoResponse + 'static,
255            T::Future: Send + 'static,
256        {
257            self.on_service(MethodFilter::$method, svc)
258        }
259    };
260}
261
262macro_rules! chained_handler_fn {
263    (
264        $name:ident, GET
265    ) => {
266        chained_handler_fn!(
267            /// Chain an additional handler that will only accept `GET` requests.
268            ///
269            /// # Example
270            ///
271            /// ```rust
272            /// use axum::{routing::post, Router};
273            ///
274            /// async fn handler() {}
275            ///
276            /// async fn other_handler() {}
277            ///
278            /// // Requests to `POST /` will go to `handler` and `GET /` will go to
279            /// // `other_handler`.
280            /// let app = Router::new().route("/", post(handler).get(other_handler));
281            /// # let _: Router = app;
282            /// ```
283            ///
284            /// Note that `get` routes will also be called for `HEAD` requests but will have
285            /// the response body removed. Make sure to add explicit `HEAD` routes
286            /// afterwards.
287            $name,
288            GET
289        );
290    };
291
292    (
293        $name:ident, CONNECT
294    ) => {
295        chained_handler_fn!(
296            /// Chain an additional handler that will only accept `CONNECT` requests.
297            ///
298            /// See [`MethodFilter::CONNECT`] for when you'd want to use this,
299            /// and [`MethodRouter::get`] for an example.
300            $name,
301            CONNECT
302        );
303    };
304
305    (
306        $name:ident, $method:ident
307    ) => {
308        chained_handler_fn!(
309            #[doc = concat!("Chain an additional handler that will only accept `", stringify!($method),"` requests.")]
310            ///
311            /// See [`MethodRouter::get`] for an example.
312            $name,
313            $method
314        );
315    };
316
317    (
318        $(#[$m:meta])+
319        $name:ident, $method:ident
320    ) => {
321        $(#[$m])+
322        #[track_caller]
323        pub fn $name<H, T>(self, handler: H) -> Self
324        where
325            H: Handler<T, S>,
326            T: 'static,
327            S: Send + Sync + 'static,
328        {
329            self.on(MethodFilter::$method, handler)
330        }
331    };
332}
333
334top_level_service_fn!(connect_service, CONNECT);
335top_level_service_fn!(delete_service, DELETE);
336top_level_service_fn!(get_service, GET);
337top_level_service_fn!(head_service, HEAD);
338top_level_service_fn!(options_service, OPTIONS);
339top_level_service_fn!(patch_service, PATCH);
340top_level_service_fn!(post_service, POST);
341top_level_service_fn!(put_service, PUT);
342top_level_service_fn!(trace_service, TRACE);
343
344/// Route requests with the given method to the service.
345///
346/// # Example
347///
348/// ```rust
349/// use axum::{
350///     extract::Request,
351///     routing::on,
352///     Router,
353///     body::Body,
354///     routing::{MethodFilter, on_service},
355/// };
356/// use http::Response;
357/// use std::convert::Infallible;
358///
359/// let service = tower::service_fn(|request: Request| async {
360///     Ok::<_, Infallible>(Response::new(Body::empty()))
361/// });
362///
363/// // Requests to `POST /` will go to `service`.
364/// let app = Router::new().route("/", on_service(MethodFilter::POST, service));
365/// # let _: Router = app;
366/// ```
367pub fn on_service<T, S>(filter: MethodFilter, svc: T) -> MethodRouter<S, T::Error>
368where
369    T: Service<Request> + Clone + Send + Sync + 'static,
370    T::Response: IntoResponse + 'static,
371    T::Future: Send + 'static,
372    S: Clone,
373{
374    MethodRouter::new().on_service(filter, svc)
375}
376
377/// Route requests to the given service regardless of its method.
378///
379/// # Example
380///
381/// ```rust
382/// use axum::{
383///     extract::Request,
384///     Router,
385///     routing::any_service,
386///     body::Body,
387/// };
388/// use http::Response;
389/// use std::convert::Infallible;
390///
391/// let service = tower::service_fn(|request: Request| async {
392///     Ok::<_, Infallible>(Response::new(Body::empty()))
393/// });
394///
395/// // All requests to `/` will go to `service`.
396/// let app = Router::new().route("/", any_service(service));
397/// # let _: Router = app;
398/// ```
399///
400/// Additional methods can still be chained:
401///
402/// ```rust
403/// use axum::{
404///     extract::Request,
405///     Router,
406///     routing::any_service,
407///     body::Body,
408/// };
409/// use http::Response;
410/// use std::convert::Infallible;
411///
412/// let service = tower::service_fn(|request: Request| async {
413///     # Ok::<_, Infallible>(Response::new(Body::empty()))
414///     // ...
415/// });
416///
417/// let other_service = tower::service_fn(|request: Request| async {
418///     # Ok::<_, Infallible>(Response::new(Body::empty()))
419///     // ...
420/// });
421///
422/// // `POST /` goes to `other_service`. All other requests go to `service`
423/// let app = Router::new().route("/", any_service(service).post_service(other_service));
424/// # let _: Router = app;
425/// ```
426pub fn any_service<T, S>(svc: T) -> MethodRouter<S, T::Error>
427where
428    T: Service<Request> + Clone + Send + Sync + 'static,
429    T::Response: IntoResponse + 'static,
430    T::Future: Send + 'static,
431    S: Clone,
432{
433    MethodRouter::new()
434        .fallback_service(svc)
435        .skip_allow_header()
436}
437
438top_level_handler_fn!(connect, CONNECT);
439top_level_handler_fn!(delete, DELETE);
440top_level_handler_fn!(get, GET);
441top_level_handler_fn!(head, HEAD);
442top_level_handler_fn!(options, OPTIONS);
443top_level_handler_fn!(patch, PATCH);
444top_level_handler_fn!(post, POST);
445top_level_handler_fn!(put, PUT);
446top_level_handler_fn!(trace, TRACE);
447
448/// Route requests with the given method to the handler.
449///
450/// # Example
451///
452/// ```rust
453/// use axum::{
454///     routing::on,
455///     Router,
456///     routing::MethodFilter,
457/// };
458///
459/// async fn handler() {}
460///
461/// // Requests to `POST /` will go to `handler`.
462/// let app = Router::new().route("/", on(MethodFilter::POST, handler));
463/// # let _: Router = app;
464/// ```
465pub fn on<H, T, S>(filter: MethodFilter, handler: H) -> MethodRouter<S, Infallible>
466where
467    H: Handler<T, S>,
468    T: 'static,
469    S: Clone + Send + Sync + 'static,
470{
471    MethodRouter::new().on(filter, handler)
472}
473
474/// Route requests with the given handler regardless of the method.
475///
476/// # Example
477///
478/// ```rust
479/// use axum::{
480///     routing::any,
481///     Router,
482/// };
483///
484/// async fn handler() {}
485///
486/// // All requests to `/` will go to `handler`.
487/// let app = Router::new().route("/", any(handler));
488/// # let _: Router = app;
489/// ```
490///
491/// Additional methods can still be chained:
492///
493/// ```rust
494/// use axum::{
495///     routing::any,
496///     Router,
497/// };
498///
499/// async fn handler() {}
500///
501/// async fn other_handler() {}
502///
503/// // `POST /` goes to `other_handler`. All other requests go to `handler`
504/// let app = Router::new().route("/", any(handler).post(other_handler));
505/// # let _: Router = app;
506/// ```
507pub fn any<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
508where
509    H: Handler<T, S>,
510    T: 'static,
511    S: Clone + Send + Sync + 'static,
512{
513    MethodRouter::new().fallback(handler).skip_allow_header()
514}
515
516/// A [`Service`] that accepts requests based on a [`MethodFilter`] and
517/// allows chaining additional handlers and services.
518///
519/// # When does `MethodRouter` implement [`Service`]?
520///
521/// Whether or not `MethodRouter` implements [`Service`] depends on the state type it requires.
522///
523/// ```
524/// use tower::Service;
525/// use axum::{routing::get, extract::{State, Request}, body::Body};
526///
527/// // this `MethodRouter` doesn't require any state, i.e. the state is `()`,
528/// let method_router = get(|| async {});
529/// // and thus it implements `Service`
530/// assert_service(method_router);
531///
532/// // this requires a `String` and doesn't implement `Service`
533/// let method_router = get(|_: State<String>| async {});
534/// // until you provide the `String` with `.with_state(...)`
535/// let method_router_with_state = method_router.with_state(String::new());
536/// // and then it implements `Service`
537/// assert_service(method_router_with_state);
538///
539/// // helper to check that a value implements `Service`
540/// fn assert_service<S>(service: S)
541/// where
542///     S: Service<Request>,
543/// {}
544/// ```
545#[must_use]
546pub struct MethodRouter<S = (), E = Infallible> {
547    get: MethodEndpoint<S, E>,
548    head: MethodEndpoint<S, E>,
549    delete: MethodEndpoint<S, E>,
550    options: MethodEndpoint<S, E>,
551    patch: MethodEndpoint<S, E>,
552    post: MethodEndpoint<S, E>,
553    put: MethodEndpoint<S, E>,
554    trace: MethodEndpoint<S, E>,
555    connect: MethodEndpoint<S, E>,
556    fallback: Fallback<S, E>,
557    allow_header: AllowHeader,
558}
559
560#[derive(Clone, Debug)]
561enum AllowHeader {
562    /// No `Allow` header value has been built-up yet. This is the default state
563    None,
564    /// Don't set an `Allow` header. This is used when `any` or `any_service` are called.
565    Skip,
566    /// The current value of the `Allow` header.
567    Bytes(BytesMut),
568}
569
570impl AllowHeader {
571    fn merge(self, other: Self) -> Self {
572        match (self, other) {
573            (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
574            (AllowHeader::None, AllowHeader::None) => AllowHeader::None,
575            (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
576            (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
577            (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
578                a.extend_from_slice(b",");
579                a.extend_from_slice(&b);
580                AllowHeader::Bytes(a)
581            }
582        }
583    }
584}
585
586impl<S, E> fmt::Debug for MethodRouter<S, E> {
587    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
588        f.debug_struct("MethodRouter")
589            .field("get", &self.get)
590            .field("head", &self.head)
591            .field("delete", &self.delete)
592            .field("options", &self.options)
593            .field("patch", &self.patch)
594            .field("post", &self.post)
595            .field("put", &self.put)
596            .field("trace", &self.trace)
597            .field("connect", &self.connect)
598            .field("fallback", &self.fallback)
599            .field("allow_header", &self.allow_header)
600            .finish()
601    }
602}
603
604impl<S> MethodRouter<S, Infallible>
605where
606    S: Clone,
607{
608    /// Chain an additional handler that will accept requests matching the given
609    /// `MethodFilter`.
610    ///
611    /// # Example
612    ///
613    /// ```rust
614    /// use axum::{
615    ///     routing::get,
616    ///     Router,
617    ///     routing::MethodFilter
618    /// };
619    ///
620    /// async fn handler() {}
621    ///
622    /// async fn other_handler() {}
623    ///
624    /// // Requests to `GET /` will go to `handler` and `DELETE /` will go to
625    /// // `other_handler`
626    /// let app = Router::new().route("/", get(handler).on(MethodFilter::DELETE, other_handler));
627    /// # let _: Router = app;
628    /// ```
629    #[track_caller]
630    pub fn on<H, T>(self, filter: MethodFilter, handler: H) -> Self
631    where
632        H: Handler<T, S>,
633        T: 'static,
634        S: Send + Sync + 'static,
635    {
636        self.on_endpoint(
637            filter,
638            MethodEndpoint::BoxedHandler(BoxedIntoRoute::from_handler(handler)),
639        )
640    }
641
642    chained_handler_fn!(connect, CONNECT);
643    chained_handler_fn!(delete, DELETE);
644    chained_handler_fn!(get, GET);
645    chained_handler_fn!(head, HEAD);
646    chained_handler_fn!(options, OPTIONS);
647    chained_handler_fn!(patch, PATCH);
648    chained_handler_fn!(post, POST);
649    chained_handler_fn!(put, PUT);
650    chained_handler_fn!(trace, TRACE);
651
652    /// Add a fallback [`Handler`] to the router.
653    pub fn fallback<H, T>(mut self, handler: H) -> Self
654    where
655        H: Handler<T, S>,
656        T: 'static,
657        S: Send + Sync + 'static,
658    {
659        self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
660        self
661    }
662
663    /// Add a fallback [`Handler`] if no custom one has been provided.
664    pub(crate) fn default_fallback<H, T>(self, handler: H) -> Self
665    where
666        H: Handler<T, S>,
667        T: 'static,
668        S: Send + Sync + 'static,
669    {
670        match self.fallback {
671            Fallback::Default(_) => self.fallback(handler),
672            _ => self,
673        }
674    }
675}
676
677impl MethodRouter<(), Infallible> {
678    /// Convert the router into a [`MakeService`].
679    ///
680    /// This allows you to serve a single `MethodRouter` if you don't need any
681    /// routing based on the path:
682    ///
683    /// ```rust
684    /// use axum::{
685    ///     handler::Handler,
686    ///     http::{Uri, Method},
687    ///     response::IntoResponse,
688    ///     routing::get,
689    /// };
690    /// use std::net::SocketAddr;
691    ///
692    /// async fn handler(method: Method, uri: Uri, body: String) -> String {
693    ///     format!("received `{method} {uri}` with body `{body:?}`")
694    /// }
695    ///
696    /// let router = get(handler).post(handler);
697    ///
698    /// # async {
699    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
700    /// axum::serve(listener, router.into_make_service()).await.unwrap();
701    /// # };
702    /// ```
703    ///
704    /// [`MakeService`]: tower::make::MakeService
705    pub fn into_make_service(self) -> IntoMakeService<Self> {
706        IntoMakeService::new(self.with_state(()))
707    }
708
709    /// Convert the router into a [`MakeService`] which stores information
710    /// about the incoming connection.
711    ///
712    /// See [`Router::into_make_service_with_connect_info`] for more details.
713    ///
714    /// ```rust
715    /// use axum::{
716    ///     handler::Handler,
717    ///     response::IntoResponse,
718    ///     extract::ConnectInfo,
719    ///     routing::get,
720    /// };
721    /// use std::net::SocketAddr;
722    ///
723    /// async fn handler(ConnectInfo(addr): ConnectInfo<SocketAddr>) -> String {
724    ///     format!("Hello {addr}")
725    /// }
726    ///
727    /// let router = get(handler).post(handler);
728    ///
729    /// # async {
730    /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
731    /// axum::serve(listener, router.into_make_service()).await.unwrap();
732    /// # };
733    /// ```
734    ///
735    /// [`MakeService`]: tower::make::MakeService
736    /// [`Router::into_make_service_with_connect_info`]: crate::routing::Router::into_make_service_with_connect_info
737    #[cfg(feature = "tokio")]
738    pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
739        IntoMakeServiceWithConnectInfo::new(self.with_state(()))
740    }
741}
742
743impl<S, E> MethodRouter<S, E>
744where
745    S: Clone,
746{
747    /// Create a default `MethodRouter` that will respond with `405 Method Not Allowed` to all
748    /// requests.
749    pub fn new() -> Self {
750        let fallback = Route::new(service_fn(|_: Request| async {
751            Ok(StatusCode::METHOD_NOT_ALLOWED.into_response())
752        }));
753
754        Self {
755            get: MethodEndpoint::None,
756            head: MethodEndpoint::None,
757            delete: MethodEndpoint::None,
758            options: MethodEndpoint::None,
759            patch: MethodEndpoint::None,
760            post: MethodEndpoint::None,
761            put: MethodEndpoint::None,
762            trace: MethodEndpoint::None,
763            connect: MethodEndpoint::None,
764            allow_header: AllowHeader::None,
765            fallback: Fallback::Default(fallback),
766        }
767    }
768
769    /// Provide the state for the router.
770    pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E> {
771        MethodRouter {
772            get: self.get.with_state(&state),
773            head: self.head.with_state(&state),
774            delete: self.delete.with_state(&state),
775            options: self.options.with_state(&state),
776            patch: self.patch.with_state(&state),
777            post: self.post.with_state(&state),
778            put: self.put.with_state(&state),
779            trace: self.trace.with_state(&state),
780            connect: self.connect.with_state(&state),
781            allow_header: self.allow_header,
782            fallback: self.fallback.with_state(state),
783        }
784    }
785
786    /// Chain an additional service that will accept requests matching the given
787    /// `MethodFilter`.
788    ///
789    /// # Example
790    ///
791    /// ```rust
792    /// use axum::{
793    ///     extract::Request,
794    ///     Router,
795    ///     routing::{MethodFilter, on_service},
796    ///     body::Body,
797    /// };
798    /// use http::Response;
799    /// use std::convert::Infallible;
800    ///
801    /// let service = tower::service_fn(|request: Request| async {
802    ///     Ok::<_, Infallible>(Response::new(Body::empty()))
803    /// });
804    ///
805    /// // Requests to `DELETE /` will go to `service`
806    /// let app = Router::new().route("/", on_service(MethodFilter::DELETE, service));
807    /// # let _: Router = app;
808    /// ```
809    #[track_caller]
810    pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
811    where
812        T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
813        T::Response: IntoResponse + 'static,
814        T::Future: Send + 'static,
815    {
816        self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc)))
817    }
818
819    #[track_caller]
820    fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint<S, E>) -> Self {
821        // written as a separate function to generate less IR
822        #[track_caller]
823        fn set_endpoint<S, E>(
824            method_name: &str,
825            out: &mut MethodEndpoint<S, E>,
826            endpoint: &MethodEndpoint<S, E>,
827            endpoint_filter: MethodFilter,
828            filter: MethodFilter,
829            allow_header: &mut AllowHeader,
830            methods: &[&'static str],
831        ) where
832            MethodEndpoint<S, E>: Clone,
833            S: Clone,
834        {
835            if endpoint_filter.contains(filter) {
836                if out.is_some() {
837                    panic!(
838                        "Overlapping method route. Cannot add two method routes that both handle \
839                         `{method_name}`",
840                    )
841                }
842                *out = endpoint.clone();
843                for method in methods {
844                    append_allow_header(allow_header, method);
845                }
846            }
847        }
848
849        set_endpoint(
850            "GET",
851            &mut self.get,
852            &endpoint,
853            filter,
854            MethodFilter::GET,
855            &mut self.allow_header,
856            &["GET", "HEAD"],
857        );
858
859        set_endpoint(
860            "HEAD",
861            &mut self.head,
862            &endpoint,
863            filter,
864            MethodFilter::HEAD,
865            &mut self.allow_header,
866            &["HEAD"],
867        );
868
869        set_endpoint(
870            "TRACE",
871            &mut self.trace,
872            &endpoint,
873            filter,
874            MethodFilter::TRACE,
875            &mut self.allow_header,
876            &["TRACE"],
877        );
878
879        set_endpoint(
880            "PUT",
881            &mut self.put,
882            &endpoint,
883            filter,
884            MethodFilter::PUT,
885            &mut self.allow_header,
886            &["PUT"],
887        );
888
889        set_endpoint(
890            "POST",
891            &mut self.post,
892            &endpoint,
893            filter,
894            MethodFilter::POST,
895            &mut self.allow_header,
896            &["POST"],
897        );
898
899        set_endpoint(
900            "PATCH",
901            &mut self.patch,
902            &endpoint,
903            filter,
904            MethodFilter::PATCH,
905            &mut self.allow_header,
906            &["PATCH"],
907        );
908
909        set_endpoint(
910            "OPTIONS",
911            &mut self.options,
912            &endpoint,
913            filter,
914            MethodFilter::OPTIONS,
915            &mut self.allow_header,
916            &["OPTIONS"],
917        );
918
919        set_endpoint(
920            "DELETE",
921            &mut self.delete,
922            &endpoint,
923            filter,
924            MethodFilter::DELETE,
925            &mut self.allow_header,
926            &["DELETE"],
927        );
928
929        set_endpoint(
930            "CONNECT",
931            &mut self.options,
932            &endpoint,
933            filter,
934            MethodFilter::CONNECT,
935            &mut self.allow_header,
936            &["CONNECT"],
937        );
938
939        self
940    }
941
942    chained_service_fn!(connect_service, CONNECT);
943    chained_service_fn!(delete_service, DELETE);
944    chained_service_fn!(get_service, GET);
945    chained_service_fn!(head_service, HEAD);
946    chained_service_fn!(options_service, OPTIONS);
947    chained_service_fn!(patch_service, PATCH);
948    chained_service_fn!(post_service, POST);
949    chained_service_fn!(put_service, PUT);
950    chained_service_fn!(trace_service, TRACE);
951
952    #[doc = include_str!("../docs/method_routing/fallback.md")]
953    pub fn fallback_service<T>(mut self, svc: T) -> Self
954    where
955        T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
956        T::Response: IntoResponse + 'static,
957        T::Future: Send + 'static,
958    {
959        self.fallback = Fallback::Service(Route::new(svc));
960        self
961    }
962
963    #[doc = include_str!("../docs/method_routing/layer.md")]
964    pub fn layer<L, NewError>(self, layer: L) -> MethodRouter<S, NewError>
965    where
966        L: Layer<Route<E>> + Clone + Send + Sync + 'static,
967        L::Service: Service<Request> + Clone + Send + Sync + 'static,
968        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
969        <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
970        <L::Service as Service<Request>>::Future: Send + 'static,
971        E: 'static,
972        S: 'static,
973        NewError: 'static,
974    {
975        let layer_fn = move |route: Route<E>| route.layer(layer.clone());
976
977        MethodRouter {
978            get: self.get.map(layer_fn.clone()),
979            head: self.head.map(layer_fn.clone()),
980            delete: self.delete.map(layer_fn.clone()),
981            options: self.options.map(layer_fn.clone()),
982            patch: self.patch.map(layer_fn.clone()),
983            post: self.post.map(layer_fn.clone()),
984            put: self.put.map(layer_fn.clone()),
985            trace: self.trace.map(layer_fn.clone()),
986            connect: self.connect.map(layer_fn.clone()),
987            fallback: self.fallback.map(layer_fn),
988            allow_header: self.allow_header,
989        }
990    }
991
992    #[doc = include_str!("../docs/method_routing/route_layer.md")]
993    #[track_caller]
994    pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<S, E>
995    where
996        L: Layer<Route<E>> + Clone + Send + Sync + 'static,
997        L::Service: Service<Request, Error = E> + Clone + Send + Sync + 'static,
998        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
999        <L::Service as Service<Request>>::Future: Send + 'static,
1000        E: 'static,
1001        S: 'static,
1002    {
1003        if self.get.is_none()
1004            && self.head.is_none()
1005            && self.delete.is_none()
1006            && self.options.is_none()
1007            && self.patch.is_none()
1008            && self.post.is_none()
1009            && self.put.is_none()
1010            && self.trace.is_none()
1011            && self.connect.is_none()
1012        {
1013            panic!(
1014                "Adding a route_layer before any routes is a no-op. \
1015                 Add the routes you want the layer to apply to first."
1016            );
1017        }
1018
1019        let layer_fn = move |svc| {
1020            let svc = layer.layer(svc);
1021            let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
1022            Route::new(svc)
1023        };
1024
1025        self.get = self.get.map(layer_fn.clone());
1026        self.head = self.head.map(layer_fn.clone());
1027        self.delete = self.delete.map(layer_fn.clone());
1028        self.options = self.options.map(layer_fn.clone());
1029        self.patch = self.patch.map(layer_fn.clone());
1030        self.post = self.post.map(layer_fn.clone());
1031        self.put = self.put.map(layer_fn.clone());
1032        self.trace = self.trace.map(layer_fn.clone());
1033        self.connect = self.connect.map(layer_fn);
1034
1035        self
1036    }
1037
1038    #[track_caller]
1039    pub(crate) fn merge_for_path(mut self, path: Option<&str>, other: MethodRouter<S, E>) -> Self {
1040        // written using inner functions to generate less IR
1041        #[track_caller]
1042        fn merge_inner<S, E>(
1043            path: Option<&str>,
1044            name: &str,
1045            first: MethodEndpoint<S, E>,
1046            second: MethodEndpoint<S, E>,
1047        ) -> MethodEndpoint<S, E> {
1048            match (first, second) {
1049                (MethodEndpoint::None, MethodEndpoint::None) => MethodEndpoint::None,
1050                (pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => pick,
1051                _ => {
1052                    if let Some(path) = path {
1053                        panic!(
1054                            "Overlapping method route. Handler for `{name} {path}` already exists"
1055                        );
1056                    } else {
1057                        panic!(
1058                            "Overlapping method route. Cannot merge two method routes that both \
1059                             define `{name}`"
1060                        );
1061                    }
1062                }
1063            }
1064        }
1065
1066        self.get = merge_inner(path, "GET", self.get, other.get);
1067        self.head = merge_inner(path, "HEAD", self.head, other.head);
1068        self.delete = merge_inner(path, "DELETE", self.delete, other.delete);
1069        self.options = merge_inner(path, "OPTIONS", self.options, other.options);
1070        self.patch = merge_inner(path, "PATCH", self.patch, other.patch);
1071        self.post = merge_inner(path, "POST", self.post, other.post);
1072        self.put = merge_inner(path, "PUT", self.put, other.put);
1073        self.trace = merge_inner(path, "TRACE", self.trace, other.trace);
1074        self.connect = merge_inner(path, "CONNECT", self.connect, other.connect);
1075
1076        self.fallback = self
1077            .fallback
1078            .merge(other.fallback)
1079            .expect("Cannot merge two `MethodRouter`s that both have a fallback");
1080
1081        self.allow_header = self.allow_header.merge(other.allow_header);
1082
1083        self
1084    }
1085
1086    #[doc = include_str!("../docs/method_routing/merge.md")]
1087    #[track_caller]
1088    pub fn merge(self, other: MethodRouter<S, E>) -> Self {
1089        self.merge_for_path(None, other)
1090    }
1091
1092    /// Apply a [`HandleErrorLayer`].
1093    ///
1094    /// This is a convenience method for doing `self.layer(HandleErrorLayer::new(f))`.
1095    pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, Infallible>
1096    where
1097        F: Clone + Send + Sync + 'static,
1098        HandleError<Route<E>, F, T>: Service<Request, Error = Infallible>,
1099        <HandleError<Route<E>, F, T> as Service<Request>>::Future: Send,
1100        <HandleError<Route<E>, F, T> as Service<Request>>::Response: IntoResponse + Send,
1101        T: 'static,
1102        E: 'static,
1103        S: 'static,
1104    {
1105        self.layer(HandleErrorLayer::new(f))
1106    }
1107
1108    fn skip_allow_header(mut self) -> Self {
1109        self.allow_header = AllowHeader::Skip;
1110        self
1111    }
1112
1113    pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<E> {
1114        macro_rules! call {
1115            (
1116                $req:expr,
1117                $method_variant:ident,
1118                $svc:expr
1119            ) => {
1120                if *req.method() == Method::$method_variant {
1121                    match $svc {
1122                        MethodEndpoint::None => {}
1123                        MethodEndpoint::Route(route) => {
1124                            return route.clone().oneshot_inner_owned($req);
1125                        }
1126                        MethodEndpoint::BoxedHandler(handler) => {
1127                            let route = handler.clone().into_route(state);
1128                            return route.oneshot_inner_owned($req);
1129                        }
1130                    }
1131                }
1132            };
1133        }
1134
1135        // written with a pattern match like this to ensure we call all routes
1136        let Self {
1137            get,
1138            head,
1139            delete,
1140            options,
1141            patch,
1142            post,
1143            put,
1144            trace,
1145            connect,
1146            fallback,
1147            allow_header,
1148        } = self;
1149
1150        call!(req, HEAD, head);
1151        call!(req, HEAD, get);
1152        call!(req, GET, get);
1153        call!(req, POST, post);
1154        call!(req, OPTIONS, options);
1155        call!(req, PATCH, patch);
1156        call!(req, PUT, put);
1157        call!(req, DELETE, delete);
1158        call!(req, TRACE, trace);
1159        call!(req, CONNECT, connect);
1160
1161        let future = fallback.clone().call_with_state(req, state);
1162
1163        match allow_header {
1164            AllowHeader::None => future.allow_header(Bytes::new()),
1165            AllowHeader::Skip => future,
1166            AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
1167        }
1168    }
1169}
1170
1171fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
1172    match allow_header {
1173        AllowHeader::None => {
1174            *allow_header = AllowHeader::Bytes(BytesMut::from(method));
1175        }
1176        AllowHeader::Skip => {}
1177        AllowHeader::Bytes(allow_header) => {
1178            if let Ok(s) = std::str::from_utf8(allow_header) {
1179                if !s.contains(method) {
1180                    allow_header.extend_from_slice(b",");
1181                    allow_header.extend_from_slice(method.as_bytes());
1182                }
1183            } else {
1184                #[cfg(debug_assertions)]
1185                panic!("`allow_header` contained invalid uft-8. This should never happen")
1186            }
1187        }
1188    }
1189}
1190
1191impl<S, E> Clone for MethodRouter<S, E> {
1192    fn clone(&self) -> Self {
1193        Self {
1194            get: self.get.clone(),
1195            head: self.head.clone(),
1196            delete: self.delete.clone(),
1197            options: self.options.clone(),
1198            patch: self.patch.clone(),
1199            post: self.post.clone(),
1200            put: self.put.clone(),
1201            trace: self.trace.clone(),
1202            connect: self.connect.clone(),
1203            fallback: self.fallback.clone(),
1204            allow_header: self.allow_header.clone(),
1205        }
1206    }
1207}
1208
1209impl<S, E> Default for MethodRouter<S, E>
1210where
1211    S: Clone,
1212{
1213    fn default() -> Self {
1214        Self::new()
1215    }
1216}
1217
1218enum MethodEndpoint<S, E> {
1219    None,
1220    Route(Route<E>),
1221    BoxedHandler(BoxedIntoRoute<S, E>),
1222}
1223
1224impl<S, E> MethodEndpoint<S, E>
1225where
1226    S: Clone,
1227{
1228    fn is_some(&self) -> bool {
1229        matches!(self, Self::Route(_) | Self::BoxedHandler(_))
1230    }
1231
1232    fn is_none(&self) -> bool {
1233        matches!(self, Self::None)
1234    }
1235
1236    fn map<F, E2>(self, f: F) -> MethodEndpoint<S, E2>
1237    where
1238        S: 'static,
1239        E: 'static,
1240        F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + Sync + 'static,
1241        E2: 'static,
1242    {
1243        match self {
1244            Self::None => MethodEndpoint::None,
1245            Self::Route(route) => MethodEndpoint::Route(f(route)),
1246            Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
1247        }
1248    }
1249
1250    fn with_state<S2>(self, state: &S) -> MethodEndpoint<S2, E> {
1251        match self {
1252            MethodEndpoint::None => MethodEndpoint::None,
1253            MethodEndpoint::Route(route) => MethodEndpoint::Route(route),
1254            MethodEndpoint::BoxedHandler(handler) => {
1255                MethodEndpoint::Route(handler.into_route(state.clone()))
1256            }
1257        }
1258    }
1259}
1260
1261impl<S, E> Clone for MethodEndpoint<S, E> {
1262    fn clone(&self) -> Self {
1263        match self {
1264            Self::None => Self::None,
1265            Self::Route(inner) => Self::Route(inner.clone()),
1266            Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
1267        }
1268    }
1269}
1270
1271impl<S, E> fmt::Debug for MethodEndpoint<S, E> {
1272    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1273        match self {
1274            Self::None => f.debug_tuple("None").finish(),
1275            Self::Route(inner) => inner.fmt(f),
1276            Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
1277        }
1278    }
1279}
1280
1281impl<B, E> Service<Request<B>> for MethodRouter<(), E>
1282where
1283    B: HttpBody<Data = Bytes> + Send + 'static,
1284    B::Error: Into<BoxError>,
1285{
1286    type Response = Response;
1287    type Error = E;
1288    type Future = RouteFuture<E>;
1289
1290    #[inline]
1291    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1292        Poll::Ready(Ok(()))
1293    }
1294
1295    #[inline]
1296    fn call(&mut self, req: Request<B>) -> Self::Future {
1297        let req = req.map(Body::new);
1298        self.call_with_state(req, ())
1299    }
1300}
1301
1302impl<S> Handler<(), S> for MethodRouter<S>
1303where
1304    S: Clone + 'static,
1305{
1306    type Future = InfallibleRouteFuture;
1307
1308    fn call(self, req: Request, state: S) -> Self::Future {
1309        InfallibleRouteFuture::new(self.call_with_state(req, state))
1310    }
1311}
1312
1313// for `axum::serve(listener, router)`
1314#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
1315const _: () = {
1316    use crate::serve;
1317
1318    impl<L> Service<serve::IncomingStream<'_, L>> for MethodRouter<()>
1319    where
1320        L: serve::Listener,
1321    {
1322        type Response = Self;
1323        type Error = Infallible;
1324        type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
1325
1326        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1327            Poll::Ready(Ok(()))
1328        }
1329
1330        fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
1331            std::future::ready(Ok(self.clone().with_state(())))
1332        }
1333    }
1334};
1335
1336#[cfg(test)]
1337mod tests {
1338    use super::*;
1339    use crate::{extract::State, handler::HandlerWithoutStateExt};
1340    use http::{header::ALLOW, HeaderMap};
1341    use http_body_util::BodyExt;
1342    use std::time::Duration;
1343    use tower::ServiceExt;
1344    use tower_http::{
1345        services::fs::ServeDir, timeout::TimeoutLayer, validate_request::ValidateRequestHeaderLayer,
1346    };
1347
1348    #[crate::test]
1349    async fn method_not_allowed_by_default() {
1350        let mut svc = MethodRouter::new();
1351        let (status, _, body) = call(Method::GET, &mut svc).await;
1352        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1353        assert!(body.is_empty());
1354    }
1355
1356    #[crate::test]
1357    async fn get_service_fn() {
1358        async fn handle(_req: Request) -> Result<Response<Body>, Infallible> {
1359            Ok(Response::new(Body::from("ok")))
1360        }
1361
1362        let mut svc = get_service(service_fn(handle));
1363
1364        let (status, _, body) = call(Method::GET, &mut svc).await;
1365        assert_eq!(status, StatusCode::OK);
1366        assert_eq!(body, "ok");
1367    }
1368
1369    #[crate::test]
1370    async fn get_handler() {
1371        let mut svc = MethodRouter::new().get(ok);
1372        let (status, _, body) = call(Method::GET, &mut svc).await;
1373        assert_eq!(status, StatusCode::OK);
1374        assert_eq!(body, "ok");
1375    }
1376
1377    #[crate::test]
1378    async fn get_accepts_head() {
1379        let mut svc = MethodRouter::new().get(ok);
1380        let (status, _, body) = call(Method::HEAD, &mut svc).await;
1381        assert_eq!(status, StatusCode::OK);
1382        assert!(body.is_empty());
1383    }
1384
1385    #[crate::test]
1386    async fn head_takes_precedence_over_get() {
1387        let mut svc = MethodRouter::new().head(created).get(ok);
1388        let (status, _, body) = call(Method::HEAD, &mut svc).await;
1389        assert_eq!(status, StatusCode::CREATED);
1390        assert!(body.is_empty());
1391    }
1392
1393    #[crate::test]
1394    async fn merge() {
1395        let mut svc = get(ok).merge(post(ok));
1396
1397        let (status, _, _) = call(Method::GET, &mut svc).await;
1398        assert_eq!(status, StatusCode::OK);
1399
1400        let (status, _, _) = call(Method::POST, &mut svc).await;
1401        assert_eq!(status, StatusCode::OK);
1402    }
1403
1404    #[crate::test]
1405    async fn layer() {
1406        let mut svc = MethodRouter::new()
1407            .get(|| async { std::future::pending::<()>().await })
1408            .layer(ValidateRequestHeaderLayer::bearer("password"));
1409
1410        // method with route
1411        let (status, _, _) = call(Method::GET, &mut svc).await;
1412        assert_eq!(status, StatusCode::UNAUTHORIZED);
1413
1414        // method without route
1415        let (status, _, _) = call(Method::DELETE, &mut svc).await;
1416        assert_eq!(status, StatusCode::UNAUTHORIZED);
1417    }
1418
1419    #[crate::test]
1420    async fn route_layer() {
1421        let mut svc = MethodRouter::new()
1422            .get(|| async { std::future::pending::<()>().await })
1423            .route_layer(ValidateRequestHeaderLayer::bearer("password"));
1424
1425        // method with route
1426        let (status, _, _) = call(Method::GET, &mut svc).await;
1427        assert_eq!(status, StatusCode::UNAUTHORIZED);
1428
1429        // method without route
1430        let (status, _, _) = call(Method::DELETE, &mut svc).await;
1431        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1432    }
1433
1434    #[allow(dead_code)]
1435    async fn building_complex_router() {
1436        let app = crate::Router::new().route(
1437            "/",
1438            // use the all the things 💣️
1439            get(ok)
1440                .post(ok)
1441                .route_layer(ValidateRequestHeaderLayer::bearer("password"))
1442                .merge(delete_service(ServeDir::new(".")))
1443                .fallback(|| async { StatusCode::NOT_FOUND })
1444                .put(ok)
1445                .layer(TimeoutLayer::new(Duration::from_secs(10))),
1446        );
1447
1448        let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
1449        crate::serve(listener, app).await.unwrap();
1450    }
1451
1452    #[crate::test]
1453    async fn sets_allow_header() {
1454        let mut svc = MethodRouter::new().put(ok).patch(ok);
1455        let (status, headers, _) = call(Method::GET, &mut svc).await;
1456        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1457        assert_eq!(headers[ALLOW], "PUT,PATCH");
1458    }
1459
1460    #[crate::test]
1461    async fn sets_allow_header_get_head() {
1462        let mut svc = MethodRouter::new().get(ok).head(ok);
1463        let (status, headers, _) = call(Method::PUT, &mut svc).await;
1464        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1465        assert_eq!(headers[ALLOW], "GET,HEAD");
1466    }
1467
1468    #[crate::test]
1469    async fn empty_allow_header_by_default() {
1470        let mut svc = MethodRouter::new();
1471        let (status, headers, _) = call(Method::PATCH, &mut svc).await;
1472        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1473        assert_eq!(headers[ALLOW], "");
1474    }
1475
1476    #[crate::test]
1477    async fn allow_header_when_merging() {
1478        let a = put(ok).patch(ok);
1479        let b = get(ok).head(ok);
1480        let mut svc = a.merge(b);
1481
1482        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1483        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1484        assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
1485    }
1486
1487    #[crate::test]
1488    async fn allow_header_any() {
1489        let mut svc = any(ok);
1490
1491        let (status, headers, _) = call(Method::GET, &mut svc).await;
1492        assert_eq!(status, StatusCode::OK);
1493        assert!(!headers.contains_key(ALLOW));
1494    }
1495
1496    #[crate::test]
1497    async fn allow_header_with_fallback() {
1498        let mut svc = MethodRouter::new()
1499            .get(ok)
1500            .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") });
1501
1502        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1503        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1504        assert_eq!(headers[ALLOW], "GET,HEAD");
1505    }
1506
1507    #[crate::test]
1508    async fn allow_header_with_fallback_that_sets_allow() {
1509        async fn fallback(method: Method) -> Response {
1510            if method == Method::POST {
1511                "OK".into_response()
1512            } else {
1513                (
1514                    StatusCode::METHOD_NOT_ALLOWED,
1515                    [(ALLOW, "GET,POST")],
1516                    "Method not allowed",
1517                )
1518                    .into_response()
1519            }
1520        }
1521
1522        let mut svc = MethodRouter::new().get(ok).fallback(fallback);
1523
1524        let (status, _, _) = call(Method::GET, &mut svc).await;
1525        assert_eq!(status, StatusCode::OK);
1526
1527        let (status, _, _) = call(Method::POST, &mut svc).await;
1528        assert_eq!(status, StatusCode::OK);
1529
1530        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1531        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1532        assert_eq!(headers[ALLOW], "GET,POST");
1533    }
1534
1535    #[crate::test]
1536    async fn allow_header_noop_middleware() {
1537        let mut svc = MethodRouter::new()
1538            .get(ok)
1539            .layer(tower::layer::util::Identity::new());
1540
1541        let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1542        assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1543        assert_eq!(headers[ALLOW], "GET,HEAD");
1544    }
1545
1546    #[crate::test]
1547    #[should_panic(
1548        expected = "Overlapping method route. Cannot add two method routes that both handle `GET`"
1549    )]
1550    async fn handler_overlaps() {
1551        let _: MethodRouter<()> = get(ok).get(ok);
1552    }
1553
1554    #[crate::test]
1555    #[should_panic(
1556        expected = "Overlapping method route. Cannot add two method routes that both handle `POST`"
1557    )]
1558    async fn service_overlaps() {
1559        let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service());
1560    }
1561
1562    #[crate::test]
1563    async fn get_head_does_not_overlap() {
1564        let _: MethodRouter<()> = get(ok).head(ok);
1565    }
1566
1567    #[crate::test]
1568    async fn head_get_does_not_overlap() {
1569        let _: MethodRouter<()> = head(ok).get(ok);
1570    }
1571
1572    #[crate::test]
1573    async fn accessing_state() {
1574        let mut svc = MethodRouter::new()
1575            .get(|State(state): State<&'static str>| async move { state })
1576            .with_state("state");
1577
1578        let (status, _, text) = call(Method::GET, &mut svc).await;
1579
1580        assert_eq!(status, StatusCode::OK);
1581        assert_eq!(text, "state");
1582    }
1583
1584    #[crate::test]
1585    async fn fallback_accessing_state() {
1586        let mut svc = MethodRouter::new()
1587            .fallback(|State(state): State<&'static str>| async move { state })
1588            .with_state("state");
1589
1590        let (status, _, text) = call(Method::GET, &mut svc).await;
1591
1592        assert_eq!(status, StatusCode::OK);
1593        assert_eq!(text, "state");
1594    }
1595
1596    #[crate::test]
1597    async fn merge_accessing_state() {
1598        let one = get(|State(state): State<&'static str>| async move { state });
1599        let two = post(|State(state): State<&'static str>| async move { state });
1600
1601        let mut svc = one.merge(two).with_state("state");
1602
1603        let (status, _, text) = call(Method::GET, &mut svc).await;
1604        assert_eq!(status, StatusCode::OK);
1605        assert_eq!(text, "state");
1606
1607        let (status, _, _) = call(Method::POST, &mut svc).await;
1608        assert_eq!(status, StatusCode::OK);
1609        assert_eq!(text, "state");
1610    }
1611
1612    async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
1613    where
1614        S: Service<Request, Error = Infallible>,
1615        S::Response: IntoResponse,
1616    {
1617        let request = Request::builder()
1618            .uri("/")
1619            .method(method)
1620            .body(Body::empty())
1621            .unwrap();
1622        let response = svc
1623            .ready()
1624            .await
1625            .unwrap()
1626            .call(request)
1627            .await
1628            .unwrap()
1629            .into_response();
1630        let (parts, body) = response.into_parts();
1631        let body =
1632            String::from_utf8(BodyExt::collect(body).await.unwrap().to_bytes().to_vec()).unwrap();
1633        (parts.status, parts.headers, body)
1634    }
1635
1636    async fn ok() -> (StatusCode, &'static str) {
1637        (StatusCode::OK, "ok")
1638    }
1639
1640    async fn created() -> (StatusCode, &'static str) {
1641        (StatusCode::CREATED, "created")
1642    }
1643}