1use std::{
4 fmt,
5 future::Future,
6 marker::PhantomData,
7 ops,
8 pin::Pin,
9 sync::Arc,
10 task::{Context, Poll},
11};
12
13use actix_http::Payload;
14use bytes::BytesMut;
15use futures_core::{ready, Stream as _};
16use serde::{de::DeserializeOwned, Serialize};
17
18#[cfg(feature = "__compress")]
19use crate::dev::Decompress;
20use crate::{
21 body::EitherBody,
22 error::{Error, JsonPayloadError},
23 extract::FromRequest,
24 http::header::{ContentLength, Header as _},
25 request::HttpRequest,
26 web, HttpMessage, HttpResponse, Responder,
27};
28
29#[derive(Debug)]
77pub struct Json<T>(pub T);
78
79impl<T> Json<T> {
80 pub fn into_inner(self) -> T {
82 self.0
83 }
84}
85
86impl<T> ops::Deref for Json<T> {
87 type Target = T;
88
89 fn deref(&self) -> &T {
90 &self.0
91 }
92}
93
94impl<T> ops::DerefMut for Json<T> {
95 fn deref_mut(&mut self) -> &mut T {
96 &mut self.0
97 }
98}
99
100impl<T: fmt::Display> fmt::Display for Json<T> {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 fmt::Display::fmt(&self.0, f)
103 }
104}
105
106impl<T: Serialize> Serialize for Json<T> {
107 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
108 where
109 S: serde::Serializer,
110 {
111 self.0.serialize(serializer)
112 }
113}
114
115impl<T: Serialize> Responder for Json<T> {
119 type Body = EitherBody<String>;
120
121 fn respond_to(self, _: &HttpRequest) -> HttpResponse<Self::Body> {
122 match serde_json::to_string(&self.0) {
123 Ok(body) => match HttpResponse::Ok()
124 .content_type(mime::APPLICATION_JSON)
125 .message_body(body)
126 {
127 Ok(res) => res.map_into_left_body(),
128 Err(err) => HttpResponse::from_error(err).map_into_right_body(),
129 },
130
131 Err(err) => {
132 HttpResponse::from_error(JsonPayloadError::Serialize(err)).map_into_right_body()
133 }
134 }
135 }
136}
137
138impl<T: DeserializeOwned> FromRequest for Json<T> {
140 type Error = Error;
141 type Future = JsonExtractFut<T>;
142
143 #[inline]
144 fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future {
145 let config = JsonConfig::from_req(req);
146
147 let limit = config.limit;
148 let ctype_required = config.content_type_required;
149 let ctype_fn = config.content_type.as_deref();
150 let err_handler = config.err_handler.clone();
151
152 JsonExtractFut {
153 req: Some(req.clone()),
154 fut: JsonBody::new(req, payload, ctype_fn, ctype_required).limit(limit),
155 err_handler,
156 }
157 }
158}
159
160type JsonErrorHandler = Option<Arc<dyn Fn(JsonPayloadError, &HttpRequest) -> Error + Send + Sync>>;
161
162pub struct JsonExtractFut<T> {
163 req: Option<HttpRequest>,
164 fut: JsonBody<T>,
165 err_handler: JsonErrorHandler,
166}
167
168impl<T: DeserializeOwned> Future for JsonExtractFut<T> {
169 type Output = Result<Json<T>, Error>;
170
171 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
172 let this = self.get_mut();
173
174 let res = ready!(Pin::new(&mut this.fut).poll(cx));
175
176 let res = match res {
177 Err(err) => {
178 let req = this.req.take().unwrap();
179 log::debug!(
180 "Failed to deserialize Json from payload. \
181 Request path: {}",
182 req.path()
183 );
184
185 if let Some(err_handler) = this.err_handler.as_ref() {
186 Err((*err_handler)(err, &req))
187 } else {
188 Err(err.into())
189 }
190 }
191 Ok(data) => Ok(Json(data)),
192 };
193
194 Poll::Ready(res)
195 }
196}
197
198#[derive(Clone)]
232pub struct JsonConfig {
233 limit: usize,
234 err_handler: JsonErrorHandler,
235 content_type: Option<Arc<dyn Fn(mime::Mime) -> bool + Send + Sync>>,
236 content_type_required: bool,
237}
238
239impl JsonConfig {
240 pub fn limit(mut self, limit: usize) -> Self {
242 self.limit = limit;
243 self
244 }
245
246 pub fn error_handler<F>(mut self, f: F) -> Self
248 where
249 F: Fn(JsonPayloadError, &HttpRequest) -> Error + Send + Sync + 'static,
250 {
251 self.err_handler = Some(Arc::new(f));
252 self
253 }
254
255 pub fn content_type<F>(mut self, predicate: F) -> Self
257 where
258 F: Fn(mime::Mime) -> bool + Send + Sync + 'static,
259 {
260 self.content_type = Some(Arc::new(predicate));
261 self
262 }
263
264 pub fn content_type_required(mut self, content_type_required: bool) -> Self {
266 self.content_type_required = content_type_required;
267 self
268 }
269
270 fn from_req(req: &HttpRequest) -> &Self {
273 req.app_data::<Self>()
274 .or_else(|| req.app_data::<web::Data<Self>>().map(|d| d.as_ref()))
275 .unwrap_or(&DEFAULT_CONFIG)
276 }
277}
278
279const DEFAULT_LIMIT: usize = 2_097_152; const DEFAULT_CONFIG: JsonConfig = JsonConfig {
283 limit: DEFAULT_LIMIT,
284 err_handler: None,
285 content_type: None,
286 content_type_required: true,
287};
288
289impl Default for JsonConfig {
290 fn default() -> Self {
291 DEFAULT_CONFIG
292 }
293}
294
295pub enum JsonBody<T> {
305 Error(Option<JsonPayloadError>),
306 Body {
307 limit: usize,
308 length: Option<usize>,
310 #[cfg(feature = "__compress")]
311 payload: Decompress<Payload>,
312 #[cfg(not(feature = "__compress"))]
313 payload: Payload,
314 buf: BytesMut,
315 _res: PhantomData<T>,
316 },
317}
318
319impl<T> Unpin for JsonBody<T> {}
320
321impl<T: DeserializeOwned> JsonBody<T> {
322 #[allow(clippy::borrow_interior_mutable_const)]
324 pub fn new(
325 req: &HttpRequest,
326 payload: &mut Payload,
327 ctype_fn: Option<&(dyn Fn(mime::Mime) -> bool + Send + Sync)>,
328 ctype_required: bool,
329 ) -> Self {
330 let can_parse_json = match (ctype_required, req.mime_type()) {
332 (true, Ok(Some(mime))) => {
333 mime.subtype() == mime::JSON
334 || mime.suffix() == Some(mime::JSON)
335 || ctype_fn.is_some_and(|predicate| predicate(mime))
336 }
337
338 (true, _) => false,
340
341 (false, _) => true,
344 };
345
346 if !can_parse_json {
347 return JsonBody::Error(Some(JsonPayloadError::ContentType));
348 }
349
350 let length = ContentLength::parse(req).ok().map(|x| x.0);
351
352 let payload = {
357 cfg_if::cfg_if! {
358 if #[cfg(feature = "__compress")] {
359 Decompress::from_headers(payload.take(), req.headers())
360 } else {
361 payload.take()
362 }
363 }
364 };
365
366 JsonBody::Body {
367 limit: DEFAULT_LIMIT,
368 length,
369 payload,
370 buf: BytesMut::with_capacity(8192),
371 _res: PhantomData,
372 }
373 }
374
375 pub fn limit(self, limit: usize) -> Self {
377 match self {
378 JsonBody::Body {
379 length,
380 payload,
381 buf,
382 ..
383 } => {
384 if let Some(len) = length {
385 if len > limit {
386 return JsonBody::Error(Some(JsonPayloadError::OverflowKnownLength {
387 length: len,
388 limit,
389 }));
390 }
391 }
392
393 JsonBody::Body {
394 limit,
395 length,
396 payload,
397 buf,
398 _res: PhantomData,
399 }
400 }
401 JsonBody::Error(err) => JsonBody::Error(err),
402 }
403 }
404}
405
406impl<T: DeserializeOwned> Future for JsonBody<T> {
407 type Output = Result<T, JsonPayloadError>;
408
409 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
410 let this = self.get_mut();
411
412 match this {
413 JsonBody::Body {
414 limit,
415 buf,
416 payload,
417 ..
418 } => loop {
419 let res = ready!(Pin::new(&mut *payload).poll_next(cx));
420 match res {
421 Some(chunk) => {
422 let chunk = chunk?;
423 let buf_len = buf.len() + chunk.len();
424 if buf_len > *limit {
425 return Poll::Ready(Err(JsonPayloadError::Overflow { limit: *limit }));
426 } else {
427 buf.extend_from_slice(&chunk);
428 }
429 }
430 None => {
431 let json = serde_json::from_slice::<T>(buf)
432 .map_err(JsonPayloadError::Deserialize)?;
433 return Poll::Ready(Ok(json));
434 }
435 }
436 },
437 JsonBody::Error(err) => Poll::Ready(Err(err.take().unwrap())),
438 }
439 }
440}
441
442#[cfg(test)]
443mod tests {
444 use bytes::Bytes;
445 use serde::{Deserialize, Serialize};
446
447 use super::*;
448 use crate::{
449 body,
450 error::InternalError,
451 http::{
452 header::{self, CONTENT_LENGTH, CONTENT_TYPE},
453 StatusCode,
454 },
455 test::{assert_body_eq, TestRequest},
456 };
457
458 #[derive(Serialize, Deserialize, PartialEq, Debug)]
459 struct MyObject {
460 name: String,
461 }
462
463 fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool {
464 match err {
465 JsonPayloadError::Overflow { .. } => {
466 matches!(other, JsonPayloadError::Overflow { .. })
467 }
468 JsonPayloadError::OverflowKnownLength { .. } => {
469 matches!(other, JsonPayloadError::OverflowKnownLength { .. })
470 }
471 JsonPayloadError::ContentType => matches!(other, JsonPayloadError::ContentType),
472 _ => false,
473 }
474 }
475
476 #[actix_rt::test]
477 async fn test_responder() {
478 let req = TestRequest::default().to_http_request();
479
480 let j = Json(MyObject {
481 name: "test".to_string(),
482 });
483 let res = j.respond_to(&req);
484 assert_eq!(res.status(), StatusCode::OK);
485 assert_eq!(
486 res.headers().get(header::CONTENT_TYPE).unwrap(),
487 header::HeaderValue::from_static("application/json")
488 );
489 assert_body_eq!(res, b"{\"name\":\"test\"}");
490 }
491
492 #[actix_rt::test]
493 async fn test_custom_error_responder() {
494 let (req, mut pl) = TestRequest::default()
495 .insert_header((
496 header::CONTENT_TYPE,
497 header::HeaderValue::from_static("application/json"),
498 ))
499 .insert_header((
500 header::CONTENT_LENGTH,
501 header::HeaderValue::from_static("16"),
502 ))
503 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
504 .app_data(JsonConfig::default().limit(10).error_handler(|err, _| {
505 let msg = MyObject {
506 name: "invalid request".to_string(),
507 };
508 let resp = HttpResponse::BadRequest().body(serde_json::to_string(&msg).unwrap());
509 InternalError::from_response(err, resp).into()
510 }))
511 .to_http_parts();
512
513 let s = Json::<MyObject>::from_request(&req, &mut pl).await;
514 let resp = HttpResponse::from_error(s.unwrap_err());
515 assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
516
517 let body = body::to_bytes(resp.into_body()).await.unwrap();
518 let msg: MyObject = serde_json::from_slice(&body).unwrap();
519 assert_eq!(msg.name, "invalid request");
520 }
521
522 #[actix_rt::test]
523 async fn test_extract() {
524 let (req, mut pl) = TestRequest::default()
525 .insert_header((
526 header::CONTENT_TYPE,
527 header::HeaderValue::from_static("application/json"),
528 ))
529 .insert_header((
530 header::CONTENT_LENGTH,
531 header::HeaderValue::from_static("16"),
532 ))
533 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
534 .to_http_parts();
535
536 let s = Json::<MyObject>::from_request(&req, &mut pl).await.unwrap();
537 assert_eq!(s.name, "test");
538 assert_eq!(
539 s.into_inner(),
540 MyObject {
541 name: "test".to_string()
542 }
543 );
544
545 let (req, mut pl) = TestRequest::default()
546 .insert_header((
547 header::CONTENT_TYPE,
548 header::HeaderValue::from_static("application/json"),
549 ))
550 .insert_header((
551 header::CONTENT_LENGTH,
552 header::HeaderValue::from_static("16"),
553 ))
554 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
555 .app_data(JsonConfig::default().limit(10))
556 .to_http_parts();
557
558 let s = Json::<MyObject>::from_request(&req, &mut pl).await;
559 assert!(format!("{}", s.err().unwrap())
560 .contains("JSON payload (16 bytes) is larger than allowed (limit: 10 bytes)."));
561
562 let (req, mut pl) = TestRequest::default()
563 .insert_header((
564 header::CONTENT_TYPE,
565 header::HeaderValue::from_static("application/json"),
566 ))
567 .insert_header((
568 header::CONTENT_LENGTH,
569 header::HeaderValue::from_static("16"),
570 ))
571 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
572 .app_data(
573 JsonConfig::default()
574 .limit(10)
575 .error_handler(|_, _| JsonPayloadError::ContentType.into()),
576 )
577 .to_http_parts();
578 let s = Json::<MyObject>::from_request(&req, &mut pl).await;
579 assert!(format!("{}", s.err().unwrap()).contains("Content type error"));
580 }
581
582 #[actix_rt::test]
583 async fn test_json_body() {
584 let (req, mut pl) = TestRequest::default().to_http_parts();
585 let json = JsonBody::<MyObject>::new(&req, &mut pl, None, true).await;
586 assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
587
588 let (req, mut pl) = TestRequest::default()
589 .insert_header((
590 header::CONTENT_TYPE,
591 header::HeaderValue::from_static("application/text"),
592 ))
593 .to_http_parts();
594 let json = JsonBody::<MyObject>::new(&req, &mut pl, None, true).await;
595 assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
596
597 let (req, mut pl) = TestRequest::default()
598 .insert_header((
599 header::CONTENT_TYPE,
600 header::HeaderValue::from_static("application/json"),
601 ))
602 .insert_header((
603 header::CONTENT_LENGTH,
604 header::HeaderValue::from_static("10000"),
605 ))
606 .to_http_parts();
607
608 let json = JsonBody::<MyObject>::new(&req, &mut pl, None, true)
609 .limit(100)
610 .await;
611 assert!(json_eq(
612 json.err().unwrap(),
613 JsonPayloadError::OverflowKnownLength {
614 length: 10000,
615 limit: 100
616 }
617 ));
618
619 let (req, mut pl) = TestRequest::default()
620 .insert_header((
621 header::CONTENT_TYPE,
622 header::HeaderValue::from_static("application/json"),
623 ))
624 .set_payload(Bytes::from_static(&[0u8; 1000]))
625 .to_http_parts();
626
627 let json = JsonBody::<MyObject>::new(&req, &mut pl, None, true)
628 .limit(100)
629 .await;
630
631 assert!(json_eq(
632 json.err().unwrap(),
633 JsonPayloadError::Overflow { limit: 100 }
634 ));
635
636 let (req, mut pl) = TestRequest::default()
637 .insert_header((
638 header::CONTENT_TYPE,
639 header::HeaderValue::from_static("application/json"),
640 ))
641 .insert_header((
642 header::CONTENT_LENGTH,
643 header::HeaderValue::from_static("16"),
644 ))
645 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
646 .to_http_parts();
647
648 let json = JsonBody::<MyObject>::new(&req, &mut pl, None, true).await;
649 assert_eq!(
650 json.ok().unwrap(),
651 MyObject {
652 name: "test".to_owned()
653 }
654 );
655 }
656
657 #[actix_rt::test]
658 async fn test_with_json_and_bad_content_type() {
659 let (req, mut pl) = TestRequest::default()
660 .insert_header((
661 header::CONTENT_TYPE,
662 header::HeaderValue::from_static("text/plain"),
663 ))
664 .insert_header((
665 header::CONTENT_LENGTH,
666 header::HeaderValue::from_static("16"),
667 ))
668 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
669 .app_data(JsonConfig::default().limit(4096))
670 .to_http_parts();
671
672 let s = Json::<MyObject>::from_request(&req, &mut pl).await;
673 assert!(s.is_err())
674 }
675
676 #[actix_rt::test]
677 async fn test_with_json_and_good_custom_content_type() {
678 let (req, mut pl) = TestRequest::default()
679 .insert_header((
680 header::CONTENT_TYPE,
681 header::HeaderValue::from_static("text/plain"),
682 ))
683 .insert_header((
684 header::CONTENT_LENGTH,
685 header::HeaderValue::from_static("16"),
686 ))
687 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
688 .app_data(JsonConfig::default().content_type(|mime: mime::Mime| {
689 mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN
690 }))
691 .to_http_parts();
692
693 let s = Json::<MyObject>::from_request(&req, &mut pl).await;
694 assert!(s.is_ok())
695 }
696
697 #[actix_rt::test]
698 async fn test_with_json_and_bad_custom_content_type() {
699 let (req, mut pl) = TestRequest::default()
700 .insert_header((
701 header::CONTENT_TYPE,
702 header::HeaderValue::from_static("text/html"),
703 ))
704 .insert_header((
705 header::CONTENT_LENGTH,
706 header::HeaderValue::from_static("16"),
707 ))
708 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
709 .app_data(JsonConfig::default().content_type(|mime: mime::Mime| {
710 mime.type_() == mime::TEXT && mime.subtype() == mime::PLAIN
711 }))
712 .to_http_parts();
713
714 let s = Json::<MyObject>::from_request(&req, &mut pl).await;
715 assert!(s.is_err())
716 }
717
718 #[actix_rt::test]
719 async fn test_json_with_no_content_type() {
720 let (req, mut pl) = TestRequest::default()
721 .insert_header((
722 header::CONTENT_LENGTH,
723 header::HeaderValue::from_static("16"),
724 ))
725 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
726 .app_data(JsonConfig::default().content_type_required(false))
727 .to_http_parts();
728
729 let s = Json::<MyObject>::from_request(&req, &mut pl).await;
730 assert!(s.is_ok())
731 }
732
733 #[actix_rt::test]
734 async fn test_json_ignoring_content_type() {
735 let (req, mut pl) = TestRequest::default()
736 .insert_header((
737 header::CONTENT_LENGTH,
738 header::HeaderValue::from_static("16"),
739 ))
740 .insert_header((
741 header::CONTENT_TYPE,
742 header::HeaderValue::from_static("invalid/value"),
743 ))
744 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
745 .app_data(JsonConfig::default().content_type_required(false))
746 .to_http_parts();
747
748 let s = Json::<MyObject>::from_request(&req, &mut pl).await;
749 assert!(s.is_ok());
750 }
751
752 #[actix_rt::test]
753 async fn test_with_config_in_data_wrapper() {
754 let (req, mut pl) = TestRequest::default()
755 .insert_header((CONTENT_TYPE, mime::APPLICATION_JSON))
756 .insert_header((CONTENT_LENGTH, 16))
757 .set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
758 .app_data(web::Data::new(JsonConfig::default().limit(10)))
759 .to_http_parts();
760
761 let s = Json::<MyObject>::from_request(&req, &mut pl).await;
762 assert!(s.is_err());
763
764 let err_str = s.err().unwrap().to_string();
765 assert!(
766 err_str.contains("JSON payload (16 bytes) is larger than allowed (limit: 10 bytes).")
767 );
768 }
769}