use crate::error::GenericTransportError;
use futures_util::stream::StreamExt;
pub async fn read_body(
headers: &hyper::HeaderMap,
mut body: hyper::Body,
max_request_body_size: u32,
) -> Result<(Vec<u8>, bool), GenericTransportError<hyper::Error>> {
let body_size = read_header_content_length(headers).unwrap_or(0);
if body_size > max_request_body_size {
return Err(GenericTransportError::TooLarge);
}
let first_chunk =
body.next().await.ok_or(GenericTransportError::Malformed)?.map_err(GenericTransportError::Inner)?;
if first_chunk.len() > max_request_body_size as usize {
return Err(GenericTransportError::TooLarge);
}
let first_non_whitespace = first_chunk.iter().find(|byte| !byte.is_ascii_whitespace());
let single = match first_non_whitespace {
Some(b'{') => true,
Some(b'[') => false,
_ => return Err(GenericTransportError::Malformed),
};
let mut received_data = Vec::with_capacity(body_size as usize);
received_data.extend_from_slice(&first_chunk);
while let Some(chunk) = body.next().await {
let chunk = chunk.map_err(GenericTransportError::Inner)?;
let body_length = chunk.len() + received_data.len();
if body_length > max_request_body_size as usize {
return Err(GenericTransportError::TooLarge);
}
received_data.extend_from_slice(&chunk);
}
Ok((received_data, single))
}
fn read_header_content_length(headers: &hyper::header::HeaderMap) -> Option<u32> {
let length = read_header_value(headers, hyper::header::CONTENT_LENGTH)?;
length.parse::<u32>().ok()
}
pub fn read_header_value(headers: &hyper::header::HeaderMap, header_name: hyper::header::HeaderName) -> Option<&str> {
let mut values = headers.get_all(header_name).iter();
let val = values.next()?;
if values.next().is_none() {
val.to_str().ok()
} else {
None
}
}
pub fn read_header_values<'a>(
headers: &'a hyper::header::HeaderMap,
header_name: &str,
) -> hyper::header::GetAll<'a, hyper::header::HeaderValue> {
headers.get_all(header_name)
}
#[cfg(test)]
mod tests {
use super::{read_body, read_header_content_length};
#[tokio::test]
async fn body_to_bytes_size_limit_works() {
let headers = hyper::header::HeaderMap::new();
let body = hyper::Body::from(vec![0; 128]);
assert!(read_body(&headers, body, 127).await.is_err());
}
#[test]
fn read_content_length_works() {
let mut headers = hyper::header::HeaderMap::new();
headers.insert(hyper::header::CONTENT_LENGTH, "177".parse().unwrap());
assert_eq!(read_header_content_length(&headers), Some(177));
headers.append(hyper::header::CONTENT_LENGTH, "999".parse().unwrap());
assert_eq!(read_header_content_length(&headers), None);
}
#[test]
fn read_content_length_too_big_value() {
let mut headers = hyper::header::HeaderMap::new();
headers.insert(hyper::header::CONTENT_LENGTH, "18446744073709551616".parse().unwrap());
assert_eq!(read_header_content_length(&headers), None);
}
}