use futures_util::ready;
use hyper::service::HttpService;
use std::future::Future;
use std::marker::PhantomPinned;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{error::Error as StdError, io, time::Duration};
use bytes::Bytes;
use http::{Request, Response};
use http_body::Body;
use hyper::{
body::Incoming,
rt::{Read, ReadBuf, Timer, Write},
service::Service,
};
#[cfg(feature = "http1")]
use hyper::server::conn::http1;
#[cfg(feature = "http2")]
use hyper::{rt::bounds::Http2ServerConnExec, server::conn::http2};
#[cfg(any(not(feature = "http2"), not(feature = "http1")))]
use std::marker::PhantomData;
use pin_project_lite::pin_project;
use crate::common::rewind::Rewind;
type Error = Box<dyn std::error::Error + Send + Sync>;
type Result<T> = std::result::Result<T, Error>;
const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
#[cfg(feature = "http2")]
pub trait HttpServerConnExec<A, B: Body>: Http2ServerConnExec<A, B> {}
#[cfg(feature = "http2")]
impl<A, B: Body, T: Http2ServerConnExec<A, B>> HttpServerConnExec<A, B> for T {}
#[cfg(not(feature = "http2"))]
pub trait HttpServerConnExec<A, B: Body> {}
#[cfg(not(feature = "http2"))]
impl<A, B: Body, T> HttpServerConnExec<A, B> for T {}
#[derive(Clone, Debug)]
pub struct Builder<E> {
#[cfg(feature = "http1")]
http1: http1::Builder,
#[cfg(feature = "http2")]
http2: http2::Builder<E>,
#[cfg(any(feature = "http1", feature = "http2"))]
version: Option<Version>,
#[cfg(not(feature = "http2"))]
_executor: E,
}
impl<E> Builder<E> {
pub fn new(executor: E) -> Self {
Self {
#[cfg(feature = "http1")]
http1: http1::Builder::new(),
#[cfg(feature = "http2")]
http2: http2::Builder::new(executor),
#[cfg(any(feature = "http1", feature = "http2"))]
version: None,
#[cfg(not(feature = "http2"))]
_executor: executor,
}
}
#[cfg(feature = "http1")]
pub fn http1(&mut self) -> Http1Builder<'_, E> {
Http1Builder { inner: self }
}
#[cfg(feature = "http2")]
pub fn http2(&mut self) -> Http2Builder<'_, E> {
Http2Builder { inner: self }
}
#[cfg(feature = "http2")]
pub fn http2_only(mut self) -> Self {
assert!(self.version.is_none());
self.version = Some(Version::H2);
self
}
#[cfg(feature = "http1")]
pub fn http1_only(mut self) -> Self {
assert!(self.version.is_none());
self.version = Some(Version::H1);
self
}
pub fn serve_connection<I, S, B>(&self, io: I, service: S) -> Connection<'_, I, S, E>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + 'static,
E: HttpServerConnExec<S::Future, B>,
{
let state = match self.version {
#[cfg(feature = "http1")]
Some(Version::H1) => {
let io = Rewind::new_buffered(io, Bytes::new());
let conn = self.http1.serve_connection(io, service);
ConnState::H1 { conn }
}
#[cfg(feature = "http2")]
Some(Version::H2) => {
let io = Rewind::new_buffered(io, Bytes::new());
let conn = self.http2.serve_connection(io, service);
ConnState::H2 { conn }
}
#[cfg(any(feature = "http1", feature = "http2"))]
_ => ConnState::ReadVersion {
read_version: read_version(io),
builder: Cow::Borrowed(self),
service: Some(service),
},
};
Connection { state }
}
pub fn serve_connection_with_upgrades<I, S, B>(
&self,
io: I,
service: S,
) -> UpgradeableConnection<'_, I, S, E>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + Send + 'static,
E: HttpServerConnExec<S::Future, B>,
{
UpgradeableConnection {
state: UpgradeableConnState::ReadVersion {
read_version: read_version(io),
builder: Cow::Borrowed(self),
service: Some(service),
},
}
}
}
#[derive(Copy, Clone, Debug)]
enum Version {
H1,
H2,
}
impl Version {
#[must_use]
#[cfg(any(not(feature = "http2"), not(feature = "http1")))]
pub fn unsupported(self) -> Error {
match self {
Version::H1 => Error::from("HTTP/1 is not supported"),
Version::H2 => Error::from("HTTP/2 is not supported"),
}
}
}
fn read_version<I>(io: I) -> ReadVersion<I>
where
I: Read + Unpin,
{
ReadVersion {
io: Some(io),
buf: [MaybeUninit::uninit(); 24],
filled: 0,
version: Version::H2,
cancelled: false,
_pin: PhantomPinned,
}
}
pin_project! {
struct ReadVersion<I> {
io: Option<I>,
buf: [MaybeUninit<u8>; 24],
filled: usize,
version: Version,
cancelled: bool,
#[pin]
_pin: PhantomPinned,
}
}
impl<I> ReadVersion<I> {
pub fn cancel(self: Pin<&mut Self>) {
*self.project().cancelled = true;
}
}
impl<I> Future for ReadVersion<I>
where
I: Read + Unpin,
{
type Output = io::Result<(Version, Rewind<I>)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
if *this.cancelled {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "Cancelled")));
}
let mut buf = ReadBuf::uninit(&mut *this.buf);
unsafe {
buf.unfilled().advance(*this.filled);
};
while buf.filled().len() < H2_PREFACE.len() {
let len = buf.filled().len();
ready!(Pin::new(this.io.as_mut().unwrap()).poll_read(cx, buf.unfilled()))?;
*this.filled = buf.filled().len();
if buf.filled().len() == len
|| buf.filled()[len..] != H2_PREFACE[len..buf.filled().len()]
{
*this.version = Version::H1;
break;
}
}
let io = this.io.take().unwrap();
let buf = buf.filled().to_vec();
Poll::Ready(Ok((
*this.version,
Rewind::new_buffered(io, Bytes::from(buf)),
)))
}
}
pin_project! {
pub struct Connection<'a, I, S, E>
where
S: HttpService<Incoming>,
{
#[pin]
state: ConnState<'a, I, S, E>,
}
}
enum Cow<'a, T> {
Borrowed(&'a T),
Owned(T),
}
impl<'a, T> std::ops::Deref for Cow<'a, T> {
type Target = T;
fn deref(&self) -> &T {
match self {
Cow::Borrowed(t) => &*t,
Cow::Owned(ref t) => t,
}
}
}
#[cfg(feature = "http1")]
type Http1Connection<I, S> = hyper::server::conn::http1::Connection<Rewind<I>, S>;
#[cfg(not(feature = "http1"))]
type Http1Connection<I, S> = (PhantomData<I>, PhantomData<S>);
#[cfg(feature = "http2")]
type Http2Connection<I, S, E> = hyper::server::conn::http2::Connection<Rewind<I>, S, E>;
#[cfg(not(feature = "http2"))]
type Http2Connection<I, S, E> = (PhantomData<I>, PhantomData<S>, PhantomData<E>);
pin_project! {
#[project = ConnStateProj]
enum ConnState<'a, I, S, E>
where
S: HttpService<Incoming>,
{
ReadVersion {
#[pin]
read_version: ReadVersion<I>,
builder: Cow<'a, Builder<E>>,
service: Option<S>,
},
H1 {
#[pin]
conn: Http1Connection<I, S>,
},
H2 {
#[pin]
conn: Http2Connection<I, S, E>,
},
}
}
impl<I, S, E, B> Connection<'_, I, S, E>
where
S: HttpService<Incoming, ResBody = B>,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
E: HttpServerConnExec<S::Future, B>,
{
pub fn graceful_shutdown(self: Pin<&mut Self>) {
match self.project().state.project() {
ConnStateProj::ReadVersion { read_version, .. } => read_version.cancel(),
#[cfg(feature = "http1")]
ConnStateProj::H1 { conn } => conn.graceful_shutdown(),
#[cfg(feature = "http2")]
ConnStateProj::H2 { conn } => conn.graceful_shutdown(),
#[cfg(any(not(feature = "http1"), not(feature = "http2")))]
_ => unreachable!(),
}
}
pub fn into_owned(self) -> Connection<'static, I, S, E>
where
Builder<E>: Clone,
{
Connection {
state: match self.state {
ConnState::ReadVersion {
read_version,
builder,
service,
} => ConnState::ReadVersion {
read_version,
service,
builder: Cow::Owned(builder.clone()),
},
#[cfg(feature = "http1")]
ConnState::H1 { conn } => ConnState::H1 { conn },
#[cfg(feature = "http2")]
ConnState::H2 { conn } => ConnState::H2 { conn },
#[cfg(any(not(feature = "http1"), not(feature = "http2")))]
_ => unreachable!(),
},
}
}
}
impl<I, S, E, B> Future for Connection<'_, I, S, E>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + 'static,
E: HttpServerConnExec<S::Future, B>,
{
type Output = Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let mut this = self.as_mut().project();
match this.state.as_mut().project() {
ConnStateProj::ReadVersion {
read_version,
builder,
service,
} => {
let (version, io) = ready!(read_version.poll(cx))?;
let service = service.take().unwrap();
match version {
#[cfg(feature = "http1")]
Version::H1 => {
let conn = builder.http1.serve_connection(io, service);
this.state.set(ConnState::H1 { conn });
}
#[cfg(feature = "http2")]
Version::H2 => {
let conn = builder.http2.serve_connection(io, service);
this.state.set(ConnState::H2 { conn });
}
#[cfg(any(not(feature = "http1"), not(feature = "http2")))]
_ => return Poll::Ready(Err(version.unsupported())),
}
}
#[cfg(feature = "http1")]
ConnStateProj::H1 { conn } => {
return conn.poll(cx).map_err(Into::into);
}
#[cfg(feature = "http2")]
ConnStateProj::H2 { conn } => {
return conn.poll(cx).map_err(Into::into);
}
#[cfg(any(not(feature = "http1"), not(feature = "http2")))]
_ => unreachable!(),
}
}
}
}
pin_project! {
pub struct UpgradeableConnection<'a, I, S, E>
where
S: HttpService<Incoming>,
{
#[pin]
state: UpgradeableConnState<'a, I, S, E>,
}
}
#[cfg(feature = "http1")]
type Http1UpgradeableConnection<I, S> = hyper::server::conn::http1::UpgradeableConnection<I, S>;
#[cfg(not(feature = "http1"))]
type Http1UpgradeableConnection<I, S> = (PhantomData<I>, PhantomData<S>);
pin_project! {
#[project = UpgradeableConnStateProj]
enum UpgradeableConnState<'a, I, S, E>
where
S: HttpService<Incoming>,
{
ReadVersion {
#[pin]
read_version: ReadVersion<I>,
builder: Cow<'a, Builder<E>>,
service: Option<S>,
},
H1 {
#[pin]
conn: Http1UpgradeableConnection<Rewind<I>, S>,
},
H2 {
#[pin]
conn: Http2Connection<I, S, E>,
},
}
}
impl<I, S, E, B> UpgradeableConnection<'_, I, S, E>
where
S: HttpService<Incoming, ResBody = B>,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
E: HttpServerConnExec<S::Future, B>,
{
pub fn graceful_shutdown(self: Pin<&mut Self>) {
match self.project().state.project() {
UpgradeableConnStateProj::ReadVersion { read_version, .. } => read_version.cancel(),
#[cfg(feature = "http1")]
UpgradeableConnStateProj::H1 { conn } => conn.graceful_shutdown(),
#[cfg(feature = "http2")]
UpgradeableConnStateProj::H2 { conn } => conn.graceful_shutdown(),
#[cfg(any(not(feature = "http1"), not(feature = "http2")))]
_ => unreachable!(),
}
}
pub fn into_owned(self) -> UpgradeableConnection<'static, I, S, E>
where
Builder<E>: Clone,
{
UpgradeableConnection {
state: match self.state {
UpgradeableConnState::ReadVersion {
read_version,
builder,
service,
} => UpgradeableConnState::ReadVersion {
read_version,
service,
builder: Cow::Owned(builder.clone()),
},
#[cfg(feature = "http1")]
UpgradeableConnState::H1 { conn } => UpgradeableConnState::H1 { conn },
#[cfg(feature = "http2")]
UpgradeableConnState::H2 { conn } => UpgradeableConnState::H2 { conn },
#[cfg(any(not(feature = "http1"), not(feature = "http2")))]
_ => unreachable!(),
},
}
}
}
impl<I, S, E, B> Future for UpgradeableConnection<'_, I, S, E>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + Send + 'static,
E: HttpServerConnExec<S::Future, B>,
{
type Output = Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let mut this = self.as_mut().project();
match this.state.as_mut().project() {
UpgradeableConnStateProj::ReadVersion {
read_version,
builder,
service,
} => {
let (version, io) = ready!(read_version.poll(cx))?;
let service = service.take().unwrap();
match version {
#[cfg(feature = "http1")]
Version::H1 => {
let conn = builder.http1.serve_connection(io, service).with_upgrades();
this.state.set(UpgradeableConnState::H1 { conn });
}
#[cfg(feature = "http2")]
Version::H2 => {
let conn = builder.http2.serve_connection(io, service);
this.state.set(UpgradeableConnState::H2 { conn });
}
#[cfg(any(not(feature = "http1"), not(feature = "http2")))]
_ => return Poll::Ready(Err(version.unsupported())),
}
}
#[cfg(feature = "http1")]
UpgradeableConnStateProj::H1 { conn } => {
return conn.poll(cx).map_err(Into::into);
}
#[cfg(feature = "http2")]
UpgradeableConnStateProj::H2 { conn } => {
return conn.poll(cx).map_err(Into::into);
}
#[cfg(any(not(feature = "http1"), not(feature = "http2")))]
_ => unreachable!(),
}
}
}
}
#[cfg(feature = "http1")]
pub struct Http1Builder<'a, E> {
inner: &'a mut Builder<E>,
}
#[cfg(feature = "http1")]
impl<E> Http1Builder<'_, E> {
#[cfg(feature = "http2")]
pub fn http2(&mut self) -> Http2Builder<'_, E> {
Http2Builder { inner: self.inner }
}
pub fn half_close(&mut self, val: bool) -> &mut Self {
self.inner.http1.half_close(val);
self
}
pub fn keep_alive(&mut self, val: bool) -> &mut Self {
self.inner.http1.keep_alive(val);
self
}
pub fn title_case_headers(&mut self, enabled: bool) -> &mut Self {
self.inner.http1.title_case_headers(enabled);
self
}
pub fn preserve_header_case(&mut self, enabled: bool) -> &mut Self {
self.inner.http1.preserve_header_case(enabled);
self
}
pub fn max_headers(&mut self, val: usize) -> &mut Self {
self.inner.http1.max_headers(val);
self
}
pub fn header_read_timeout(&mut self, read_timeout: impl Into<Option<Duration>>) -> &mut Self {
self.inner.http1.header_read_timeout(read_timeout);
self
}
pub fn writev(&mut self, val: bool) -> &mut Self {
self.inner.http1.writev(val);
self
}
pub fn max_buf_size(&mut self, max: usize) -> &mut Self {
self.inner.http1.max_buf_size(max);
self
}
pub fn pipeline_flush(&mut self, enabled: bool) -> &mut Self {
self.inner.http1.pipeline_flush(enabled);
self
}
pub fn timer<M>(&mut self, timer: M) -> &mut Self
where
M: Timer + Send + Sync + 'static,
{
self.inner.http1.timer(timer);
self
}
#[cfg(feature = "http2")]
pub async fn serve_connection<I, S, B>(&self, io: I, service: S) -> Result<()>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + 'static,
E: HttpServerConnExec<S::Future, B>,
{
self.inner.serve_connection(io, service).await
}
#[cfg(not(feature = "http2"))]
pub async fn serve_connection<I, S, B>(&self, io: I, service: S) -> Result<()>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + 'static,
{
self.inner.serve_connection(io, service).await
}
#[cfg(feature = "http2")]
pub fn serve_connection_with_upgrades<I, S, B>(
&self,
io: I,
service: S,
) -> UpgradeableConnection<'_, I, S, E>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + Send + 'static,
E: HttpServerConnExec<S::Future, B>,
{
self.inner.serve_connection_with_upgrades(io, service)
}
}
#[cfg(feature = "http2")]
pub struct Http2Builder<'a, E> {
inner: &'a mut Builder<E>,
}
#[cfg(feature = "http2")]
impl<E> Http2Builder<'_, E> {
#[cfg(feature = "http1")]
pub fn http1(&mut self) -> Http1Builder<'_, E> {
Http1Builder { inner: self.inner }
}
pub fn max_pending_accept_reset_streams(&mut self, max: impl Into<Option<usize>>) -> &mut Self {
self.inner.http2.max_pending_accept_reset_streams(max);
self
}
pub fn initial_stream_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
self.inner.http2.initial_stream_window_size(sz);
self
}
pub fn initial_connection_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
self.inner.http2.initial_connection_window_size(sz);
self
}
pub fn adaptive_window(&mut self, enabled: bool) -> &mut Self {
self.inner.http2.adaptive_window(enabled);
self
}
pub fn max_frame_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
self.inner.http2.max_frame_size(sz);
self
}
pub fn max_concurrent_streams(&mut self, max: impl Into<Option<u32>>) -> &mut Self {
self.inner.http2.max_concurrent_streams(max);
self
}
pub fn keep_alive_interval(&mut self, interval: impl Into<Option<Duration>>) -> &mut Self {
self.inner.http2.keep_alive_interval(interval);
self
}
pub fn keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self {
self.inner.http2.keep_alive_timeout(timeout);
self
}
pub fn max_send_buf_size(&mut self, max: usize) -> &mut Self {
self.inner.http2.max_send_buf_size(max);
self
}
pub fn enable_connect_protocol(&mut self) -> &mut Self {
self.inner.http2.enable_connect_protocol();
self
}
pub fn max_header_list_size(&mut self, max: u32) -> &mut Self {
self.inner.http2.max_header_list_size(max);
self
}
pub fn timer<M>(&mut self, timer: M) -> &mut Self
where
M: Timer + Send + Sync + 'static,
{
self.inner.http2.timer(timer);
self
}
pub async fn serve_connection<I, S, B>(&self, io: I, service: S) -> Result<()>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + 'static,
E: HttpServerConnExec<S::Future, B>,
{
self.inner.serve_connection(io, service).await
}
pub fn serve_connection_with_upgrades<I, S, B>(
&self,
io: I,
service: S,
) -> UpgradeableConnection<'_, I, S, E>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + Send + 'static,
E: HttpServerConnExec<S::Future, B>,
{
self.inner.serve_connection_with_upgrades(io, service)
}
}
#[cfg(test)]
mod tests {
use crate::{
rt::{TokioExecutor, TokioIo},
server::conn::auto,
};
use http::{Request, Response};
use http_body::Body;
use http_body_util::{BodyExt, Empty, Full};
use hyper::{body, body::Bytes, client, service::service_fn};
use std::{convert::Infallible, error::Error as StdError, net::SocketAddr, time::Duration};
use tokio::{
net::{TcpListener, TcpStream},
pin,
};
const BODY: &[u8] = b"Hello, world!";
#[test]
fn configuration() {
auto::Builder::new(TokioExecutor::new())
.http1()
.keep_alive(true)
.http2()
.keep_alive_interval(None);
let mut builder = auto::Builder::new(TokioExecutor::new());
builder.http1().keep_alive(true);
builder.http2().keep_alive_interval(None);
}
#[cfg(not(miri))]
#[tokio::test]
async fn http1() {
let addr = start_server(false, false).await;
let mut sender = connect_h1(addr).await;
let response = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.unwrap();
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(body, BODY);
}
#[cfg(not(miri))]
#[tokio::test]
async fn http2() {
let addr = start_server(false, false).await;
let mut sender = connect_h2(addr).await;
let response = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.unwrap();
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(body, BODY);
}
#[cfg(not(miri))]
#[tokio::test]
async fn http2_only() {
let addr = start_server(false, true).await;
let mut sender = connect_h2(addr).await;
let response = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.unwrap();
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(body, BODY);
}
#[cfg(not(miri))]
#[tokio::test]
async fn http2_only_fail_if_client_is_http1() {
let addr = start_server(false, true).await;
let mut sender = connect_h1(addr).await;
let _ = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.expect_err("should fail");
}
#[cfg(not(miri))]
#[tokio::test]
async fn http1_only() {
let addr = start_server(true, false).await;
let mut sender = connect_h1(addr).await;
let response = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.unwrap();
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(body, BODY);
}
#[cfg(not(miri))]
#[tokio::test]
async fn http1_only_fail_if_client_is_http2() {
let addr = start_server(true, false).await;
let mut sender = connect_h2(addr).await;
let _ = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.expect_err("should fail");
}
#[cfg(not(miri))]
#[tokio::test]
async fn graceful_shutdown() {
let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))
.await
.unwrap();
let listener_addr = listener.local_addr().unwrap();
let listen_task = tokio::spawn(async move { listener.accept().await.unwrap() });
let _stream = TcpStream::connect(listener_addr).await.unwrap();
let (stream, _) = listen_task.await.unwrap();
let stream = TokioIo::new(stream);
let builder = auto::Builder::new(TokioExecutor::new());
let connection = builder.serve_connection(stream, service_fn(hello));
pin!(connection);
connection.as_mut().graceful_shutdown();
let connection_error = tokio::time::timeout(Duration::from_millis(200), connection)
.await
.expect("Connection should have finished in a timely manner after graceful shutdown.")
.expect_err("Connection should have been interrupted.");
let connection_error = connection_error
.downcast_ref::<std::io::Error>()
.expect("The error should have been `std::io::Error`.");
assert_eq!(connection_error.kind(), std::io::ErrorKind::Interrupted);
}
async fn connect_h1<B>(addr: SocketAddr) -> client::conn::http1::SendRequest<B>
where
B: Body + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap());
let (sender, connection) = client::conn::http1::handshake(stream).await.unwrap();
tokio::spawn(connection);
sender
}
async fn connect_h2<B>(addr: SocketAddr) -> client::conn::http2::SendRequest<B>
where
B: Body + Unpin + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap());
let (sender, connection) = client::conn::http2::Builder::new(TokioExecutor::new())
.handshake(stream)
.await
.unwrap();
tokio::spawn(connection);
sender
}
async fn start_server(h1_only: bool, h2_only: bool) -> SocketAddr {
let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
let listener = TcpListener::bind(addr).await.unwrap();
let local_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
let stream = TokioIo::new(stream);
tokio::task::spawn(async move {
let mut builder = auto::Builder::new(TokioExecutor::new());
if h1_only {
builder = builder.http1_only();
builder.serve_connection(stream, service_fn(hello)).await
} else if h2_only {
builder = builder.http2_only();
builder.serve_connection(stream, service_fn(hello)).await
} else {
builder
.http2()
.max_header_list_size(4096)
.serve_connection_with_upgrades(stream, service_fn(hello))
.await
}
.unwrap();
});
}
});
local_addr
}
async fn hello(_req: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, Infallible> {
Ok(Response::new(Full::new(Bytes::from(BODY))))
}
}