1use super::HandlerCallWithExtractors;
2use crate::either::Either;
3use axum::{
4 extract::{FromRequest, FromRequestParts, Request},
5 handler::Handler,
6 response::{IntoResponse, Response},
7};
8use futures_util::future::{BoxFuture, Either as EitherFuture, FutureExt, Map};
9use std::{future::Future, marker::PhantomData};
10
11#[allow(missing_debug_implementations)]
16pub struct Or<L, R, Lt, Rt, S> {
17 pub(super) lhs: L,
18 pub(super) rhs: R,
19 pub(super) _marker: PhantomData<fn() -> (Lt, Rt, S)>,
20}
21
22impl<S, L, R, Lt, Rt> HandlerCallWithExtractors<Either<Lt, Rt>, S> for Or<L, R, Lt, Rt, S>
23where
24 L: HandlerCallWithExtractors<Lt, S> + Send + 'static,
25 R: HandlerCallWithExtractors<Rt, S> + Send + 'static,
26 Rt: Send + 'static,
27 Lt: Send + 'static,
28{
29 type Future = EitherFuture<
31 Map<L::Future, fn(<L::Future as Future>::Output) -> Response>,
32 Map<R::Future, fn(<R::Future as Future>::Output) -> Response>,
33 >;
34
35 fn call(
36 self,
37 extractors: Either<Lt, Rt>,
38 state: S,
39 ) -> <Self as HandlerCallWithExtractors<Either<Lt, Rt>, S>>::Future {
40 match extractors {
41 Either::E1(lt) => self
42 .lhs
43 .call(lt, state)
44 .map(IntoResponse::into_response as _)
45 .left_future(),
46 Either::E2(rt) => self
47 .rhs
48 .call(rt, state)
49 .map(IntoResponse::into_response as _)
50 .right_future(),
51 }
52 }
53}
54
55impl<S, L, R, Lt, Rt, M> Handler<(M, Lt, Rt), S> for Or<L, R, Lt, Rt, S>
56where
57 L: HandlerCallWithExtractors<Lt, S> + Clone + Send + Sync + 'static,
58 R: HandlerCallWithExtractors<Rt, S> + Clone + Send + Sync + 'static,
59 Lt: FromRequestParts<S> + Send + 'static,
60 Rt: FromRequest<S, M> + Send + 'static,
61 Lt::Rejection: Send,
62 Rt::Rejection: Send,
63 S: Send + Sync + 'static,
64{
65 type Future = BoxFuture<'static, Response>;
67
68 fn call(self, req: Request, state: S) -> Self::Future {
69 Box::pin(async move {
70 let (mut parts, body) = req.into_parts();
71
72 if let Ok(lt) = Lt::from_request_parts(&mut parts, &state).await {
73 return self.lhs.call(lt, state).await;
74 }
75
76 let req = Request::from_parts(parts, body);
77
78 match Rt::from_request(req, &state).await {
79 Ok(rt) => self.rhs.call(rt, state).await,
80 Err(rejection) => rejection.into_response(),
81 }
82 })
83 }
84}
85
86impl<L, R, Lt, Rt, S> Copy for Or<L, R, Lt, Rt, S>
87where
88 L: Copy,
89 R: Copy,
90{
91}
92
93impl<L, R, Lt, Rt, S> Clone for Or<L, R, Lt, Rt, S>
94where
95 L: Clone,
96 R: Clone,
97{
98 fn clone(&self) -> Self {
99 Self {
100 lhs: self.lhs.clone(),
101 rhs: self.rhs.clone(),
102 _marker: self._marker,
103 }
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110 use crate::test_helpers::*;
111 use axum::{
112 extract::{Path, Query},
113 routing::get,
114 Router,
115 };
116 use serde::Deserialize;
117
118 #[tokio::test]
119 async fn works() {
120 #[derive(Deserialize)]
121 struct Params {
122 a: String,
123 }
124
125 async fn one(Path(id): Path<u32>) -> String {
126 id.to_string()
127 }
128
129 async fn two(Query(params): Query<Params>) -> String {
130 params.a
131 }
132
133 async fn three() -> &'static str {
134 "fallback"
135 }
136
137 let app = Router::new().route("/{id}", get(one.or(two).or(three)));
138
139 let client = TestClient::new(app);
140
141 let res = client.get("/123").await;
142 assert_eq!(res.text().await, "123");
143
144 let res = client.get("/foo?a=bar").await;
145 assert_eq!(res.text().await, "bar");
146
147 let res = client.get("/foo").await;
148 assert_eq!(res.text().await, "fallback");
149 }
150}