1use super::{future::InfallibleRouteFuture, IntoMakeService};
4#[cfg(feature = "tokio")]
5use crate::extract::connect_info::IntoMakeServiceWithConnectInfo;
6use crate::{
7 body::{Body, Bytes, HttpBody},
8 boxed::BoxedIntoRoute,
9 error_handling::{HandleError, HandleErrorLayer},
10 handler::Handler,
11 http::{Method, StatusCode},
12 response::Response,
13 routing::{future::RouteFuture, Fallback, MethodFilter, Route},
14};
15use axum_core::{extract::Request, response::IntoResponse, BoxError};
16use bytes::BytesMut;
17use std::{
18 convert::Infallible,
19 fmt,
20 task::{Context, Poll},
21};
22use tower::{service_fn, util::MapResponseLayer};
23use tower_layer::Layer;
24use tower_service::Service;
25
26macro_rules! top_level_service_fn {
27 (
28 $name:ident, GET
29 ) => {
30 top_level_service_fn!(
31 $name,
58 GET
59 );
60 };
61
62 (
63 $name:ident, CONNECT
64 ) => {
65 top_level_service_fn!(
66 $name,
71 CONNECT
72 );
73 };
74
75 (
76 $name:ident, $method:ident
77 ) => {
78 top_level_service_fn!(
79 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given service.")]
80 $name,
83 $method
84 );
85 };
86
87 (
88 $(#[$m:meta])+
89 $name:ident, $method:ident
90 ) => {
91 $(#[$m])+
92 pub fn $name<T, S>(svc: T) -> MethodRouter<S, T::Error>
93 where
94 T: Service<Request> + Clone + Send + Sync + 'static,
95 T::Response: IntoResponse + 'static,
96 T::Future: Send + 'static,
97 S: Clone,
98 {
99 on_service(MethodFilter::$method, svc)
100 }
101 };
102}
103
104macro_rules! top_level_handler_fn {
105 (
106 $name:ident, GET
107 ) => {
108 top_level_handler_fn!(
109 $name,
130 GET
131 );
132 };
133
134 (
135 $name:ident, CONNECT
136 ) => {
137 top_level_handler_fn!(
138 $name,
143 CONNECT
144 );
145 };
146
147 (
148 $name:ident, $method:ident
149 ) => {
150 top_level_handler_fn!(
151 #[doc = concat!("Route `", stringify!($method) ,"` requests to the given handler.")]
152 $name,
155 $method
156 );
157 };
158
159 (
160 $(#[$m:meta])+
161 $name:ident, $method:ident
162 ) => {
163 $(#[$m])+
164 pub fn $name<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
165 where
166 H: Handler<T, S>,
167 T: 'static,
168 S: Clone + Send + Sync + 'static,
169 {
170 on(MethodFilter::$method, handler)
171 }
172 };
173}
174
175macro_rules! chained_service_fn {
176 (
177 $name:ident, GET
178 ) => {
179 chained_service_fn!(
180 $name,
212 GET
213 );
214 };
215
216 (
217 $name:ident, CONNECT
218 ) => {
219 chained_service_fn!(
220 $name,
225 CONNECT
226 );
227 };
228
229 (
230 $name:ident, $method:ident
231 ) => {
232 chained_service_fn!(
233 #[doc = concat!("Chain an additional service that will only accept `", stringify!($method),"` requests.")]
234 $name,
237 $method
238 );
239 };
240
241 (
242 $(#[$m:meta])+
243 $name:ident, $method:ident
244 ) => {
245 $(#[$m])+
246 #[track_caller]
247 pub fn $name<T>(self, svc: T) -> Self
248 where
249 T: Service<Request, Error = E>
250 + Clone
251 + Send
252 + Sync
253 + 'static,
254 T::Response: IntoResponse + 'static,
255 T::Future: Send + 'static,
256 {
257 self.on_service(MethodFilter::$method, svc)
258 }
259 };
260}
261
262macro_rules! chained_handler_fn {
263 (
264 $name:ident, GET
265 ) => {
266 chained_handler_fn!(
267 $name,
288 GET
289 );
290 };
291
292 (
293 $name:ident, CONNECT
294 ) => {
295 chained_handler_fn!(
296 $name,
301 CONNECT
302 );
303 };
304
305 (
306 $name:ident, $method:ident
307 ) => {
308 chained_handler_fn!(
309 #[doc = concat!("Chain an additional handler that will only accept `", stringify!($method),"` requests.")]
310 $name,
313 $method
314 );
315 };
316
317 (
318 $(#[$m:meta])+
319 $name:ident, $method:ident
320 ) => {
321 $(#[$m])+
322 #[track_caller]
323 pub fn $name<H, T>(self, handler: H) -> Self
324 where
325 H: Handler<T, S>,
326 T: 'static,
327 S: Send + Sync + 'static,
328 {
329 self.on(MethodFilter::$method, handler)
330 }
331 };
332}
333
334top_level_service_fn!(connect_service, CONNECT);
335top_level_service_fn!(delete_service, DELETE);
336top_level_service_fn!(get_service, GET);
337top_level_service_fn!(head_service, HEAD);
338top_level_service_fn!(options_service, OPTIONS);
339top_level_service_fn!(patch_service, PATCH);
340top_level_service_fn!(post_service, POST);
341top_level_service_fn!(put_service, PUT);
342top_level_service_fn!(trace_service, TRACE);
343
344pub fn on_service<T, S>(filter: MethodFilter, svc: T) -> MethodRouter<S, T::Error>
368where
369 T: Service<Request> + Clone + Send + Sync + 'static,
370 T::Response: IntoResponse + 'static,
371 T::Future: Send + 'static,
372 S: Clone,
373{
374 MethodRouter::new().on_service(filter, svc)
375}
376
377pub fn any_service<T, S>(svc: T) -> MethodRouter<S, T::Error>
427where
428 T: Service<Request> + Clone + Send + Sync + 'static,
429 T::Response: IntoResponse + 'static,
430 T::Future: Send + 'static,
431 S: Clone,
432{
433 MethodRouter::new()
434 .fallback_service(svc)
435 .skip_allow_header()
436}
437
438top_level_handler_fn!(connect, CONNECT);
439top_level_handler_fn!(delete, DELETE);
440top_level_handler_fn!(get, GET);
441top_level_handler_fn!(head, HEAD);
442top_level_handler_fn!(options, OPTIONS);
443top_level_handler_fn!(patch, PATCH);
444top_level_handler_fn!(post, POST);
445top_level_handler_fn!(put, PUT);
446top_level_handler_fn!(trace, TRACE);
447
448pub fn on<H, T, S>(filter: MethodFilter, handler: H) -> MethodRouter<S, Infallible>
466where
467 H: Handler<T, S>,
468 T: 'static,
469 S: Clone + Send + Sync + 'static,
470{
471 MethodRouter::new().on(filter, handler)
472}
473
474pub fn any<H, T, S>(handler: H) -> MethodRouter<S, Infallible>
508where
509 H: Handler<T, S>,
510 T: 'static,
511 S: Clone + Send + Sync + 'static,
512{
513 MethodRouter::new().fallback(handler).skip_allow_header()
514}
515
516#[must_use]
546pub struct MethodRouter<S = (), E = Infallible> {
547 get: MethodEndpoint<S, E>,
548 head: MethodEndpoint<S, E>,
549 delete: MethodEndpoint<S, E>,
550 options: MethodEndpoint<S, E>,
551 patch: MethodEndpoint<S, E>,
552 post: MethodEndpoint<S, E>,
553 put: MethodEndpoint<S, E>,
554 trace: MethodEndpoint<S, E>,
555 connect: MethodEndpoint<S, E>,
556 fallback: Fallback<S, E>,
557 allow_header: AllowHeader,
558}
559
560#[derive(Clone, Debug)]
561enum AllowHeader {
562 None,
564 Skip,
566 Bytes(BytesMut),
568}
569
570impl AllowHeader {
571 fn merge(self, other: Self) -> Self {
572 match (self, other) {
573 (AllowHeader::Skip, _) | (_, AllowHeader::Skip) => AllowHeader::Skip,
574 (AllowHeader::None, AllowHeader::None) => AllowHeader::None,
575 (AllowHeader::None, AllowHeader::Bytes(pick)) => AllowHeader::Bytes(pick),
576 (AllowHeader::Bytes(pick), AllowHeader::None) => AllowHeader::Bytes(pick),
577 (AllowHeader::Bytes(mut a), AllowHeader::Bytes(b)) => {
578 a.extend_from_slice(b",");
579 a.extend_from_slice(&b);
580 AllowHeader::Bytes(a)
581 }
582 }
583 }
584}
585
586impl<S, E> fmt::Debug for MethodRouter<S, E> {
587 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
588 f.debug_struct("MethodRouter")
589 .field("get", &self.get)
590 .field("head", &self.head)
591 .field("delete", &self.delete)
592 .field("options", &self.options)
593 .field("patch", &self.patch)
594 .field("post", &self.post)
595 .field("put", &self.put)
596 .field("trace", &self.trace)
597 .field("connect", &self.connect)
598 .field("fallback", &self.fallback)
599 .field("allow_header", &self.allow_header)
600 .finish()
601 }
602}
603
604impl<S> MethodRouter<S, Infallible>
605where
606 S: Clone,
607{
608 #[track_caller]
630 pub fn on<H, T>(self, filter: MethodFilter, handler: H) -> Self
631 where
632 H: Handler<T, S>,
633 T: 'static,
634 S: Send + Sync + 'static,
635 {
636 self.on_endpoint(
637 filter,
638 MethodEndpoint::BoxedHandler(BoxedIntoRoute::from_handler(handler)),
639 )
640 }
641
642 chained_handler_fn!(connect, CONNECT);
643 chained_handler_fn!(delete, DELETE);
644 chained_handler_fn!(get, GET);
645 chained_handler_fn!(head, HEAD);
646 chained_handler_fn!(options, OPTIONS);
647 chained_handler_fn!(patch, PATCH);
648 chained_handler_fn!(post, POST);
649 chained_handler_fn!(put, PUT);
650 chained_handler_fn!(trace, TRACE);
651
652 pub fn fallback<H, T>(mut self, handler: H) -> Self
654 where
655 H: Handler<T, S>,
656 T: 'static,
657 S: Send + Sync + 'static,
658 {
659 self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler));
660 self
661 }
662
663 pub(crate) fn default_fallback<H, T>(self, handler: H) -> Self
665 where
666 H: Handler<T, S>,
667 T: 'static,
668 S: Send + Sync + 'static,
669 {
670 match self.fallback {
671 Fallback::Default(_) => self.fallback(handler),
672 _ => self,
673 }
674 }
675}
676
677impl MethodRouter<(), Infallible> {
678 pub fn into_make_service(self) -> IntoMakeService<Self> {
706 IntoMakeService::new(self.with_state(()))
707 }
708
709 #[cfg(feature = "tokio")]
738 pub fn into_make_service_with_connect_info<C>(self) -> IntoMakeServiceWithConnectInfo<Self, C> {
739 IntoMakeServiceWithConnectInfo::new(self.with_state(()))
740 }
741}
742
743impl<S, E> MethodRouter<S, E>
744where
745 S: Clone,
746{
747 pub fn new() -> Self {
750 let fallback = Route::new(service_fn(|_: Request| async {
751 Ok(StatusCode::METHOD_NOT_ALLOWED.into_response())
752 }));
753
754 Self {
755 get: MethodEndpoint::None,
756 head: MethodEndpoint::None,
757 delete: MethodEndpoint::None,
758 options: MethodEndpoint::None,
759 patch: MethodEndpoint::None,
760 post: MethodEndpoint::None,
761 put: MethodEndpoint::None,
762 trace: MethodEndpoint::None,
763 connect: MethodEndpoint::None,
764 allow_header: AllowHeader::None,
765 fallback: Fallback::Default(fallback),
766 }
767 }
768
769 pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E> {
771 MethodRouter {
772 get: self.get.with_state(&state),
773 head: self.head.with_state(&state),
774 delete: self.delete.with_state(&state),
775 options: self.options.with_state(&state),
776 patch: self.patch.with_state(&state),
777 post: self.post.with_state(&state),
778 put: self.put.with_state(&state),
779 trace: self.trace.with_state(&state),
780 connect: self.connect.with_state(&state),
781 allow_header: self.allow_header,
782 fallback: self.fallback.with_state(state),
783 }
784 }
785
786 #[track_caller]
810 pub fn on_service<T>(self, filter: MethodFilter, svc: T) -> Self
811 where
812 T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
813 T::Response: IntoResponse + 'static,
814 T::Future: Send + 'static,
815 {
816 self.on_endpoint(filter, MethodEndpoint::Route(Route::new(svc)))
817 }
818
819 #[track_caller]
820 fn on_endpoint(mut self, filter: MethodFilter, endpoint: MethodEndpoint<S, E>) -> Self {
821 #[track_caller]
823 fn set_endpoint<S, E>(
824 method_name: &str,
825 out: &mut MethodEndpoint<S, E>,
826 endpoint: &MethodEndpoint<S, E>,
827 endpoint_filter: MethodFilter,
828 filter: MethodFilter,
829 allow_header: &mut AllowHeader,
830 methods: &[&'static str],
831 ) where
832 MethodEndpoint<S, E>: Clone,
833 S: Clone,
834 {
835 if endpoint_filter.contains(filter) {
836 if out.is_some() {
837 panic!(
838 "Overlapping method route. Cannot add two method routes that both handle \
839 `{method_name}`",
840 )
841 }
842 *out = endpoint.clone();
843 for method in methods {
844 append_allow_header(allow_header, method);
845 }
846 }
847 }
848
849 set_endpoint(
850 "GET",
851 &mut self.get,
852 &endpoint,
853 filter,
854 MethodFilter::GET,
855 &mut self.allow_header,
856 &["GET", "HEAD"],
857 );
858
859 set_endpoint(
860 "HEAD",
861 &mut self.head,
862 &endpoint,
863 filter,
864 MethodFilter::HEAD,
865 &mut self.allow_header,
866 &["HEAD"],
867 );
868
869 set_endpoint(
870 "TRACE",
871 &mut self.trace,
872 &endpoint,
873 filter,
874 MethodFilter::TRACE,
875 &mut self.allow_header,
876 &["TRACE"],
877 );
878
879 set_endpoint(
880 "PUT",
881 &mut self.put,
882 &endpoint,
883 filter,
884 MethodFilter::PUT,
885 &mut self.allow_header,
886 &["PUT"],
887 );
888
889 set_endpoint(
890 "POST",
891 &mut self.post,
892 &endpoint,
893 filter,
894 MethodFilter::POST,
895 &mut self.allow_header,
896 &["POST"],
897 );
898
899 set_endpoint(
900 "PATCH",
901 &mut self.patch,
902 &endpoint,
903 filter,
904 MethodFilter::PATCH,
905 &mut self.allow_header,
906 &["PATCH"],
907 );
908
909 set_endpoint(
910 "OPTIONS",
911 &mut self.options,
912 &endpoint,
913 filter,
914 MethodFilter::OPTIONS,
915 &mut self.allow_header,
916 &["OPTIONS"],
917 );
918
919 set_endpoint(
920 "DELETE",
921 &mut self.delete,
922 &endpoint,
923 filter,
924 MethodFilter::DELETE,
925 &mut self.allow_header,
926 &["DELETE"],
927 );
928
929 set_endpoint(
930 "CONNECT",
931 &mut self.options,
932 &endpoint,
933 filter,
934 MethodFilter::CONNECT,
935 &mut self.allow_header,
936 &["CONNECT"],
937 );
938
939 self
940 }
941
942 chained_service_fn!(connect_service, CONNECT);
943 chained_service_fn!(delete_service, DELETE);
944 chained_service_fn!(get_service, GET);
945 chained_service_fn!(head_service, HEAD);
946 chained_service_fn!(options_service, OPTIONS);
947 chained_service_fn!(patch_service, PATCH);
948 chained_service_fn!(post_service, POST);
949 chained_service_fn!(put_service, PUT);
950 chained_service_fn!(trace_service, TRACE);
951
952 #[doc = include_str!("../docs/method_routing/fallback.md")]
953 pub fn fallback_service<T>(mut self, svc: T) -> Self
954 where
955 T: Service<Request, Error = E> + Clone + Send + Sync + 'static,
956 T::Response: IntoResponse + 'static,
957 T::Future: Send + 'static,
958 {
959 self.fallback = Fallback::Service(Route::new(svc));
960 self
961 }
962
963 #[doc = include_str!("../docs/method_routing/layer.md")]
964 pub fn layer<L, NewError>(self, layer: L) -> MethodRouter<S, NewError>
965 where
966 L: Layer<Route<E>> + Clone + Send + Sync + 'static,
967 L::Service: Service<Request> + Clone + Send + Sync + 'static,
968 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
969 <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
970 <L::Service as Service<Request>>::Future: Send + 'static,
971 E: 'static,
972 S: 'static,
973 NewError: 'static,
974 {
975 let layer_fn = move |route: Route<E>| route.layer(layer.clone());
976
977 MethodRouter {
978 get: self.get.map(layer_fn.clone()),
979 head: self.head.map(layer_fn.clone()),
980 delete: self.delete.map(layer_fn.clone()),
981 options: self.options.map(layer_fn.clone()),
982 patch: self.patch.map(layer_fn.clone()),
983 post: self.post.map(layer_fn.clone()),
984 put: self.put.map(layer_fn.clone()),
985 trace: self.trace.map(layer_fn.clone()),
986 connect: self.connect.map(layer_fn.clone()),
987 fallback: self.fallback.map(layer_fn),
988 allow_header: self.allow_header,
989 }
990 }
991
992 #[doc = include_str!("../docs/method_routing/route_layer.md")]
993 #[track_caller]
994 pub fn route_layer<L>(mut self, layer: L) -> MethodRouter<S, E>
995 where
996 L: Layer<Route<E>> + Clone + Send + Sync + 'static,
997 L::Service: Service<Request, Error = E> + Clone + Send + Sync + 'static,
998 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
999 <L::Service as Service<Request>>::Future: Send + 'static,
1000 E: 'static,
1001 S: 'static,
1002 {
1003 if self.get.is_none()
1004 && self.head.is_none()
1005 && self.delete.is_none()
1006 && self.options.is_none()
1007 && self.patch.is_none()
1008 && self.post.is_none()
1009 && self.put.is_none()
1010 && self.trace.is_none()
1011 && self.connect.is_none()
1012 {
1013 panic!(
1014 "Adding a route_layer before any routes is a no-op. \
1015 Add the routes you want the layer to apply to first."
1016 );
1017 }
1018
1019 let layer_fn = move |svc| {
1020 let svc = layer.layer(svc);
1021 let svc = MapResponseLayer::new(IntoResponse::into_response).layer(svc);
1022 Route::new(svc)
1023 };
1024
1025 self.get = self.get.map(layer_fn.clone());
1026 self.head = self.head.map(layer_fn.clone());
1027 self.delete = self.delete.map(layer_fn.clone());
1028 self.options = self.options.map(layer_fn.clone());
1029 self.patch = self.patch.map(layer_fn.clone());
1030 self.post = self.post.map(layer_fn.clone());
1031 self.put = self.put.map(layer_fn.clone());
1032 self.trace = self.trace.map(layer_fn.clone());
1033 self.connect = self.connect.map(layer_fn);
1034
1035 self
1036 }
1037
1038 #[track_caller]
1039 pub(crate) fn merge_for_path(mut self, path: Option<&str>, other: MethodRouter<S, E>) -> Self {
1040 #[track_caller]
1042 fn merge_inner<S, E>(
1043 path: Option<&str>,
1044 name: &str,
1045 first: MethodEndpoint<S, E>,
1046 second: MethodEndpoint<S, E>,
1047 ) -> MethodEndpoint<S, E> {
1048 match (first, second) {
1049 (MethodEndpoint::None, MethodEndpoint::None) => MethodEndpoint::None,
1050 (pick, MethodEndpoint::None) | (MethodEndpoint::None, pick) => pick,
1051 _ => {
1052 if let Some(path) = path {
1053 panic!(
1054 "Overlapping method route. Handler for `{name} {path}` already exists"
1055 );
1056 } else {
1057 panic!(
1058 "Overlapping method route. Cannot merge two method routes that both \
1059 define `{name}`"
1060 );
1061 }
1062 }
1063 }
1064 }
1065
1066 self.get = merge_inner(path, "GET", self.get, other.get);
1067 self.head = merge_inner(path, "HEAD", self.head, other.head);
1068 self.delete = merge_inner(path, "DELETE", self.delete, other.delete);
1069 self.options = merge_inner(path, "OPTIONS", self.options, other.options);
1070 self.patch = merge_inner(path, "PATCH", self.patch, other.patch);
1071 self.post = merge_inner(path, "POST", self.post, other.post);
1072 self.put = merge_inner(path, "PUT", self.put, other.put);
1073 self.trace = merge_inner(path, "TRACE", self.trace, other.trace);
1074 self.connect = merge_inner(path, "CONNECT", self.connect, other.connect);
1075
1076 self.fallback = self
1077 .fallback
1078 .merge(other.fallback)
1079 .expect("Cannot merge two `MethodRouter`s that both have a fallback");
1080
1081 self.allow_header = self.allow_header.merge(other.allow_header);
1082
1083 self
1084 }
1085
1086 #[doc = include_str!("../docs/method_routing/merge.md")]
1087 #[track_caller]
1088 pub fn merge(self, other: MethodRouter<S, E>) -> Self {
1089 self.merge_for_path(None, other)
1090 }
1091
1092 pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, Infallible>
1096 where
1097 F: Clone + Send + Sync + 'static,
1098 HandleError<Route<E>, F, T>: Service<Request, Error = Infallible>,
1099 <HandleError<Route<E>, F, T> as Service<Request>>::Future: Send,
1100 <HandleError<Route<E>, F, T> as Service<Request>>::Response: IntoResponse + Send,
1101 T: 'static,
1102 E: 'static,
1103 S: 'static,
1104 {
1105 self.layer(HandleErrorLayer::new(f))
1106 }
1107
1108 fn skip_allow_header(mut self) -> Self {
1109 self.allow_header = AllowHeader::Skip;
1110 self
1111 }
1112
1113 pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture<E> {
1114 macro_rules! call {
1115 (
1116 $req:expr,
1117 $method_variant:ident,
1118 $svc:expr
1119 ) => {
1120 if *req.method() == Method::$method_variant {
1121 match $svc {
1122 MethodEndpoint::None => {}
1123 MethodEndpoint::Route(route) => {
1124 return route.clone().oneshot_inner_owned($req);
1125 }
1126 MethodEndpoint::BoxedHandler(handler) => {
1127 let route = handler.clone().into_route(state);
1128 return route.oneshot_inner_owned($req);
1129 }
1130 }
1131 }
1132 };
1133 }
1134
1135 let Self {
1137 get,
1138 head,
1139 delete,
1140 options,
1141 patch,
1142 post,
1143 put,
1144 trace,
1145 connect,
1146 fallback,
1147 allow_header,
1148 } = self;
1149
1150 call!(req, HEAD, head);
1151 call!(req, HEAD, get);
1152 call!(req, GET, get);
1153 call!(req, POST, post);
1154 call!(req, OPTIONS, options);
1155 call!(req, PATCH, patch);
1156 call!(req, PUT, put);
1157 call!(req, DELETE, delete);
1158 call!(req, TRACE, trace);
1159 call!(req, CONNECT, connect);
1160
1161 let future = fallback.clone().call_with_state(req, state);
1162
1163 match allow_header {
1164 AllowHeader::None => future.allow_header(Bytes::new()),
1165 AllowHeader::Skip => future,
1166 AllowHeader::Bytes(allow_header) => future.allow_header(allow_header.clone().freeze()),
1167 }
1168 }
1169}
1170
1171fn append_allow_header(allow_header: &mut AllowHeader, method: &'static str) {
1172 match allow_header {
1173 AllowHeader::None => {
1174 *allow_header = AllowHeader::Bytes(BytesMut::from(method));
1175 }
1176 AllowHeader::Skip => {}
1177 AllowHeader::Bytes(allow_header) => {
1178 if let Ok(s) = std::str::from_utf8(allow_header) {
1179 if !s.contains(method) {
1180 allow_header.extend_from_slice(b",");
1181 allow_header.extend_from_slice(method.as_bytes());
1182 }
1183 } else {
1184 #[cfg(debug_assertions)]
1185 panic!("`allow_header` contained invalid uft-8. This should never happen")
1186 }
1187 }
1188 }
1189}
1190
1191impl<S, E> Clone for MethodRouter<S, E> {
1192 fn clone(&self) -> Self {
1193 Self {
1194 get: self.get.clone(),
1195 head: self.head.clone(),
1196 delete: self.delete.clone(),
1197 options: self.options.clone(),
1198 patch: self.patch.clone(),
1199 post: self.post.clone(),
1200 put: self.put.clone(),
1201 trace: self.trace.clone(),
1202 connect: self.connect.clone(),
1203 fallback: self.fallback.clone(),
1204 allow_header: self.allow_header.clone(),
1205 }
1206 }
1207}
1208
1209impl<S, E> Default for MethodRouter<S, E>
1210where
1211 S: Clone,
1212{
1213 fn default() -> Self {
1214 Self::new()
1215 }
1216}
1217
1218enum MethodEndpoint<S, E> {
1219 None,
1220 Route(Route<E>),
1221 BoxedHandler(BoxedIntoRoute<S, E>),
1222}
1223
1224impl<S, E> MethodEndpoint<S, E>
1225where
1226 S: Clone,
1227{
1228 fn is_some(&self) -> bool {
1229 matches!(self, Self::Route(_) | Self::BoxedHandler(_))
1230 }
1231
1232 fn is_none(&self) -> bool {
1233 matches!(self, Self::None)
1234 }
1235
1236 fn map<F, E2>(self, f: F) -> MethodEndpoint<S, E2>
1237 where
1238 S: 'static,
1239 E: 'static,
1240 F: FnOnce(Route<E>) -> Route<E2> + Clone + Send + Sync + 'static,
1241 E2: 'static,
1242 {
1243 match self {
1244 Self::None => MethodEndpoint::None,
1245 Self::Route(route) => MethodEndpoint::Route(f(route)),
1246 Self::BoxedHandler(handler) => MethodEndpoint::BoxedHandler(handler.map(f)),
1247 }
1248 }
1249
1250 fn with_state<S2>(self, state: &S) -> MethodEndpoint<S2, E> {
1251 match self {
1252 MethodEndpoint::None => MethodEndpoint::None,
1253 MethodEndpoint::Route(route) => MethodEndpoint::Route(route),
1254 MethodEndpoint::BoxedHandler(handler) => {
1255 MethodEndpoint::Route(handler.into_route(state.clone()))
1256 }
1257 }
1258 }
1259}
1260
1261impl<S, E> Clone for MethodEndpoint<S, E> {
1262 fn clone(&self) -> Self {
1263 match self {
1264 Self::None => Self::None,
1265 Self::Route(inner) => Self::Route(inner.clone()),
1266 Self::BoxedHandler(inner) => Self::BoxedHandler(inner.clone()),
1267 }
1268 }
1269}
1270
1271impl<S, E> fmt::Debug for MethodEndpoint<S, E> {
1272 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1273 match self {
1274 Self::None => f.debug_tuple("None").finish(),
1275 Self::Route(inner) => inner.fmt(f),
1276 Self::BoxedHandler(_) => f.debug_tuple("BoxedHandler").finish(),
1277 }
1278 }
1279}
1280
1281impl<B, E> Service<Request<B>> for MethodRouter<(), E>
1282where
1283 B: HttpBody<Data = Bytes> + Send + 'static,
1284 B::Error: Into<BoxError>,
1285{
1286 type Response = Response;
1287 type Error = E;
1288 type Future = RouteFuture<E>;
1289
1290 #[inline]
1291 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1292 Poll::Ready(Ok(()))
1293 }
1294
1295 #[inline]
1296 fn call(&mut self, req: Request<B>) -> Self::Future {
1297 let req = req.map(Body::new);
1298 self.call_with_state(req, ())
1299 }
1300}
1301
1302impl<S> Handler<(), S> for MethodRouter<S>
1303where
1304 S: Clone + 'static,
1305{
1306 type Future = InfallibleRouteFuture;
1307
1308 fn call(self, req: Request, state: S) -> Self::Future {
1309 InfallibleRouteFuture::new(self.call_with_state(req, state))
1310 }
1311}
1312
1313#[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))]
1315const _: () = {
1316 use crate::serve;
1317
1318 impl<L> Service<serve::IncomingStream<'_, L>> for MethodRouter<()>
1319 where
1320 L: serve::Listener,
1321 {
1322 type Response = Self;
1323 type Error = Infallible;
1324 type Future = std::future::Ready<Result<Self::Response, Self::Error>>;
1325
1326 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
1327 Poll::Ready(Ok(()))
1328 }
1329
1330 fn call(&mut self, _req: serve::IncomingStream<'_, L>) -> Self::Future {
1331 std::future::ready(Ok(self.clone().with_state(())))
1332 }
1333 }
1334};
1335
1336#[cfg(test)]
1337mod tests {
1338 use super::*;
1339 use crate::{extract::State, handler::HandlerWithoutStateExt};
1340 use http::{header::ALLOW, HeaderMap};
1341 use http_body_util::BodyExt;
1342 use std::time::Duration;
1343 use tower::ServiceExt;
1344 use tower_http::{
1345 services::fs::ServeDir, timeout::TimeoutLayer, validate_request::ValidateRequestHeaderLayer,
1346 };
1347
1348 #[crate::test]
1349 async fn method_not_allowed_by_default() {
1350 let mut svc = MethodRouter::new();
1351 let (status, _, body) = call(Method::GET, &mut svc).await;
1352 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1353 assert!(body.is_empty());
1354 }
1355
1356 #[crate::test]
1357 async fn get_service_fn() {
1358 async fn handle(_req: Request) -> Result<Response<Body>, Infallible> {
1359 Ok(Response::new(Body::from("ok")))
1360 }
1361
1362 let mut svc = get_service(service_fn(handle));
1363
1364 let (status, _, body) = call(Method::GET, &mut svc).await;
1365 assert_eq!(status, StatusCode::OK);
1366 assert_eq!(body, "ok");
1367 }
1368
1369 #[crate::test]
1370 async fn get_handler() {
1371 let mut svc = MethodRouter::new().get(ok);
1372 let (status, _, body) = call(Method::GET, &mut svc).await;
1373 assert_eq!(status, StatusCode::OK);
1374 assert_eq!(body, "ok");
1375 }
1376
1377 #[crate::test]
1378 async fn get_accepts_head() {
1379 let mut svc = MethodRouter::new().get(ok);
1380 let (status, _, body) = call(Method::HEAD, &mut svc).await;
1381 assert_eq!(status, StatusCode::OK);
1382 assert!(body.is_empty());
1383 }
1384
1385 #[crate::test]
1386 async fn head_takes_precedence_over_get() {
1387 let mut svc = MethodRouter::new().head(created).get(ok);
1388 let (status, _, body) = call(Method::HEAD, &mut svc).await;
1389 assert_eq!(status, StatusCode::CREATED);
1390 assert!(body.is_empty());
1391 }
1392
1393 #[crate::test]
1394 async fn merge() {
1395 let mut svc = get(ok).merge(post(ok));
1396
1397 let (status, _, _) = call(Method::GET, &mut svc).await;
1398 assert_eq!(status, StatusCode::OK);
1399
1400 let (status, _, _) = call(Method::POST, &mut svc).await;
1401 assert_eq!(status, StatusCode::OK);
1402 }
1403
1404 #[crate::test]
1405 async fn layer() {
1406 let mut svc = MethodRouter::new()
1407 .get(|| async { std::future::pending::<()>().await })
1408 .layer(ValidateRequestHeaderLayer::bearer("password"));
1409
1410 let (status, _, _) = call(Method::GET, &mut svc).await;
1412 assert_eq!(status, StatusCode::UNAUTHORIZED);
1413
1414 let (status, _, _) = call(Method::DELETE, &mut svc).await;
1416 assert_eq!(status, StatusCode::UNAUTHORIZED);
1417 }
1418
1419 #[crate::test]
1420 async fn route_layer() {
1421 let mut svc = MethodRouter::new()
1422 .get(|| async { std::future::pending::<()>().await })
1423 .route_layer(ValidateRequestHeaderLayer::bearer("password"));
1424
1425 let (status, _, _) = call(Method::GET, &mut svc).await;
1427 assert_eq!(status, StatusCode::UNAUTHORIZED);
1428
1429 let (status, _, _) = call(Method::DELETE, &mut svc).await;
1431 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1432 }
1433
1434 #[allow(dead_code)]
1435 async fn building_complex_router() {
1436 let app = crate::Router::new().route(
1437 "/",
1438 get(ok)
1440 .post(ok)
1441 .route_layer(ValidateRequestHeaderLayer::bearer("password"))
1442 .merge(delete_service(ServeDir::new(".")))
1443 .fallback(|| async { StatusCode::NOT_FOUND })
1444 .put(ok)
1445 .layer(TimeoutLayer::new(Duration::from_secs(10))),
1446 );
1447
1448 let listener = tokio::net::TcpListener::bind("0.0.0.0:0").await.unwrap();
1449 crate::serve(listener, app).await.unwrap();
1450 }
1451
1452 #[crate::test]
1453 async fn sets_allow_header() {
1454 let mut svc = MethodRouter::new().put(ok).patch(ok);
1455 let (status, headers, _) = call(Method::GET, &mut svc).await;
1456 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1457 assert_eq!(headers[ALLOW], "PUT,PATCH");
1458 }
1459
1460 #[crate::test]
1461 async fn sets_allow_header_get_head() {
1462 let mut svc = MethodRouter::new().get(ok).head(ok);
1463 let (status, headers, _) = call(Method::PUT, &mut svc).await;
1464 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1465 assert_eq!(headers[ALLOW], "GET,HEAD");
1466 }
1467
1468 #[crate::test]
1469 async fn empty_allow_header_by_default() {
1470 let mut svc = MethodRouter::new();
1471 let (status, headers, _) = call(Method::PATCH, &mut svc).await;
1472 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1473 assert_eq!(headers[ALLOW], "");
1474 }
1475
1476 #[crate::test]
1477 async fn allow_header_when_merging() {
1478 let a = put(ok).patch(ok);
1479 let b = get(ok).head(ok);
1480 let mut svc = a.merge(b);
1481
1482 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1483 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1484 assert_eq!(headers[ALLOW], "PUT,PATCH,GET,HEAD");
1485 }
1486
1487 #[crate::test]
1488 async fn allow_header_any() {
1489 let mut svc = any(ok);
1490
1491 let (status, headers, _) = call(Method::GET, &mut svc).await;
1492 assert_eq!(status, StatusCode::OK);
1493 assert!(!headers.contains_key(ALLOW));
1494 }
1495
1496 #[crate::test]
1497 async fn allow_header_with_fallback() {
1498 let mut svc = MethodRouter::new()
1499 .get(ok)
1500 .fallback(|| async { (StatusCode::METHOD_NOT_ALLOWED, "Method not allowed") });
1501
1502 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1503 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1504 assert_eq!(headers[ALLOW], "GET,HEAD");
1505 }
1506
1507 #[crate::test]
1508 async fn allow_header_with_fallback_that_sets_allow() {
1509 async fn fallback(method: Method) -> Response {
1510 if method == Method::POST {
1511 "OK".into_response()
1512 } else {
1513 (
1514 StatusCode::METHOD_NOT_ALLOWED,
1515 [(ALLOW, "GET,POST")],
1516 "Method not allowed",
1517 )
1518 .into_response()
1519 }
1520 }
1521
1522 let mut svc = MethodRouter::new().get(ok).fallback(fallback);
1523
1524 let (status, _, _) = call(Method::GET, &mut svc).await;
1525 assert_eq!(status, StatusCode::OK);
1526
1527 let (status, _, _) = call(Method::POST, &mut svc).await;
1528 assert_eq!(status, StatusCode::OK);
1529
1530 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1531 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1532 assert_eq!(headers[ALLOW], "GET,POST");
1533 }
1534
1535 #[crate::test]
1536 async fn allow_header_noop_middleware() {
1537 let mut svc = MethodRouter::new()
1538 .get(ok)
1539 .layer(tower::layer::util::Identity::new());
1540
1541 let (status, headers, _) = call(Method::DELETE, &mut svc).await;
1542 assert_eq!(status, StatusCode::METHOD_NOT_ALLOWED);
1543 assert_eq!(headers[ALLOW], "GET,HEAD");
1544 }
1545
1546 #[crate::test]
1547 #[should_panic(
1548 expected = "Overlapping method route. Cannot add two method routes that both handle `GET`"
1549 )]
1550 async fn handler_overlaps() {
1551 let _: MethodRouter<()> = get(ok).get(ok);
1552 }
1553
1554 #[crate::test]
1555 #[should_panic(
1556 expected = "Overlapping method route. Cannot add two method routes that both handle `POST`"
1557 )]
1558 async fn service_overlaps() {
1559 let _: MethodRouter<()> = post_service(ok.into_service()).post_service(ok.into_service());
1560 }
1561
1562 #[crate::test]
1563 async fn get_head_does_not_overlap() {
1564 let _: MethodRouter<()> = get(ok).head(ok);
1565 }
1566
1567 #[crate::test]
1568 async fn head_get_does_not_overlap() {
1569 let _: MethodRouter<()> = head(ok).get(ok);
1570 }
1571
1572 #[crate::test]
1573 async fn accessing_state() {
1574 let mut svc = MethodRouter::new()
1575 .get(|State(state): State<&'static str>| async move { state })
1576 .with_state("state");
1577
1578 let (status, _, text) = call(Method::GET, &mut svc).await;
1579
1580 assert_eq!(status, StatusCode::OK);
1581 assert_eq!(text, "state");
1582 }
1583
1584 #[crate::test]
1585 async fn fallback_accessing_state() {
1586 let mut svc = MethodRouter::new()
1587 .fallback(|State(state): State<&'static str>| async move { state })
1588 .with_state("state");
1589
1590 let (status, _, text) = call(Method::GET, &mut svc).await;
1591
1592 assert_eq!(status, StatusCode::OK);
1593 assert_eq!(text, "state");
1594 }
1595
1596 #[crate::test]
1597 async fn merge_accessing_state() {
1598 let one = get(|State(state): State<&'static str>| async move { state });
1599 let two = post(|State(state): State<&'static str>| async move { state });
1600
1601 let mut svc = one.merge(two).with_state("state");
1602
1603 let (status, _, text) = call(Method::GET, &mut svc).await;
1604 assert_eq!(status, StatusCode::OK);
1605 assert_eq!(text, "state");
1606
1607 let (status, _, _) = call(Method::POST, &mut svc).await;
1608 assert_eq!(status, StatusCode::OK);
1609 assert_eq!(text, "state");
1610 }
1611
1612 async fn call<S>(method: Method, svc: &mut S) -> (StatusCode, HeaderMap, String)
1613 where
1614 S: Service<Request, Error = Infallible>,
1615 S::Response: IntoResponse,
1616 {
1617 let request = Request::builder()
1618 .uri("/")
1619 .method(method)
1620 .body(Body::empty())
1621 .unwrap();
1622 let response = svc
1623 .ready()
1624 .await
1625 .unwrap()
1626 .call(request)
1627 .await
1628 .unwrap()
1629 .into_response();
1630 let (parts, body) = response.into_parts();
1631 let body =
1632 String::from_utf8(BodyExt::collect(body).await.unwrap().to_bytes().to_vec()).unwrap();
1633 (parts.status, parts.headers, body)
1634 }
1635
1636 async fn ok() -> (StatusCode, &'static str) {
1637 (StatusCode::OK, "ok")
1638 }
1639
1640 async fn created() -> (StatusCode, &'static str) {
1641 (StatusCode::CREATED, "created")
1642 }
1643}