trust_dns_proto/https/
https_server.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! HTTPS related server items
9
10use std::fmt::Debug;
11use std::str::FromStr;
12use std::sync::Arc;
13
14use bytes::{Bytes, BytesMut};
15use futures_util::stream::{Stream, StreamExt};
16use h2;
17use http::header::CONTENT_LENGTH;
18use http::{Method, Request};
19use tracing::debug;
20
21use crate::https::HttpsError;
22
23/// Given an HTTP request, return a future that will result in the next sequence of bytes.
24///
25/// To allow downstream clients to do something interesting with the lifetime of the bytes, this doesn't
26///   perform a conversion to a Message, only collects all the bytes.
27pub async fn message_from<R>(
28    this_server_name: Option<Arc<str>>,
29    request: Request<R>,
30) -> Result<BytesMut, HttpsError>
31where
32    R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin,
33{
34    debug!("Received request: {:#?}", request);
35
36    let this_server_name = this_server_name.as_deref();
37    match crate::https::request::verify(this_server_name, &request) {
38        Ok(_) => (),
39        Err(err) => return Err(err),
40    }
41
42    // attempt to get the content length
43    let mut content_length = None;
44    if let Some(length) = request.headers().get(CONTENT_LENGTH) {
45        let length = usize::from_str(length.to_str()?)?;
46        debug!("got message length: {}", length);
47        content_length = Some(length);
48    }
49
50    match *request.method() {
51        Method::GET => Err(format!("GET unimplemented: {}", request.method()).into()),
52        Method::POST => message_from_post(request.into_body(), content_length).await,
53        _ => Err(format!("bad method: {}", request.method()).into()),
54    }
55}
56
57/// Deserialize the message from a POST message
58pub(crate) async fn message_from_post<R>(
59    mut request_stream: R,
60    length: Option<usize>,
61) -> Result<BytesMut, HttpsError>
62where
63    R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin,
64{
65    let mut bytes = BytesMut::with_capacity(length.unwrap_or(0).clamp(512, 4096));
66
67    loop {
68        match request_stream.next().await {
69            Some(Ok(mut frame)) => bytes.extend_from_slice(&frame.split_off(0)),
70            Some(Err(err)) => return Err(err.into()),
71            None => {
72                return if let Some(length) = length {
73                    // wait until we have all the bytes
74                    if bytes.len() == length {
75                        Ok(bytes)
76                    } else {
77                        Err("not all bytes received".into())
78                    }
79                } else {
80                    Ok(bytes)
81                };
82            }
83        };
84
85        if let Some(length) = length {
86            // wait until we have all the bytes
87            if bytes.len() == length {
88                return Ok(bytes);
89            }
90        }
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use futures_executor::block_on;
97    use std::pin::Pin;
98    use std::task::{Context, Poll};
99
100    use crate::https::request;
101    use crate::op::Message;
102
103    use super::*;
104
105    #[derive(Debug)]
106    struct TestBytesStream(Vec<Result<Bytes, h2::Error>>);
107
108    impl Stream for TestBytesStream {
109        type Item = Result<Bytes, h2::Error>;
110
111        fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
112            match self.0.pop() {
113                Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes))),
114                Some(Err(err)) => Poll::Ready(Some(Err(err))),
115                None => Poll::Ready(None),
116            }
117        }
118    }
119
120    #[test]
121    fn test_from_post() {
122        let message = Message::new();
123        let msg_bytes = message.to_vec().unwrap();
124        let len = msg_bytes.len();
125        let stream = TestBytesStream(vec![Ok(Bytes::from(msg_bytes))]);
126        let request = request::new("ns.example.com", len).unwrap();
127        let request = request.map(|()| stream);
128
129        let from_post = message_from(Some(Arc::from("ns.example.com")), request);
130        let bytes = match block_on(from_post) {
131            Ok(bytes) => bytes,
132            e => panic!("{:#?}", e),
133        };
134
135        let msg_from_post = Message::from_vec(bytes.as_ref()).expect("bytes failed");
136        assert_eq!(message, msg_from_post);
137    }
138}