1use http::{header, Request, Response, StatusCode};
120use mime::{Mime, MimeIter};
121use pin_project_lite::pin_project;
122use std::{
123 fmt,
124 future::Future,
125 marker::PhantomData,
126 pin::Pin,
127 sync::Arc,
128 task::{Context, Poll},
129};
130use tower_layer::Layer;
131use tower_service::Service;
132
133#[derive(Debug, Clone)]
137pub struct ValidateRequestHeaderLayer<T> {
138 validate: T,
139}
140
141impl<ResBody> ValidateRequestHeaderLayer<AcceptHeader<ResBody>> {
142 pub fn accept(value: &str) -> Self
164 where
165 ResBody: Default,
166 {
167 Self::custom(AcceptHeader::new(value))
168 }
169}
170
171impl<T> ValidateRequestHeaderLayer<T> {
172 pub fn custom(validate: T) -> ValidateRequestHeaderLayer<T> {
174 Self { validate }
175 }
176}
177
178impl<S, T> Layer<S> for ValidateRequestHeaderLayer<T>
179where
180 T: Clone,
181{
182 type Service = ValidateRequestHeader<S, T>;
183
184 fn layer(&self, inner: S) -> Self::Service {
185 ValidateRequestHeader::new(inner, self.validate.clone())
186 }
187}
188
189#[derive(Clone, Debug)]
193pub struct ValidateRequestHeader<S, T> {
194 inner: S,
195 validate: T,
196}
197
198impl<S, T> ValidateRequestHeader<S, T> {
199 fn new(inner: S, validate: T) -> Self {
200 Self::custom(inner, validate)
201 }
202
203 define_inner_service_accessors!();
204}
205
206impl<S, ResBody> ValidateRequestHeader<S, AcceptHeader<ResBody>> {
207 pub fn accept(inner: S, value: &str) -> Self
216 where
217 ResBody: Default,
218 {
219 Self::custom(inner, AcceptHeader::new(value))
220 }
221}
222
223impl<S, T> ValidateRequestHeader<S, T> {
224 pub fn custom(inner: S, validate: T) -> ValidateRequestHeader<S, T> {
226 Self { inner, validate }
227 }
228}
229
230impl<ReqBody, ResBody, S, V> Service<Request<ReqBody>> for ValidateRequestHeader<S, V>
231where
232 V: ValidateRequest<ReqBody, ResponseBody = ResBody>,
233 S: Service<Request<ReqBody>, Response = Response<ResBody>>,
234{
235 type Response = Response<ResBody>;
236 type Error = S::Error;
237 type Future = ResponseFuture<S::Future, ResBody>;
238
239 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
240 self.inner.poll_ready(cx)
241 }
242
243 fn call(&mut self, mut req: Request<ReqBody>) -> Self::Future {
244 match self.validate.validate(&mut req) {
245 Ok(_) => ResponseFuture::future(self.inner.call(req)),
246 Err(res) => ResponseFuture::invalid_header_value(res),
247 }
248 }
249}
250
251pin_project! {
252 pub struct ResponseFuture<F, B> {
254 #[pin]
255 kind: Kind<F, B>,
256 }
257}
258
259impl<F, B> ResponseFuture<F, B> {
260 fn future(future: F) -> Self {
261 Self {
262 kind: Kind::Future { future },
263 }
264 }
265
266 fn invalid_header_value(res: Response<B>) -> Self {
267 Self {
268 kind: Kind::Error {
269 response: Some(res),
270 },
271 }
272 }
273}
274
275pin_project! {
276 #[project = KindProj]
277 enum Kind<F, B> {
278 Future {
279 #[pin]
280 future: F,
281 },
282 Error {
283 response: Option<Response<B>>,
284 },
285 }
286}
287
288impl<F, B, E> Future for ResponseFuture<F, B>
289where
290 F: Future<Output = Result<Response<B>, E>>,
291{
292 type Output = F::Output;
293
294 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
295 match self.project().kind.project() {
296 KindProj::Future { future } => future.poll(cx),
297 KindProj::Error { response } => {
298 let response = response.take().expect("future polled after completion");
299 Poll::Ready(Ok(response))
300 }
301 }
302 }
303}
304
305pub trait ValidateRequest<B> {
307 type ResponseBody;
309
310 fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>>;
314}
315
316impl<B, F, ResBody> ValidateRequest<B> for F
317where
318 F: FnMut(&mut Request<B>) -> Result<(), Response<ResBody>>,
319{
320 type ResponseBody = ResBody;
321
322 fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
323 self(request)
324 }
325}
326
327pub struct AcceptHeader<ResBody> {
329 header_value: Arc<Mime>,
330 _ty: PhantomData<fn() -> ResBody>,
331}
332
333impl<ResBody> AcceptHeader<ResBody> {
334 fn new(header_value: &str) -> Self
340 where
341 ResBody: Default,
342 {
343 Self {
344 header_value: Arc::new(
345 header_value
346 .parse::<Mime>()
347 .expect("value is not a valid header value"),
348 ),
349 _ty: PhantomData,
350 }
351 }
352}
353
354impl<ResBody> Clone for AcceptHeader<ResBody> {
355 fn clone(&self) -> Self {
356 Self {
357 header_value: self.header_value.clone(),
358 _ty: PhantomData,
359 }
360 }
361}
362
363impl<ResBody> fmt::Debug for AcceptHeader<ResBody> {
364 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
365 f.debug_struct("AcceptHeader")
366 .field("header_value", &self.header_value)
367 .finish()
368 }
369}
370
371impl<B, ResBody> ValidateRequest<B> for AcceptHeader<ResBody>
372where
373 ResBody: Default,
374{
375 type ResponseBody = ResBody;
376
377 fn validate(&mut self, req: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
378 if !req.headers().contains_key(header::ACCEPT) {
379 return Ok(());
380 }
381 if req
382 .headers()
383 .get_all(header::ACCEPT)
384 .into_iter()
385 .filter_map(|header| header.to_str().ok())
386 .any(|h| {
387 MimeIter::new(h)
388 .map(|mim| {
389 if let Ok(mim) = mim {
390 let typ = self.header_value.type_();
391 let subtype = self.header_value.subtype();
392 match (mim.type_(), mim.subtype()) {
393 (t, s) if t == typ && s == subtype => true,
394 (t, mime::STAR) if t == typ => true,
395 (mime::STAR, mime::STAR) => true,
396 _ => false,
397 }
398 } else {
399 false
400 }
401 })
402 .reduce(|acc, mim| acc || mim)
403 .unwrap_or(false)
404 })
405 {
406 return Ok(());
407 }
408 let mut res = Response::new(ResBody::default());
409 *res.status_mut() = StatusCode::NOT_ACCEPTABLE;
410 Err(res)
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 #[allow(unused_imports)]
417 use super::*;
418 use crate::test_helpers::Body;
419 use http::header;
420 use tower::{BoxError, ServiceBuilder, ServiceExt};
421
422 #[tokio::test]
423 async fn valid_accept_header() {
424 let mut service = ServiceBuilder::new()
425 .layer(ValidateRequestHeaderLayer::accept("application/json"))
426 .service_fn(echo);
427
428 let request = Request::get("/")
429 .header(header::ACCEPT, "application/json")
430 .body(Body::empty())
431 .unwrap();
432
433 let res = service.ready().await.unwrap().call(request).await.unwrap();
434
435 assert_eq!(res.status(), StatusCode::OK);
436 }
437
438 #[tokio::test]
439 async fn valid_accept_header_accept_all_json() {
440 let mut service = ServiceBuilder::new()
441 .layer(ValidateRequestHeaderLayer::accept("application/json"))
442 .service_fn(echo);
443
444 let request = Request::get("/")
445 .header(header::ACCEPT, "application/*")
446 .body(Body::empty())
447 .unwrap();
448
449 let res = service.ready().await.unwrap().call(request).await.unwrap();
450
451 assert_eq!(res.status(), StatusCode::OK);
452 }
453
454 #[tokio::test]
455 async fn valid_accept_header_accept_all() {
456 let mut service = ServiceBuilder::new()
457 .layer(ValidateRequestHeaderLayer::accept("application/json"))
458 .service_fn(echo);
459
460 let request = Request::get("/")
461 .header(header::ACCEPT, "*/*")
462 .body(Body::empty())
463 .unwrap();
464
465 let res = service.ready().await.unwrap().call(request).await.unwrap();
466
467 assert_eq!(res.status(), StatusCode::OK);
468 }
469
470 #[tokio::test]
471 async fn invalid_accept_header() {
472 let mut service = ServiceBuilder::new()
473 .layer(ValidateRequestHeaderLayer::accept("application/json"))
474 .service_fn(echo);
475
476 let request = Request::get("/")
477 .header(header::ACCEPT, "invalid")
478 .body(Body::empty())
479 .unwrap();
480
481 let res = service.ready().await.unwrap().call(request).await.unwrap();
482
483 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
484 }
485 #[tokio::test]
486 async fn not_accepted_accept_header_subtype() {
487 let mut service = ServiceBuilder::new()
488 .layer(ValidateRequestHeaderLayer::accept("application/json"))
489 .service_fn(echo);
490
491 let request = Request::get("/")
492 .header(header::ACCEPT, "application/strings")
493 .body(Body::empty())
494 .unwrap();
495
496 let res = service.ready().await.unwrap().call(request).await.unwrap();
497
498 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
499 }
500
501 #[tokio::test]
502 async fn not_accepted_accept_header() {
503 let mut service = ServiceBuilder::new()
504 .layer(ValidateRequestHeaderLayer::accept("application/json"))
505 .service_fn(echo);
506
507 let request = Request::get("/")
508 .header(header::ACCEPT, "text/strings")
509 .body(Body::empty())
510 .unwrap();
511
512 let res = service.ready().await.unwrap().call(request).await.unwrap();
513
514 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
515 }
516
517 #[tokio::test]
518 async fn accepted_multiple_header_value() {
519 let mut service = ServiceBuilder::new()
520 .layer(ValidateRequestHeaderLayer::accept("application/json"))
521 .service_fn(echo);
522
523 let request = Request::get("/")
524 .header(header::ACCEPT, "text/strings")
525 .header(header::ACCEPT, "invalid, application/json")
526 .body(Body::empty())
527 .unwrap();
528
529 let res = service.ready().await.unwrap().call(request).await.unwrap();
530
531 assert_eq!(res.status(), StatusCode::OK);
532 }
533
534 #[tokio::test]
535 async fn accepted_inner_header_value() {
536 let mut service = ServiceBuilder::new()
537 .layer(ValidateRequestHeaderLayer::accept("application/json"))
538 .service_fn(echo);
539
540 let request = Request::get("/")
541 .header(header::ACCEPT, "text/strings, invalid, application/json")
542 .body(Body::empty())
543 .unwrap();
544
545 let res = service.ready().await.unwrap().call(request).await.unwrap();
546
547 assert_eq!(res.status(), StatusCode::OK);
548 }
549
550 #[tokio::test]
551 async fn accepted_header_with_quotes_valid() {
552 let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\", application/*";
553 let mut service = ServiceBuilder::new()
554 .layer(ValidateRequestHeaderLayer::accept("application/xml"))
555 .service_fn(echo);
556
557 let request = Request::get("/")
558 .header(header::ACCEPT, value)
559 .body(Body::empty())
560 .unwrap();
561
562 let res = service.ready().await.unwrap().call(request).await.unwrap();
563
564 assert_eq!(res.status(), StatusCode::OK);
565 }
566
567 #[tokio::test]
568 async fn accepted_header_with_quotes_invalid() {
569 let value = "foo/bar; parisien=\"baguette, text/html, jambon, fromage\"";
570 let mut service = ServiceBuilder::new()
571 .layer(ValidateRequestHeaderLayer::accept("text/html"))
572 .service_fn(echo);
573
574 let request = Request::get("/")
575 .header(header::ACCEPT, value)
576 .body(Body::empty())
577 .unwrap();
578
579 let res = service.ready().await.unwrap().call(request).await.unwrap();
580
581 assert_eq!(res.status(), StatusCode::NOT_ACCEPTABLE);
582 }
583
584 async fn echo(req: Request<Body>) -> Result<Response<Body>, BoxError> {
585 Ok(Response::new(req.into_body()))
586 }
587}