trust_dns_proto/https/
https_server.rs1use std::fmt::Debug;
11use std::str::FromStr;
12use std::sync::Arc;
13
14use bytes::{Bytes, BytesMut};
15use futures_util::stream::{Stream, StreamExt};
16use h2;
17use http::header::CONTENT_LENGTH;
18use http::{Method, Request};
19use tracing::debug;
20
21use crate::https::HttpsError;
22
23pub async fn message_from<R>(
28 this_server_name: Option<Arc<str>>,
29 request: Request<R>,
30) -> Result<BytesMut, HttpsError>
31where
32 R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin,
33{
34 debug!("Received request: {:#?}", request);
35
36 let this_server_name = this_server_name.as_deref();
37 match crate::https::request::verify(this_server_name, &request) {
38 Ok(_) => (),
39 Err(err) => return Err(err),
40 }
41
42 let mut content_length = None;
44 if let Some(length) = request.headers().get(CONTENT_LENGTH) {
45 let length = usize::from_str(length.to_str()?)?;
46 debug!("got message length: {}", length);
47 content_length = Some(length);
48 }
49
50 match *request.method() {
51 Method::GET => Err(format!("GET unimplemented: {}", request.method()).into()),
52 Method::POST => message_from_post(request.into_body(), content_length).await,
53 _ => Err(format!("bad method: {}", request.method()).into()),
54 }
55}
56
57pub(crate) async fn message_from_post<R>(
59 mut request_stream: R,
60 length: Option<usize>,
61) -> Result<BytesMut, HttpsError>
62where
63 R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin,
64{
65 let mut bytes = BytesMut::with_capacity(length.unwrap_or(0).clamp(512, 4096));
66
67 loop {
68 match request_stream.next().await {
69 Some(Ok(mut frame)) => bytes.extend_from_slice(&frame.split_off(0)),
70 Some(Err(err)) => return Err(err.into()),
71 None => {
72 return if let Some(length) = length {
73 if bytes.len() == length {
75 Ok(bytes)
76 } else {
77 Err("not all bytes received".into())
78 }
79 } else {
80 Ok(bytes)
81 };
82 }
83 };
84
85 if let Some(length) = length {
86 if bytes.len() == length {
88 return Ok(bytes);
89 }
90 }
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use futures_executor::block_on;
97 use std::pin::Pin;
98 use std::task::{Context, Poll};
99
100 use crate::https::request;
101 use crate::op::Message;
102
103 use super::*;
104
105 #[derive(Debug)]
106 struct TestBytesStream(Vec<Result<Bytes, h2::Error>>);
107
108 impl Stream for TestBytesStream {
109 type Item = Result<Bytes, h2::Error>;
110
111 fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
112 match self.0.pop() {
113 Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes))),
114 Some(Err(err)) => Poll::Ready(Some(Err(err))),
115 None => Poll::Ready(None),
116 }
117 }
118 }
119
120 #[test]
121 fn test_from_post() {
122 let message = Message::new();
123 let msg_bytes = message.to_vec().unwrap();
124 let len = msg_bytes.len();
125 let stream = TestBytesStream(vec![Ok(Bytes::from(msg_bytes))]);
126 let request = request::new("ns.example.com", len).unwrap();
127 let request = request.map(|()| stream);
128
129 let from_post = message_from(Some(Arc::from("ns.example.com")), request);
130 let bytes = match block_on(from_post) {
131 Ok(bytes) => bytes,
132 e => panic!("{:#?}", e),
133 };
134
135 let msg_from_post = Message::from_vec(bytes.as_ref()).expect("bytes failed");
136 assert_eq!(message, msg_from_post);
137 }
138}