hickory_proto/h2/
h2_server.rs1use alloc::sync::Arc;
11use core::fmt::Debug;
12use core::str::FromStr;
13
14use bytes::{Bytes, BytesMut};
15use futures_util::stream::{Stream, StreamExt};
16use h2;
17use http::header::CONTENT_LENGTH;
18use http::{Method, Request};
19use tracing::debug;
20
21use crate::h2::HttpsError;
22use crate::http::Version;
23
24pub async fn message_from<R>(
29 this_server_name: Option<Arc<str>>,
30 this_server_endpoint: Arc<str>,
31 request: Request<R>,
32) -> Result<BytesMut, HttpsError>
33where
34 R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin,
35{
36 debug!("Received request: {:#?}", request);
37
38 let this_server_name = this_server_name.as_deref();
39 match crate::http::request::verify(
40 Version::Http2,
41 this_server_name,
42 &this_server_endpoint,
43 &request,
44 ) {
45 Ok(_) => (),
46 Err(err) => return Err(err),
47 }
48
49 let mut content_length = None;
51 if let Some(length) = request.headers().get(CONTENT_LENGTH) {
52 let length = usize::from_str(length.to_str()?)?;
53 debug!("got message length: {}", length);
54 content_length = Some(length);
55 }
56
57 match *request.method() {
58 Method::GET => Err(format!("GET unimplemented: {}", request.method()).into()),
59 Method::POST => message_from_post(request.into_body(), content_length).await,
60 _ => Err(format!("bad method: {}", request.method()).into()),
61 }
62}
63
64pub(crate) async fn message_from_post<R>(
66 mut request_stream: R,
67 length: Option<usize>,
68) -> Result<BytesMut, HttpsError>
69where
70 R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin,
71{
72 let mut bytes = BytesMut::with_capacity(length.unwrap_or(0).clamp(512, 4_096));
73
74 loop {
75 match request_stream.next().await {
76 Some(Ok(mut frame)) => bytes.extend_from_slice(&frame.split_off(0)),
77 Some(Err(err)) => return Err(err.into()),
78 None => {
79 return if let Some(length) = length {
80 if bytes.len() == length {
82 Ok(bytes)
83 } else {
84 Err("not all bytes received".into())
85 }
86 } else {
87 Ok(bytes)
88 };
89 }
90 };
91
92 if let Some(length) = length {
93 if bytes.len() == length {
95 return Ok(bytes);
96 }
97 }
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use alloc::vec::Vec;
104 use core::pin::Pin;
105 use core::task::{Context, Poll};
106 use futures_executor::block_on;
107
108 use test_support::subscribe;
109
110 use crate::http::request;
111 use crate::op::Message;
112
113 use super::*;
114
115 #[derive(Debug)]
116 struct TestBytesStream(Vec<Result<Bytes, h2::Error>>);
117
118 impl Stream for TestBytesStream {
119 type Item = Result<Bytes, h2::Error>;
120
121 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122 match self.0.pop() {
123 Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes))),
124 Some(Err(err)) => Poll::Ready(Some(Err(err))),
125 None => Poll::Ready(None),
126 }
127 }
128 }
129
130 #[test]
131 fn test_from_post() {
132 subscribe();
133 let message = Message::new();
134 let msg_bytes = message.to_vec().unwrap();
135 let len = msg_bytes.len();
136 let stream = TestBytesStream(vec![Ok(Bytes::from(msg_bytes))]);
137 let request = request::new(Version::Http2, "ns.example.com", "/dns-query", len).unwrap();
138 let request = request.map(|()| stream);
139
140 let from_post = message_from(
141 Some(Arc::from("ns.example.com")),
142 "/dns-query".into(),
143 request,
144 );
145 let bytes = match block_on(from_post) {
146 Ok(bytes) => bytes,
147 e => panic!("{:#?}", e),
148 };
149
150 let msg_from_post = Message::from_vec(bytes.as_ref()).expect("bytes failed");
151 assert_eq!(message, msg_from_post);
152 }
153}