use std::rc::Rc;
use std::{fmt, ops};
use bytes::{Bytes, BytesMut};
use futures::{Future, Poll, Stream};
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json;
use actix_http::http::{header::CONTENT_LENGTH, StatusCode};
use actix_http::{HttpMessage, Payload, Response};
use crate::error::{Error, JsonPayloadError, PayloadError};
use crate::extract::FromRequest;
use crate::request::HttpRequest;
use crate::responder::Responder;
pub struct Json<T>(pub T);
impl<T> Json<T> {
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> ops::Deref for Json<T> {
type Target = T;
fn deref(&self) -> &T {
&self.0
}
}
impl<T> ops::DerefMut for Json<T> {
fn deref_mut(&mut self) -> &mut T {
&mut self.0
}
}
impl<T> fmt::Debug for Json<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Json: {:?}", self.0)
}
}
impl<T> fmt::Display for Json<T>
where
T: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
impl<T: Serialize> Responder for Json<T> {
type Error = Error;
type Future = Result<Response, Error>;
fn respond_to(self, _: &HttpRequest) -> Self::Future {
let body = match serde_json::to_string(&self.0) {
Ok(body) => body,
Err(e) => return Err(e.into()),
};
Ok(Response::build(StatusCode::OK)
.content_type("application/json")
.body(body))
}
}
impl<T, P> FromRequest<P> for Json<T>
where
T: DeserializeOwned + 'static,
P: Stream<Item = Bytes, Error = crate::error::PayloadError> + 'static,
{
type Error = Error;
type Future = Box<Future<Item = Self, Error = Error>>;
#[inline]
fn from_request(req: &HttpRequest, payload: &mut Payload<P>) -> Self::Future {
let req2 = req.clone();
let (limit, err) = req
.route_data::<JsonConfig>()
.map(|c| (c.limit, c.ehandler.clone()))
.unwrap_or((32768, None));
let path = req.path().to_string();
Box::new(
JsonBody::new(req, payload)
.limit(limit)
.map_err(move |e| {
log::debug!(
"Failed to deserialize Json from payload. \
Request path: {:?}",
path
);
if let Some(err) = err {
(*err)(e, &req2)
} else {
e.into()
}
})
.map(Json),
)
}
}
#[derive(Clone)]
pub struct JsonConfig {
limit: usize,
ehandler: Option<Rc<Fn(JsonPayloadError, &HttpRequest) -> Error>>,
}
impl JsonConfig {
pub fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
pub fn error_handler<F>(mut self, f: F) -> Self
where
F: Fn(JsonPayloadError, &HttpRequest) -> Error + 'static,
{
self.ehandler = Some(Rc::new(f));
self
}
}
impl Default for JsonConfig {
fn default() -> Self {
JsonConfig {
limit: 32768,
ehandler: None,
}
}
}
pub struct JsonBody<P, U> {
limit: usize,
length: Option<usize>,
stream: Payload<P>,
err: Option<JsonPayloadError>,
fut: Option<Box<Future<Item = U, Error = JsonPayloadError>>>,
}
impl<P, U> JsonBody<P, U>
where
P: Stream<Item = Bytes, Error = PayloadError> + 'static,
U: DeserializeOwned + 'static,
{
pub fn new(req: &HttpRequest, payload: &mut Payload<P>) -> Self {
let json = if let Ok(Some(mime)) = req.mime_type() {
mime.subtype() == mime::JSON || mime.suffix() == Some(mime::JSON)
} else {
false
};
if !json {
return JsonBody {
limit: 262_144,
length: None,
stream: Payload::None,
fut: None,
err: Some(JsonPayloadError::ContentType),
};
}
let mut len = None;
if let Some(l) = req.headers().get(&CONTENT_LENGTH) {
if let Ok(s) = l.to_str() {
if let Ok(l) = s.parse::<usize>() {
len = Some(l)
}
}
}
JsonBody {
limit: 262_144,
length: len,
stream: payload.take(),
fut: None,
err: None,
}
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
}
impl<P, U> Future for JsonBody<P, U>
where
P: Stream<Item = Bytes, Error = PayloadError> + 'static,
U: DeserializeOwned + 'static,
{
type Item = U;
type Error = JsonPayloadError;
fn poll(&mut self) -> Poll<U, JsonPayloadError> {
if let Some(ref mut fut) = self.fut {
return fut.poll();
}
if let Some(err) = self.err.take() {
return Err(err);
}
let limit = self.limit;
if let Some(len) = self.length.take() {
if len > limit {
return Err(JsonPayloadError::Overflow);
}
}
let fut = std::mem::replace(&mut self.stream, Payload::None)
.from_err()
.fold(BytesMut::with_capacity(8192), move |mut body, chunk| {
if (body.len() + chunk.len()) > limit {
Err(JsonPayloadError::Overflow)
} else {
body.extend_from_slice(&chunk);
Ok(body)
}
})
.and_then(|body| Ok(serde_json::from_slice::<U>(&body)?));
self.fut = Some(Box::new(fut));
self.poll()
}
}
#[cfg(test)]
mod tests {
use bytes::Bytes;
use serde_derive::{Deserialize, Serialize};
use super::*;
use crate::error::InternalError;
use crate::http::header;
use crate::test::{block_on, TestRequest};
use crate::HttpResponse;
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct MyObject {
name: String,
}
fn json_eq(err: JsonPayloadError, other: JsonPayloadError) -> bool {
match err {
JsonPayloadError::Overflow => match other {
JsonPayloadError::Overflow => true,
_ => false,
},
JsonPayloadError::ContentType => match other {
JsonPayloadError::ContentType => true,
_ => false,
},
_ => false,
}
}
#[test]
fn test_responder() {
let req = TestRequest::default().to_http_request();
let j = Json(MyObject {
name: "test".to_string(),
});
let resp = j.respond_to(&req).unwrap();
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get(header::CONTENT_TYPE).unwrap(),
header::HeaderValue::from_static("application/json")
);
use crate::responder::tests::BodyTest;
assert_eq!(resp.body().bin_ref(), b"{\"name\":\"test\"}");
}
#[test]
fn test_custom_error_responder() {
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.route_data(JsonConfig::default().limit(10).error_handler(|err, _| {
let msg = MyObject {
name: "invalid request".to_string(),
};
let resp = HttpResponse::BadRequest()
.body(serde_json::to_string(&msg).unwrap());
InternalError::from_response(err, resp).into()
}))
.to_http_parts();
let s = block_on(Json::<MyObject>::from_request(&req, &mut pl));
let mut resp = Response::from_error(s.err().unwrap().into());
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
let body = block_on(resp.take_body().concat2()).unwrap();
let msg: MyObject = serde_json::from_slice(&body).unwrap();
assert_eq!(msg.name, "invalid request");
}
#[test]
fn test_extract() {
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.to_http_parts();
let s = block_on(Json::<MyObject>::from_request(&req, &mut pl)).unwrap();
assert_eq!(s.name, "test");
assert_eq!(
s.into_inner(),
MyObject {
name: "test".to_string()
}
);
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.route_data(JsonConfig::default().limit(10))
.to_http_parts();
let s = block_on(Json::<MyObject>::from_request(&req, &mut pl));
assert!(format!("{}", s.err().unwrap())
.contains("Json payload size is bigger than allowed"));
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.route_data(
JsonConfig::default()
.limit(10)
.error_handler(|_, _| JsonPayloadError::ContentType.into()),
)
.to_http_parts();
let s = block_on(Json::<MyObject>::from_request(&req, &mut pl));
assert!(format!("{}", s.err().unwrap()).contains("Content type error"));
}
#[test]
fn test_json_body() {
let (req, mut pl) = TestRequest::default().to_http_parts();
let json = block_on(JsonBody::<_, MyObject>::new(&req, &mut pl));
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/text"),
)
.to_http_parts();
let json = block_on(JsonBody::<_, MyObject>::new(&req, &mut pl));
assert!(json_eq(json.err().unwrap(), JsonPayloadError::ContentType));
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("10000"),
)
.to_http_parts();
let json = block_on(JsonBody::<_, MyObject>::new(&req, &mut pl).limit(100));
assert!(json_eq(json.err().unwrap(), JsonPayloadError::Overflow));
let (req, mut pl) = TestRequest::default()
.header(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
)
.header(
header::CONTENT_LENGTH,
header::HeaderValue::from_static("16"),
)
.set_payload(Bytes::from_static(b"{\"name\": \"test\"}"))
.to_http_parts();
let json = block_on(JsonBody::<_, MyObject>::new(&req, &mut pl));
assert_eq!(
json.ok().unwrap(),
MyObject {
name: "test".to_owned()
}
);
}
}