hickory_proto/h2/
h2_server.rs

1// Copyright 2015-2018 Benjamin Fry <benjaminfry@me.com>
2//
3// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
4// https://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// https://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! HTTPS related server items
9
10use alloc::sync::Arc;
11use core::fmt::Debug;
12use core::str::FromStr;
13
14use bytes::{Bytes, BytesMut};
15use futures_util::stream::{Stream, StreamExt};
16use h2;
17use http::header::CONTENT_LENGTH;
18use http::{Method, Request};
19use tracing::debug;
20
21use crate::h2::HttpsError;
22use crate::http::Version;
23
24/// Given an HTTP request, return a future that will result in the next sequence of bytes.
25///
26/// To allow downstream clients to do something interesting with the lifetime of the bytes, this doesn't
27///   perform a conversion to a Message, only collects all the bytes.
28pub async fn message_from<R>(
29    this_server_name: Option<Arc<str>>,
30    this_server_endpoint: Arc<str>,
31    request: Request<R>,
32) -> Result<BytesMut, HttpsError>
33where
34    R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin,
35{
36    debug!("Received request: {:#?}", request);
37
38    let this_server_name = this_server_name.as_deref();
39    match crate::http::request::verify(
40        Version::Http2,
41        this_server_name,
42        &this_server_endpoint,
43        &request,
44    ) {
45        Ok(_) => (),
46        Err(err) => return Err(err),
47    }
48
49    // attempt to get the content length
50    let mut content_length = None;
51    if let Some(length) = request.headers().get(CONTENT_LENGTH) {
52        let length = usize::from_str(length.to_str()?)?;
53        debug!("got message length: {}", length);
54        content_length = Some(length);
55    }
56
57    match *request.method() {
58        Method::GET => Err(format!("GET unimplemented: {}", request.method()).into()),
59        Method::POST => message_from_post(request.into_body(), content_length).await,
60        _ => Err(format!("bad method: {}", request.method()).into()),
61    }
62}
63
64/// Deserialize the message from a POST message
65pub(crate) async fn message_from_post<R>(
66    mut request_stream: R,
67    length: Option<usize>,
68) -> Result<BytesMut, HttpsError>
69where
70    R: Stream<Item = Result<Bytes, h2::Error>> + 'static + Send + Debug + Unpin,
71{
72    let mut bytes = BytesMut::with_capacity(length.unwrap_or(0).clamp(512, 4_096));
73
74    loop {
75        match request_stream.next().await {
76            Some(Ok(mut frame)) => bytes.extend_from_slice(&frame.split_off(0)),
77            Some(Err(err)) => return Err(err.into()),
78            None => {
79                return if let Some(length) = length {
80                    // wait until we have all the bytes
81                    if bytes.len() == length {
82                        Ok(bytes)
83                    } else {
84                        Err("not all bytes received".into())
85                    }
86                } else {
87                    Ok(bytes)
88                };
89            }
90        };
91
92        if let Some(length) = length {
93            // wait until we have all the bytes
94            if bytes.len() == length {
95                return Ok(bytes);
96            }
97        }
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use alloc::vec::Vec;
104    use core::pin::Pin;
105    use core::task::{Context, Poll};
106    use futures_executor::block_on;
107
108    use test_support::subscribe;
109
110    use crate::http::request;
111    use crate::op::Message;
112
113    use super::*;
114
115    #[derive(Debug)]
116    struct TestBytesStream(Vec<Result<Bytes, h2::Error>>);
117
118    impl Stream for TestBytesStream {
119        type Item = Result<Bytes, h2::Error>;
120
121        fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
122            match self.0.pop() {
123                Some(Ok(bytes)) => Poll::Ready(Some(Ok(bytes))),
124                Some(Err(err)) => Poll::Ready(Some(Err(err))),
125                None => Poll::Ready(None),
126            }
127        }
128    }
129
130    #[test]
131    fn test_from_post() {
132        subscribe();
133        let message = Message::new();
134        let msg_bytes = message.to_vec().unwrap();
135        let len = msg_bytes.len();
136        let stream = TestBytesStream(vec![Ok(Bytes::from(msg_bytes))]);
137        let request = request::new(Version::Http2, "ns.example.com", "/dns-query", len).unwrap();
138        let request = request.map(|()| stream);
139
140        let from_post = message_from(
141            Some(Arc::from("ns.example.com")),
142            "/dns-query".into(),
143            request,
144        );
145        let bytes = match block_on(from_post) {
146            Ok(bytes) => bytes,
147            e => panic!("{:#?}", e),
148        };
149
150        let msg_from_post = Message::from_vec(bytes.as_ref()).expect("bytes failed");
151        assert_eq!(message, msg_from_post);
152    }
153}