
1use self::RejectionKind::*;
2use super::Response;
3use crate::request::{Context, RequestContext};
4use zino_core::{SharedString, error::Error, trace::TraceContext, validation::Validation, warn};
6/// A rejection response type.
8pub struct Rejection {
9    /// Rejection kind.
10    kind: RejectionKind,
11    /// Optional context.
12    context: Option<Context>,
13    /// Optional trace context.
14    trace_context: Option<TraceContext>,
17/// Rejection kind.
20enum RejectionKind {
21    /// 400 Bad Request
22    BadRequest(Validation),
23    /// 401 Unauthorized
24    Unauthorized(Error),
25    /// 403 Forbidden
26    Forbidden(Error),
27    /// 404 NotFound
28    NotFound(Error),
29    /// 405 Method Not Allowed
30    MethodNotAllowed(Error),
31    /// 409 Conflict
32    Conflict(Error),
33    /// 500 Internal Server Error
34    InternalServerError(Error),
35    /// 503 Service Unavailable
36    ServiceUnavailable(Error),
39impl Rejection {
40    /// Creates a `400 Bad Request` rejection.
41    #[inline]
42    pub fn bad_request(validation: Validation) -> Self {
43        Self {
44            kind: BadRequest(validation),
45            context: None,
46            trace_context: None,
47        }
48    }
50    /// Creates a `401 Unauthorized` rejection.
51    #[inline]
52    pub fn unauthorized(err: impl Into<Error>) -> Self {
53        Self {
54            kind: Unauthorized(err.into()),
55            context: None,
56            trace_context: None,
57        }
58    }
60    /// Creates a `403 Forbidden` rejection.
61    #[inline]
62    pub fn forbidden(err: impl Into<Error>) -> Self {
63        Self {
64            kind: Forbidden(err.into()),
65            context: None,
66            trace_context: None,
67        }
68    }
70    /// Creates a `404 Not Found` rejection.
71    #[inline]
72    pub fn not_found(err: impl Into<Error>) -> Self {
73        Self {
74            kind: NotFound(err.into()),
75            context: None,
76            trace_context: None,
77        }
78    }
80    /// Creates a `405 Method Not Allowed` rejection.
81    #[inline]
82    pub fn method_not_allowed(err: impl Into<Error>) -> Self {
83        Self {
84            kind: MethodNotAllowed(err.into()),
85            context: None,
86            trace_context: None,
87        }
88    }
90    /// Creates a `409 Conflict` rejection.
91    #[inline]
92    pub fn conflict(err: impl Into<Error>) -> Self {
93        Self {
94            kind: Conflict(err.into()),
95            context: None,
96            trace_context: None,
97        }
98    }
100    /// Creates a `500 Internal Server Error` rejection.
101    #[inline]
102    pub fn internal_server_error(err: impl Into<Error>) -> Self {
103        Self {
104            kind: InternalServerError(err.into()),
105            context: None,
106            trace_context: None,
107        }
108    }
110    /// Creates a `503 Service Unavailable` rejection.
111    #[inline]
112    pub fn service_unavailable(err: impl Into<Error>) -> Self {
113        Self {
114            kind: ServiceUnavailable(err.into()),
115            context: None,
116            trace_context: None,
117        }
118    }
120    /// Creates a new instance with the validation entry.
121    #[inline]
122    pub fn from_validation_entry(key: impl Into<SharedString>, err: impl Into<Error>) -> Self {
123        let validation = Validation::from_entry(key, err);
124        Self::bad_request(validation)
125    }
127    /// Creates a new instance from an error classified by the error message.
128    pub fn from_error(err: impl Into<Error>) -> Self {
129        fn inner(err: Error) -> Rejection {
130            let message = err.message();
131            if message.starts_with("401 Unauthorized") {
132                Rejection::unauthorized(err)
133            } else if message.starts_with("403 Forbidden") {
134                Rejection::forbidden(err)
135            } else if message.starts_with("404 Not Found") {
136                Rejection::not_found(err)
137            } else if message.starts_with("405 Method Not Allowed") {
138                Rejection::method_not_allowed(err)
139            } else if message.starts_with("409 Conflict") {
140                Rejection::conflict(err)
141            } else if message.starts_with("503 Service Unavailable") {
142                Rejection::service_unavailable(err)
143            } else {
144                Rejection::internal_server_error(err)
145            }
146        }
147        inner(err.into())
148    }
150    /// Creates a new instance with the error message.
151    #[inline]
152    pub fn with_message(message: impl Into<SharedString>) -> Self {
153        Self::from_error(Error::new(message))
154    }
156    /// Provides the request context for the rejection.
157    #[inline]
158    pub fn context<T: RequestContext + ?Sized>(mut self, ctx: &T) -> Self {
159        self.context = ctx.get_context();
160        self.trace_context = Some(ctx.new_trace_context());
161        self
162    }
164    /// Returns the status code as `u16`.
165    #[inline]
166    pub fn status_code(&self) -> u16 {
167        match &self.kind {
168            BadRequest(_) => 400,
169            Unauthorized(_) => 401,
170            Forbidden(_) => 403,
171            NotFound(_) => 404,
172            MethodNotAllowed(_) => 405,
173            Conflict(_) => 409,
174            InternalServerError(_) => 500,
175            ServiceUnavailable(_) => 503,
176        }
177    }
180macro_rules! impl_from_rejection {
181    ($Ty:ty) => {
182        impl From<Rejection> for Response<$Ty> {
183            fn from(rejection: Rejection) -> Self {
184                let mut res = match rejection.kind {
185                    BadRequest(validation) => {
186                        let mut res = Response::new(<$Ty>::BAD_REQUEST);
187                        res.set_validation_data(validation);
188                        res
189                    }
190                    Unauthorized(err) => {
191                        let mut res = Response::new(<$Ty>::UNAUTHORIZED);
192                        res.set_error_message(err);
193                        res
194                    }
195                    Forbidden(err) => {
196                        let mut res = Response::new(<$Ty>::FORBIDDEN);
197                        res.set_error_message(err);
198                        res
199                    }
200                    NotFound(err) => {
201                        let mut res = Response::new(<$Ty>::NOT_FOUND);
202                        res.set_error_message(err);
203                        res
204                    }
205                    MethodNotAllowed(err) => {
206                        let mut res = Response::new(<$Ty>::METHOD_NOT_ALLOWED);
207                        res.set_error_message(err);
208                        res
209                    }
210                    Conflict(err) => {
211                        let mut res = Response::new(<$Ty>::CONFLICT);
212                        res.set_error_message(err);
213                        res
214                    }
215                    InternalServerError(err) => {
216                        let mut res = Response::new(<$Ty>::INTERNAL_SERVER_ERROR);
217                        res.set_error_message(err);
218                        res
219                    }
220                    ServiceUnavailable(err) => {
221                        let mut res = Response::new(<$Ty>::SERVICE_UNAVAILABLE);
222                        res.set_error_message(err);
223                        res
224                    }
225                };
226                if let Some(ctx) = rejection.context {
227                    res.set_instance(ctx.instance().to_owned());
228                    res.set_start_time(ctx.start_time());
229                    res.set_request_id(ctx.request_id());
230                }
231                res.set_trace_context(rejection.trace_context);
232                res
233            }
234        }
235    };
240#[cfg(feature = "http02")]
243/// Trait for extracting rejections.
244pub trait ExtractRejection<T> {
245    /// Extracts a rejection with the request context.
246    fn extract<Ctx: RequestContext>(self, ctx: &Ctx) -> Result<T, Rejection>;
249impl<T> ExtractRejection<T> for Option<T> {
250    #[inline]
251    fn extract<Ctx: RequestContext>(self, ctx: &Ctx) -> Result<T, Rejection> {
252        self.ok_or_else(|| Rejection::not_found(warn!("resource does not exist")).context(ctx))
253    }
256impl<T, E: Into<Error>> ExtractRejection<T> for Result<T, E> {
257    #[inline]
258    fn extract<Ctx: RequestContext>(self, ctx: &Ctx) -> Result<T, Rejection> {
259        self.map_err(|err| Rejection::from_error(err).context(ctx))
260    }
263impl<T, E: Into<Error>> ExtractRejection<T> for Result<Option<T>, E> {
264    #[inline]
265    fn extract<Ctx: RequestContext>(self, ctx: &Ctx) -> Result<T, Rejection> {
266        self.map_err(|err| Rejection::from_error(err).context(ctx))?
267            .ok_or_else(|| Rejection::not_found(warn!("resource does not exist")).context(ctx))
268    }
271/// Returns early with a [`Rejection`].
273macro_rules! reject {
274    ($ctx:ident, $validation:expr $(,)?) => {{
275        return Err(Rejection::bad_request($validation).context(&$ctx).into());
276    }};
277    ($ctx:ident, $key:literal, $message:literal $(,)?) => {{
278        let err = Error::new($message);
279        warn!("invalid value for `{}`: {}", $key, $message);
280        return Err(Rejection::from_validation_entry($key, err).context(&$ctx).into());
281    }};
282    ($ctx:ident, $key:literal, $err:expr $(,)?) => {{
283        return Err(Rejection::from_validation_entry($key, $err).context(&$ctx).into());
284    }};
285    ($ctx:ident, $kind:ident, $message:literal $(,)?) => {{
286        let err = warn!($message);
287        return Err(Rejection::$kind(err).context(&$ctx).into());
288    }};
289    ($ctx:ident, $kind:ident, $err:expr $(,)?) => {{
290        return Err(Rejection::$kind($err).context(&$ctx).into());
291    }};
292    ($ctx:ident, $kind:ident, $fmt:expr, $($arg:tt)+) => {{
293        let err = warn!($fmt, $($arg)+);
294        return Err(Rejection::$kind(err).context(&$ctx).into());
295    }};