rama_http/layer/validate_request/
validate_request_header.rs1use super::{AcceptHeader, BoxValidateRequestFn, ValidateRequest};
2use crate::{Request, Response};
3use rama_core::{Context, Layer, Service};
4use rama_utils::macros::define_inner_service_accessors;
5use std::fmt;
6
7pub struct ValidateRequestHeaderLayer<T> {
11 pub(crate) validate: T,
12}
13
14impl<T: fmt::Debug> fmt::Debug for ValidateRequestHeaderLayer<T> {
15 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
16 f.debug_struct("ValidateRequestHeaderLayer")
17 .field("validate", &self.validate)
18 .finish()
19 }
20}
21
22impl<T> Clone for ValidateRequestHeaderLayer<T>
23where
24 T: Clone,
25{
26 fn clone(&self) -> Self {
27 Self {
28 validate: self.validate.clone(),
29 }
30 }
31}
32
33impl<ResBody> ValidateRequestHeaderLayer<AcceptHeader<ResBody>> {
34 pub fn accept(value: &str) -> Self
54 where
55 ResBody: Default,
56 {
57 Self::custom(AcceptHeader::new(value))
58 }
59}
60
61impl<T> ValidateRequestHeaderLayer<T> {
62 pub fn custom(validate: T) -> Self {
64 Self { validate }
65 }
66}
67
68impl<F, A> ValidateRequestHeaderLayer<BoxValidateRequestFn<F, A>> {
69 pub fn custom_fn(validate: F) -> Self {
71 Self {
72 validate: BoxValidateRequestFn::new(validate),
73 }
74 }
75}
76
77impl<S, T> Layer<S> for ValidateRequestHeaderLayer<T>
78where
79 T: Clone,
80{
81 type Service = ValidateRequestHeader<S, T>;
82
83 fn layer(&self, inner: S) -> Self::Service {
84 ValidateRequestHeader::new(inner, self.validate.clone())
85 }
86
87 fn into_layer(self, inner: S) -> Self::Service {
88 ValidateRequestHeader::new(inner, self.validate)
89 }
90}
91
92pub struct ValidateRequestHeader<S, T> {
96 inner: S,
97 pub(crate) validate: T,
98}
99
100impl<S: fmt::Debug, T: fmt::Debug> fmt::Debug for ValidateRequestHeader<S, T> {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 f.debug_struct("ValidateRequestHeader")
103 .field("inner", &self.inner)
104 .field("validate", &self.validate)
105 .finish()
106 }
107}
108
109impl<S, T> Clone for ValidateRequestHeader<S, T>
110where
111 S: Clone,
112 T: Clone,
113{
114 fn clone(&self) -> Self {
115 Self {
116 inner: self.inner.clone(),
117 validate: self.validate.clone(),
118 }
119 }
120}
121
122impl<S, T> ValidateRequestHeader<S, T> {
123 fn new(inner: S, validate: T) -> Self {
124 Self::custom(inner, validate)
125 }
126
127 define_inner_service_accessors!();
128}
129
130impl<S, ResBody> ValidateRequestHeader<S, AcceptHeader<ResBody>> {
131 pub fn accept(inner: S, value: &str) -> Self
140 where
141 ResBody: Default,
142 {
143 Self::custom(inner, AcceptHeader::new(value))
144 }
145}
146
147impl<S, T> ValidateRequestHeader<S, T> {
148 pub fn custom(inner: S, validate: T) -> Self {
150 Self { inner, validate }
151 }
152}
153
154impl<S, F, A> ValidateRequestHeader<S, BoxValidateRequestFn<F, A>> {
155 pub fn custom_fn(inner: S, validate: F) -> Self {
157 Self {
158 inner,
159 validate: BoxValidateRequestFn::new(validate),
160 }
161 }
162}
163
164impl<ReqBody, ResBody, State, S, V> Service<State, Request<ReqBody>> for ValidateRequestHeader<S, V>
165where
166 ReqBody: Send + 'static,
167 ResBody: Send + 'static,
168 State: Clone + Send + Sync + 'static,
169 V: ValidateRequest<State, ReqBody, ResponseBody = ResBody>,
170 S: Service<State, Request<ReqBody>, Response = Response<ResBody>>,
171{
172 type Response = Response<ResBody>;
173 type Error = S::Error;
174
175 async fn serve(
176 &self,
177 ctx: Context<State>,
178 req: Request<ReqBody>,
179 ) -> Result<Self::Response, Self::Error> {
180 match self.validate.validate(ctx, req).await {
181 Ok((ctx, req)) => self.inner.serve(ctx, req).await,
182 Err(res) => Ok(res),
183 }
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 #[allow(unused_imports)]
190 use super::*;
191
192 use crate::{Body, StatusCode, header};
193 use rama_core::{Layer, error::BoxError, service::service_fn};
194
195 #[tokio::test]
196 async fn valid_accept_header() {
197 let service =
198 ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
199
200 let request = Request::get("/")
201 .header(header::ACCEPT, "application/json")
202 .body(Body::empty())
203 .unwrap();
204
205 let res = service.serve(Context::default(), request).await.unwrap();
206
207 assert_eq!(res.status(), StatusCode::OK);
208 }
209
210 #[tokio::test]
211 async fn valid_accept_header_accept_all_json() {
212 let service =
213 ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
214
215 let request = Request::get("/")
216 .header(header::ACCEPT, "application/*")
217 .body(Body::empty())
218 .unwrap();
219
220 let res = service.serve(Context::default(), request).await.unwrap();
221
222 assert_eq!(res.status(), StatusCode::OK);
223 }
224
225 #[tokio::test]
226 async fn valid_accept_header_accept_all() {
227 let service =
228 ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
229
230 let request = Request::get("/")
231 .header(header::ACCEPT, "*/*")
232 .body(Body::empty())
233 .unwrap();
234
235 let res = service.serve(Context::default(), request).await.unwrap();
236
237 assert_eq!(res.status(), StatusCode::OK);
238 }
239
240 #[tokio::test]
241 async fn invalid_accept_header() {
242 let service =
243 ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
244
245 let request = Request::get("/")
246 .header(header::ACCEPT, "invalid")
247 .body(Body::empty())
248 .unwrap();
249
250 let res = service.serve(Context::default(), request).await.unwrap();
251
252 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
253 }
254 #[tokio::test]
255 async fn not_accepted_accept_header_subtype() {
256 let service =
257 ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
258
259 let request = Request::get("/")
260 .header(header::ACCEPT, "application/strings")
261 .body(Body::empty())
262 .unwrap();
263
264 let res = service.serve(Context::default(), request).await.unwrap();
265
266 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
267 }
268
269 #[tokio::test]
270 async fn not_accepted_accept_header() {
271 let service =
272 ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
273
274 let request = Request::get("/")
275 .header(header::ACCEPT, "text/strings")
276 .body(Body::empty())
277 .unwrap();
278
279 let res = service.serve(Context::default(), request).await.unwrap();
280
281 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
282 }
283
284 #[tokio::test]
285 async fn accepted_multiple_header_value() {
286 let service =
287 ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
288
289 let request = Request::get("/")
290 .header(header::ACCEPT, "text/strings")
291 .header(header::ACCEPT, "invalid, application/json")
292 .body(Body::empty())
293 .unwrap();
294
295 let res = service.serve(Context::default(), request).await.unwrap();
296
297 assert_eq!(res.status(), StatusCode::OK);
298 }
299
300 #[tokio::test]
301 async fn accepted_inner_header_value() {
302 let service =
303 ValidateRequestHeaderLayer::accept("application/json").into_layer(service_fn(echo));
304
305 let request = Request::get("/")
306 .header(header::ACCEPT, "text/strings, invalid, application/json")
307 .body(Body::empty())
308 .unwrap();
309
310 let res = service.serve(Context::default(), request).await.unwrap();
311
312 assert_eq!(res.status(), StatusCode::OK);
313 }
314
315 #[tokio::test]
316 async fn accepted_header_with_quotes_valid() {
317 let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\", application/*";
318 let service =
319 ValidateRequestHeaderLayer::accept("application/xml").into_layer(service_fn(echo));
320
321 let request = Request::get("/")
322 .header(header::ACCEPT, value)
323 .body(Body::empty())
324 .unwrap();
325
326 let res = service.serve(Context::default(), request).await.unwrap();
327
328 assert_eq!(res.status(), StatusCode::OK);
329 }
330
331 #[tokio::test]
332 async fn accepted_header_with_quotes_invalid() {
333 let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\"";
334 let service = ValidateRequestHeaderLayer::accept("text/html").into_layer(service_fn(echo));
335
336 let request = Request::get("/")
337 .header(header::ACCEPT, value)
338 .body(Body::empty())
339 .unwrap();
340
341 let res = service.serve(Context::default(), request).await.unwrap();
342
343 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
344 }
345
346 async fn echo<B>(req: Request<B>) -> Result<Response<B>, BoxError> {
347 Ok(Response::new(req.into_body()))
348 }
349}