1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
use std::borrow::Borrow;
use std::fmt::Debug;
use std::str::FromStr;
use std::sync::Arc;
use bytes::{Bytes, BytesMut};
use futures_util::stream::{Stream, StreamExt};
use h2;
use http::header::CONTENT_LENGTH;
use http::{Method, Request};
use log::debug;
use crate::https::HttpsError;
pub async fn message_from<R>(
this_server_name: Arc<str>,
request: Request<R>,
) -> Result<BytesMut, HttpsError>
where
R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin,
{
debug!("Received request: {:#?}", request);
let this_server_name = this_server_name.borrow();
match crate::https::request::verify(this_server_name, &request) {
Ok(_) => (),
Err(err) => return Err(err),
}
let mut content_length = None;
if let Some(length) = request.headers().get(CONTENT_LENGTH) {
let length = usize::from_str(length.to_str()?)?;
debug!("got message length: {}", length);
content_length = Some(length);
}
match *request.method() {
Method::GET => Err(format!("GET unimplemented: {}", request.method()).into()),
Method::POST => message_from_post(request.into_body(), content_length).await,
_ => Err(format!("bad method: {}", request.method()).into()),
}
}
pub(crate) async fn message_from_post<R>(
mut request_stream: R,
length: Option<usize>,
) -> Result<BytesMut, HttpsError>
where
R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin,
{
let mut bytes = BytesMut::with_capacity(length.unwrap_or(0).clamp(512, 4096));
loop {
match request_stream.next().await {
Some(Ok(mut frame)) => bytes.extend_from_slice(&frame.split_off(0)),
Some(Err(err)) => return Err(err.into()),
None => {
return if let Some(length) = length {
if bytes.len() == length {
Ok(bytes)
} else {
Err("not all bytes received".into())
}
} else {
Ok(bytes)
};
}
};
if let Some(length) = length {
if bytes.len() == length {
return Ok(bytes);
}
}
}
}
#[cfg(test)]
mod tests {
use futures_executor::block_on;
use std::pin::Pin;
use std::task::{Context, Poll};
use crate::https::request;
use crate::op::Message;
use super::*;
#[derive(Debug)]
struct TestBytesStream(Vec<Result<Bytes, h2::Error>>);
impl Stream for TestBytesStream {
type Item = Result<Bytes, h2::Error>;
fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match self.0.pop() {
Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes))),
Some(Err(err)) => Poll::Ready(Some(Err(err))),
None => Poll::Ready(None),
}
}
}
#[test]
fn test_from_post() {
let message = Message::new();
let msg_bytes = message.to_vec().unwrap();
let len = msg_bytes.len();
let stream = TestBytesStream(vec![Ok(Bytes::from(msg_bytes))]);
let request = request::new("ns.example.com", len).unwrap();
let request = request.map(|()| stream);
let from_post = message_from(Arc::from("ns.example.com"), request);
let bytes = match block_on(from_post) {
Ok(bytes) => bytes,
e => panic!("{:#?}", e),
};
let msg_from_post = Message::from_vec(bytes.as_ref()).expect("bytes failed");
assert_eq!(message, msg_from_post);
}
}