hickory_proto/h2/
h2_server.rs1use std::fmt::Debug;
11use std::str::FromStr;
12use std::sync::Arc;
13
14use bytes::{Bytes, BytesMut};
15use futures_util::stream::{Stream, StreamExt};
16use h2;
17use http::header::CONTENT_LENGTH;
18use http::{Method, Request};
19use tracing::debug;
20
21use crate::h2::HttpsError;
22use crate::http::Version;
23
24pub async fn message_from<R>(
29 this_server_name: Option<Arc<str>>,
30 request: Request<R>,
31) -> Result<BytesMut, HttpsError>
32where
33 R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin,
34{
35 debug!("Received request: {:#?}", request);
36
37 let this_server_name = this_server_name.as_deref();
38 match crate::http::request::verify(Version::Http2, this_server_name, &request) {
39 Ok(_) => (),
40 Err(err) => return Err(err),
41 }
42
43 let mut content_length = None;
45 if let Some(length) = request.headers().get(CONTENT_LENGTH) {
46 let length = usize::from_str(length.to_str()?)?;
47 debug!("got message length: {}", length);
48 content_length = Some(length);
49 }
50
51 match *request.method() {
52 Method::GET => Err(format!("GET unimplemented: {}", request.method()).into()),
53 Method::POST => message_from_post(request.into_body(), content_length).await,
54 _ => Err(format!("bad method: {}", request.method()).into()),
55 }
56}
57
58pub(crate) async fn message_from_post<R>(
60 mut request_stream: R,
61 length: Option<usize>,
62) -> Result<BytesMut, HttpsError>
63where
64 R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin,
65{
66 let mut bytes = BytesMut::with_capacity(length.unwrap_or(0).clamp(512, 4096));
67
68 loop {
69 match request_stream.next().await {
70 Some(Ok(mut frame)) => bytes.extend_from_slice(&frame.split_off(0)),
71 Some(Err(err)) => return Err(err.into()),
72 None => {
73 return if let Some(length) = length {
74 if bytes.len() == length {
76 Ok(bytes)
77 } else {
78 Err("not all bytes received".into())
79 }
80 } else {
81 Ok(bytes)
82 };
83 }
84 };
85
86 if let Some(length) = length {
87 if bytes.len() == length {
89 return Ok(bytes);
90 }
91 }
92 }
93}
94
95#[cfg(test)]
96mod tests {
97 use futures_executor::block_on;
98 use std::pin::Pin;
99 use std::task::{Context, Poll};
100
101 use crate::http::request;
102 use crate::op::Message;
103
104 use super::*;
105
106 #[derive(Debug)]
107 struct TestBytesStream(Vec<Result<Bytes, h2::Error>>);
108
109 impl Stream for TestBytesStream {
110 type Item = Result<Bytes, h2::Error>;
111
112 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
113 match self.0.pop() {
114 Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes))),
115 Some(Err(err)) => Poll::Ready(Some(Err(err))),
116 None => Poll::Ready(None),
117 }
118 }
119 }
120
121 #[test]
122 fn test_from_post() {
123 let message = Message::new();
124 let msg_bytes = message.to_vec().unwrap();
125 let len = msg_bytes.len();
126 let stream = TestBytesStream(vec![Ok(Bytes::from(msg_bytes))]);
127 let request = request::new(Version::Http2, "ns.example.com", len).unwrap();
128 let request = request.map(|()| stream);
129
130 let from_post = message_from(Some(Arc::from("ns.example.com")), request);
131 let bytes = match block_on(from_post) {
132 Ok(bytes) => bytes,
133 e => panic!("{:#?}", e),
134 };
135
136 let msg_from_post = Message::from_vec(bytes.as_ref()).expect("bytes failed");
137 assert_eq!(message, msg_from_post);
138 }
139}