axum_extra/extract/
with_rejection.rs

1use axum::extract::{FromRequest, FromRequestParts, Request};
2use axum::response::IntoResponse;
3use http::request::Parts;
4use std::fmt::{Debug, Display};
5use std::marker::PhantomData;
6use std::ops::{Deref, DerefMut};
7
8#[cfg(feature = "typed-routing")]
9use crate::routing::TypedPath;
10
11/// Extractor for customizing extractor rejections
12///
13/// `WithRejection` wraps another extractor and gives you the result. If the
14/// extraction fails, the `Rejection` is transformed into `R` and returned as a
15/// response
16///
17/// `E` is expected to implement [`FromRequest`]
18///
19/// `R` is expected to implement [`IntoResponse`] and [`From<E::Rejection>`]
20///
21///
22/// # Example
23///
24/// ```rust
25/// use axum::extract::rejection::JsonRejection;
26/// use axum::response::{Response, IntoResponse};
27/// use axum::Json;
28/// use axum_extra::extract::WithRejection;
29/// use serde::Deserialize;
30///
31/// struct MyRejection { /* ... */ }
32///
33/// impl From<JsonRejection> for MyRejection {
34///     fn from(rejection: JsonRejection) -> MyRejection {
35///         // ...
36///         # todo!()
37///     }
38/// }
39///
40/// impl IntoResponse for MyRejection {
41///     fn into_response(self) -> Response {
42///         // ...
43///         # todo!()
44///     }
45/// }
46/// #[derive(Debug, Deserialize)]
47/// struct Person { /* ... */ }
48///
49/// async fn handler(
50///     // If the `Json` extractor ever fails, `MyRejection` will be sent to the
51///     // client using the `IntoResponse` impl
52///     WithRejection(Json(Person), _): WithRejection<Json<Person>, MyRejection>
53/// ) { /* ... */ }
54/// # let _: axum::Router = axum::Router::new().route("/", axum::routing::get(handler));
55/// ```
56///
57/// [`FromRequest`]: axum::extract::FromRequest
58/// [`IntoResponse`]: axum::response::IntoResponse
59/// [`From<E::Rejection>`]: std::convert::From
60pub struct WithRejection<E, R>(pub E, pub PhantomData<R>);
61
62impl<E, R> WithRejection<E, R> {
63    /// Returns the wrapped extractor
64    pub fn into_inner(self) -> E {
65        self.0
66    }
67}
68
69impl<E, R> Debug for WithRejection<E, R>
70where
71    E: Debug,
72{
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        f.debug_tuple("WithRejection")
75            .field(&self.0)
76            .field(&self.1)
77            .finish()
78    }
79}
80
81impl<E, R> Clone for WithRejection<E, R>
82where
83    E: Clone,
84{
85    fn clone(&self) -> Self {
86        Self(self.0.clone(), self.1)
87    }
88}
89
90impl<E, R> Copy for WithRejection<E, R> where E: Copy {}
91
92impl<E: Default, R> Default for WithRejection<E, R> {
93    fn default() -> Self {
94        Self(Default::default(), Default::default())
95    }
96}
97
98impl<E, R> Deref for WithRejection<E, R> {
99    type Target = E;
100
101    fn deref(&self) -> &Self::Target {
102        &self.0
103    }
104}
105
106impl<E, R> DerefMut for WithRejection<E, R> {
107    fn deref_mut(&mut self) -> &mut Self::Target {
108        &mut self.0
109    }
110}
111
112impl<E, R, S> FromRequest<S> for WithRejection<E, R>
113where
114    S: Send + Sync,
115    E: FromRequest<S>,
116    R: From<E::Rejection> + IntoResponse,
117{
118    type Rejection = R;
119
120    async fn from_request(req: Request, state: &S) -> Result<Self, Self::Rejection> {
121        let extractor = E::from_request(req, state).await?;
122        Ok(WithRejection(extractor, PhantomData))
123    }
124}
125
126impl<E, R, S> FromRequestParts<S> for WithRejection<E, R>
127where
128    S: Send + Sync,
129    E: FromRequestParts<S>,
130    R: From<E::Rejection> + IntoResponse,
131{
132    type Rejection = R;
133
134    async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
135        let extractor = E::from_request_parts(parts, state).await?;
136        Ok(WithRejection(extractor, PhantomData))
137    }
138}
139
140#[cfg(feature = "typed-routing")]
141impl<E, R> TypedPath for WithRejection<E, R>
142where
143    E: TypedPath,
144{
145    const PATH: &'static str = E::PATH;
146}
147
148impl<E, R> Display for WithRejection<E, R>
149where
150    E: Display,
151{
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        write!(f, "{}", self.0)
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use axum::body::Body;
161    use axum::http::Request;
162    use axum::response::Response;
163
164    #[tokio::test]
165    async fn extractor_rejection_is_transformed() {
166        struct TestExtractor;
167        struct TestRejection;
168
169        impl<S> FromRequestParts<S> for TestExtractor
170        where
171            S: Send + Sync,
172        {
173            type Rejection = ();
174
175            async fn from_request_parts(
176                _parts: &mut Parts,
177                _state: &S,
178            ) -> Result<Self, Self::Rejection> {
179                Err(())
180            }
181        }
182
183        impl IntoResponse for TestRejection {
184            fn into_response(self) -> Response {
185                ().into_response()
186            }
187        }
188
189        impl From<()> for TestRejection {
190            fn from(_: ()) -> Self {
191                TestRejection
192            }
193        }
194
195        let req = Request::new(Body::empty());
196        let result = WithRejection::<TestExtractor, TestRejection>::from_request(req, &()).await;
197        assert!(matches!(result, Err(TestRejection)));
198
199        let (mut parts, _) = Request::new(()).into_parts();
200        let result =
201            WithRejection::<TestExtractor, TestRejection>::from_request_parts(&mut parts, &())
202                .await;
203        assert!(matches!(result, Err(TestRejection)));
204    }
205}