axum_extra/handler/
or.rs

1use super::HandlerCallWithExtractors;
2use crate::either::Either;
3use axum::{
4    extract::{FromRequest, FromRequestParts, Request},
5    handler::Handler,
6    response::{IntoResponse, Response},
7};
8use futures_util::future::{BoxFuture, Either as EitherFuture, FutureExt, Map};
9use std::{future::Future, marker::PhantomData};
10
11/// [`Handler`] that runs one [`Handler`] and if that rejects it'll fallback to another
12/// [`Handler`].
13///
14/// Created with [`HandlerCallWithExtractors::or`](super::HandlerCallWithExtractors::or).
15#[allow(missing_debug_implementations)]
16pub struct Or<L, R, Lt, Rt, S> {
17    pub(super) lhs: L,
18    pub(super) rhs: R,
19    pub(super) _marker: PhantomData<fn() -> (Lt, Rt, S)>,
20}
21
22impl<S, L, R, Lt, Rt> HandlerCallWithExtractors<Either<Lt, Rt>, S> for Or<L, R, Lt, Rt, S>
23where
24    L: HandlerCallWithExtractors<Lt, S> + Send + 'static,
25    R: HandlerCallWithExtractors<Rt, S> + Send + 'static,
26    Rt: Send + 'static,
27    Lt: Send + 'static,
28{
29    // this puts `futures_util` in our public API but thats fine in axum-extra
30    type Future = EitherFuture<
31        Map<L::Future, fn(<L::Future as Future>::Output) -> Response>,
32        Map<R::Future, fn(<R::Future as Future>::Output) -> Response>,
33    >;
34
35    fn call(
36        self,
37        extractors: Either<Lt, Rt>,
38        state: S,
39    ) -> <Self as HandlerCallWithExtractors<Either<Lt, Rt>, S>>::Future {
40        match extractors {
41            Either::E1(lt) => self
42                .lhs
43                .call(lt, state)
44                .map(IntoResponse::into_response as _)
45                .left_future(),
46            Either::E2(rt) => self
47                .rhs
48                .call(rt, state)
49                .map(IntoResponse::into_response as _)
50                .right_future(),
51        }
52    }
53}
54
55impl<S, L, R, Lt, Rt, M> Handler<(M, Lt, Rt), S> for Or<L, R, Lt, Rt, S>
56where
57    L: HandlerCallWithExtractors<Lt, S> + Clone + Send + Sync + 'static,
58    R: HandlerCallWithExtractors<Rt, S> + Clone + Send + Sync + 'static,
59    Lt: FromRequestParts<S> + Send + 'static,
60    Rt: FromRequest<S, M> + Send + 'static,
61    Lt::Rejection: Send,
62    Rt::Rejection: Send,
63    S: Send + Sync + 'static,
64{
65    // this puts `futures_util` in our public API but thats fine in axum-extra
66    type Future = BoxFuture<'static, Response>;
67
68    fn call(self, req: Request, state: S) -> Self::Future {
69        Box::pin(async move {
70            let (mut parts, body) = req.into_parts();
71
72            if let Ok(lt) = Lt::from_request_parts(&mut parts, &state).await {
73                return self.lhs.call(lt, state).await;
74            }
75
76            let req = Request::from_parts(parts, body);
77
78            match Rt::from_request(req, &state).await {
79                Ok(rt) => self.rhs.call(rt, state).await,
80                Err(rejection) => rejection.into_response(),
81            }
82        })
83    }
84}
85
86impl<L, R, Lt, Rt, S> Copy for Or<L, R, Lt, Rt, S>
87where
88    L: Copy,
89    R: Copy,
90{
91}
92
93impl<L, R, Lt, Rt, S> Clone for Or<L, R, Lt, Rt, S>
94where
95    L: Clone,
96    R: Clone,
97{
98    fn clone(&self) -> Self {
99        Self {
100            lhs: self.lhs.clone(),
101            rhs: self.rhs.clone(),
102            _marker: self._marker,
103        }
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110    use crate::test_helpers::*;
111    use axum::{
112        extract::{Path, Query},
113        routing::get,
114        Router,
115    };
116    use serde::Deserialize;
117
118    #[tokio::test]
119    async fn works() {
120        #[derive(Deserialize)]
121        struct Params {
122            a: String,
123        }
124
125        async fn one(Path(id): Path<u32>) -> String {
126            id.to_string()
127        }
128
129        async fn two(Query(params): Query<Params>) -> String {
130            params.a
131        }
132
133        async fn three() -> &'static str {
134            "fallback"
135        }
136
137        let app = Router::new().route("/{id}", get(one.or(two).or(three)));
138
139        let client = TestClient::new(app);
140
141        let res = client.get("/123").await;
142        assert_eq!(res.text().await, "123");
143
144        let res = client.get("/foo?a=bar").await;
145        assert_eq!(res.text().await, "bar");
146
147        let res = client.get("/foo").await;
148        assert_eq!(res.text().await, "fallback");
149    }
150}