1pub mod upgrade;
4
5use futures_util::ready;
6use hyper::service::HttpService;
7use std::future::Future;
8use std::marker::PhantomPinned;
9use std::mem::MaybeUninit;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use std::{error::Error as StdError, io, time::Duration};
13
14use bytes::Bytes;
15use http::{Request, Response};
16use http_body::Body;
17use hyper::{
18 body::Incoming,
19 rt::{Read, ReadBuf, Timer, Write},
20 service::Service,
21};
22
23#[cfg(feature = "http1")]
24use hyper::server::conn::http1;
25
26#[cfg(feature = "http2")]
27use hyper::{rt::bounds::Http2ServerConnExec, server::conn::http2};
28
29#[cfg(any(not(feature = "http2"), not(feature = "http1")))]
30use std::marker::PhantomData;
31
32use pin_project_lite::pin_project;
33
34use crate::common::rewind::Rewind;
35
36type Error = Box<dyn std::error::Error + Send + Sync>;
37
38type Result<T> = std::result::Result<T, Error>;
39
40const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
41
42#[cfg(feature = "http2")]
44pub trait HttpServerConnExec<A, B: Body>: Http2ServerConnExec<A, B> {}
45
46#[cfg(feature = "http2")]
47impl<A, B: Body, T: Http2ServerConnExec<A, B>> HttpServerConnExec<A, B> for T {}
48
49#[cfg(not(feature = "http2"))]
51pub trait HttpServerConnExec<A, B: Body> {}
52
53#[cfg(not(feature = "http2"))]
54impl<A, B: Body, T> HttpServerConnExec<A, B> for T {}
55
56#[derive(Clone, Debug)]
58pub struct Builder<E> {
59 #[cfg(feature = "http1")]
60 http1: http1::Builder,
61 #[cfg(feature = "http2")]
62 http2: http2::Builder<E>,
63 #[cfg(any(feature = "http1", feature = "http2"))]
64 version: Option<Version>,
65 #[cfg(not(feature = "http2"))]
66 _executor: E,
67}
68
69impl<E> Builder<E> {
70 pub fn new(executor: E) -> Self {
86 Self {
87 #[cfg(feature = "http1")]
88 http1: http1::Builder::new(),
89 #[cfg(feature = "http2")]
90 http2: http2::Builder::new(executor),
91 #[cfg(any(feature = "http1", feature = "http2"))]
92 version: None,
93 #[cfg(not(feature = "http2"))]
94 _executor: executor,
95 }
96 }
97
98 #[cfg(feature = "http1")]
100 pub fn http1(&mut self) -> Http1Builder<'_, E> {
101 Http1Builder { inner: self }
102 }
103
104 #[cfg(feature = "http2")]
106 pub fn http2(&mut self) -> Http2Builder<'_, E> {
107 Http2Builder { inner: self }
108 }
109
110 #[cfg(feature = "http2")]
114 pub fn http2_only(mut self) -> Self {
115 assert!(self.version.is_none());
116 self.version = Some(Version::H2);
117 self
118 }
119
120 #[cfg(feature = "http1")]
124 pub fn http1_only(mut self) -> Self {
125 assert!(self.version.is_none());
126 self.version = Some(Version::H1);
127 self
128 }
129
130 pub fn serve_connection<I, S, B>(&self, io: I, service: S) -> Connection<'_, I, S, E>
132 where
133 S: Service<Request<Incoming>, Response = Response<B>>,
134 S::Future: 'static,
135 S::Error: Into<Box<dyn StdError + Send + Sync>>,
136 B: Body + 'static,
137 B::Error: Into<Box<dyn StdError + Send + Sync>>,
138 I: Read + Write + Unpin + 'static,
139 E: HttpServerConnExec<S::Future, B>,
140 {
141 let state = match self.version {
142 #[cfg(feature = "http1")]
143 Some(Version::H1) => {
144 let io = Rewind::new_buffered(io, Bytes::new());
145 let conn = self.http1.serve_connection(io, service);
146 ConnState::H1 { conn }
147 }
148 #[cfg(feature = "http2")]
149 Some(Version::H2) => {
150 let io = Rewind::new_buffered(io, Bytes::new());
151 let conn = self.http2.serve_connection(io, service);
152 ConnState::H2 { conn }
153 }
154 #[cfg(any(feature = "http1", feature = "http2"))]
155 _ => ConnState::ReadVersion {
156 read_version: read_version(io),
157 builder: Cow::Borrowed(self),
158 service: Some(service),
159 },
160 };
161
162 Connection { state }
163 }
164
165 pub fn serve_connection_with_upgrades<I, S, B>(
173 &self,
174 io: I,
175 service: S,
176 ) -> UpgradeableConnection<'_, I, S, E>
177 where
178 S: Service<Request<Incoming>, Response = Response<B>>,
179 S::Future: 'static,
180 S::Error: Into<Box<dyn StdError + Send + Sync>>,
181 B: Body + 'static,
182 B::Error: Into<Box<dyn StdError + Send + Sync>>,
183 I: Read + Write + Unpin + Send + 'static,
184 E: HttpServerConnExec<S::Future, B>,
185 {
186 UpgradeableConnection {
187 state: UpgradeableConnState::ReadVersion {
188 read_version: read_version(io),
189 builder: Cow::Borrowed(self),
190 service: Some(service),
191 },
192 }
193 }
194}
195
196#[derive(Copy, Clone, Debug)]
197enum Version {
198 H1,
199 H2,
200}
201
202impl Version {
203 #[must_use]
204 #[cfg(any(not(feature = "http2"), not(feature = "http1")))]
205 pub fn unsupported(self) -> Error {
206 match self {
207 Version::H1 => Error::from("HTTP/1 is not supported"),
208 Version::H2 => Error::from("HTTP/2 is not supported"),
209 }
210 }
211}
212
213fn read_version<I>(io: I) -> ReadVersion<I>
214where
215 I: Read + Unpin,
216{
217 ReadVersion {
218 io: Some(io),
219 buf: [MaybeUninit::uninit(); 24],
220 filled: 0,
221 version: Version::H2,
222 cancelled: false,
223 _pin: PhantomPinned,
224 }
225}
226
227pin_project! {
228 struct ReadVersion<I> {
229 io: Option<I>,
230 buf: [MaybeUninit<u8>; 24],
231 filled: usize,
233 version: Version,
234 cancelled: bool,
235 #[pin]
237 _pin: PhantomPinned,
238 }
239}
240
241impl<I> ReadVersion<I> {
242 pub fn cancel(self: Pin<&mut Self>) {
243 *self.project().cancelled = true;
244 }
245}
246
247impl<I> Future for ReadVersion<I>
248where
249 I: Read + Unpin,
250{
251 type Output = io::Result<(Version, Rewind<I>)>;
252
253 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
254 let this = self.project();
255 if *this.cancelled {
256 return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "Cancelled")));
257 }
258
259 let mut buf = ReadBuf::uninit(&mut *this.buf);
260 unsafe {
263 buf.unfilled().advance(*this.filled);
264 };
265
266 while buf.filled().len() < H2_PREFACE.len() {
268 let len = buf.filled().len();
269 ready!(Pin::new(this.io.as_mut().unwrap()).poll_read(cx, buf.unfilled()))?;
270 *this.filled = buf.filled().len();
271
272 if buf.filled().len() == len
274 || buf.filled()[len..] != H2_PREFACE[len..buf.filled().len()]
275 {
276 *this.version = Version::H1;
277 break;
278 }
279 }
280
281 let io = this.io.take().unwrap();
282 let buf = buf.filled().to_vec();
283 Poll::Ready(Ok((
284 *this.version,
285 Rewind::new_buffered(io, Bytes::from(buf)),
286 )))
287 }
288}
289
290pin_project! {
291 pub struct Connection<'a, I, S, E>
293 where
294 S: HttpService<Incoming>,
295 {
296 #[pin]
297 state: ConnState<'a, I, S, E>,
298 }
299}
300
301enum Cow<'a, T> {
303 Borrowed(&'a T),
304 Owned(T),
305}
306
307impl<'a, T> std::ops::Deref for Cow<'a, T> {
308 type Target = T;
309 fn deref(&self) -> &T {
310 match self {
311 Cow::Borrowed(t) => &*t,
312 Cow::Owned(ref t) => t,
313 }
314 }
315}
316
317#[cfg(feature = "http1")]
318type Http1Connection<I, S> = hyper::server::conn::http1::Connection<Rewind<I>, S>;
319
320#[cfg(not(feature = "http1"))]
321type Http1Connection<I, S> = (PhantomData<I>, PhantomData<S>);
322
323#[cfg(feature = "http2")]
324type Http2Connection<I, S, E> = hyper::server::conn::http2::Connection<Rewind<I>, S, E>;
325
326#[cfg(not(feature = "http2"))]
327type Http2Connection<I, S, E> = (PhantomData<I>, PhantomData<S>, PhantomData<E>);
328
329pin_project! {
330 #[project = ConnStateProj]
331 enum ConnState<'a, I, S, E>
332 where
333 S: HttpService<Incoming>,
334 {
335 ReadVersion {
336 #[pin]
337 read_version: ReadVersion<I>,
338 builder: Cow<'a, Builder<E>>,
339 service: Option<S>,
340 },
341 H1 {
342 #[pin]
343 conn: Http1Connection<I, S>,
344 },
345 H2 {
346 #[pin]
347 conn: Http2Connection<I, S, E>,
348 },
349 }
350}
351
352impl<I, S, E, B> Connection<'_, I, S, E>
353where
354 S: HttpService<Incoming, ResBody = B>,
355 S::Error: Into<Box<dyn StdError + Send + Sync>>,
356 I: Read + Write + Unpin,
357 B: Body + 'static,
358 B::Error: Into<Box<dyn StdError + Send + Sync>>,
359 E: HttpServerConnExec<S::Future, B>,
360{
361 pub fn graceful_shutdown(self: Pin<&mut Self>) {
370 match self.project().state.project() {
371 ConnStateProj::ReadVersion { read_version, .. } => read_version.cancel(),
372 #[cfg(feature = "http1")]
373 ConnStateProj::H1 { conn } => conn.graceful_shutdown(),
374 #[cfg(feature = "http2")]
375 ConnStateProj::H2 { conn } => conn.graceful_shutdown(),
376 #[cfg(any(not(feature = "http1"), not(feature = "http2")))]
377 _ => unreachable!(),
378 }
379 }
380
381 pub fn into_owned(self) -> Connection<'static, I, S, E>
383 where
384 Builder<E>: Clone,
385 {
386 Connection {
387 state: match self.state {
388 ConnState::ReadVersion {
389 read_version,
390 builder,
391 service,
392 } => ConnState::ReadVersion {
393 read_version,
394 service,
395 builder: Cow::Owned(builder.clone()),
396 },
397 #[cfg(feature = "http1")]
398 ConnState::H1 { conn } => ConnState::H1 { conn },
399 #[cfg(feature = "http2")]
400 ConnState::H2 { conn } => ConnState::H2 { conn },
401 #[cfg(any(not(feature = "http1"), not(feature = "http2")))]
402 _ => unreachable!(),
403 },
404 }
405 }
406}
407
408impl<I, S, E, B> Future for Connection<'_, I, S, E>
409where
410 S: Service<Request<Incoming>, Response = Response<B>>,
411 S::Future: 'static,
412 S::Error: Into<Box<dyn StdError + Send + Sync>>,
413 B: Body + 'static,
414 B::Error: Into<Box<dyn StdError + Send + Sync>>,
415 I: Read + Write + Unpin + 'static,
416 E: HttpServerConnExec<S::Future, B>,
417{
418 type Output = Result<()>;
419
420 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
421 loop {
422 let mut this = self.as_mut().project();
423
424 match this.state.as_mut().project() {
425 ConnStateProj::ReadVersion {
426 read_version,
427 builder,
428 service,
429 } => {
430 let (version, io) = ready!(read_version.poll(cx))?;
431 let service = service.take().unwrap();
432 match version {
433 #[cfg(feature = "http1")]
434 Version::H1 => {
435 let conn = builder.http1.serve_connection(io, service);
436 this.state.set(ConnState::H1 { conn });
437 }
438 #[cfg(feature = "http2")]
439 Version::H2 => {
440 let conn = builder.http2.serve_connection(io, service);
441 this.state.set(ConnState::H2 { conn });
442 }
443 #[cfg(any(not(feature = "http1"), not(feature = "http2")))]
444 _ => return Poll::Ready(Err(version.unsupported())),
445 }
446 }
447 #[cfg(feature = "http1")]
448 ConnStateProj::H1 { conn } => {
449 return conn.poll(cx).map_err(Into::into);
450 }
451 #[cfg(feature = "http2")]
452 ConnStateProj::H2 { conn } => {
453 return conn.poll(cx).map_err(Into::into);
454 }
455 #[cfg(any(not(feature = "http1"), not(feature = "http2")))]
456 _ => unreachable!(),
457 }
458 }
459 }
460}
461
462pin_project! {
463 pub struct UpgradeableConnection<'a, I, S, E>
465 where
466 S: HttpService<Incoming>,
467 {
468 #[pin]
469 state: UpgradeableConnState<'a, I, S, E>,
470 }
471}
472
473#[cfg(feature = "http1")]
474type Http1UpgradeableConnection<I, S> = hyper::server::conn::http1::UpgradeableConnection<I, S>;
475
476#[cfg(not(feature = "http1"))]
477type Http1UpgradeableConnection<I, S> = (PhantomData<I>, PhantomData<S>);
478
479pin_project! {
480 #[project = UpgradeableConnStateProj]
481 enum UpgradeableConnState<'a, I, S, E>
482 where
483 S: HttpService<Incoming>,
484 {
485 ReadVersion {
486 #[pin]
487 read_version: ReadVersion<I>,
488 builder: Cow<'a, Builder<E>>,
489 service: Option<S>,
490 },
491 H1 {
492 #[pin]
493 conn: Http1UpgradeableConnection<Rewind<I>, S>,
494 },
495 H2 {
496 #[pin]
497 conn: Http2Connection<I, S, E>,
498 },
499 }
500}
501
502impl<I, S, E, B> UpgradeableConnection<'_, I, S, E>
503where
504 S: HttpService<Incoming, ResBody = B>,
505 S::Error: Into<Box<dyn StdError + Send + Sync>>,
506 I: Read + Write + Unpin,
507 B: Body + 'static,
508 B::Error: Into<Box<dyn StdError + Send + Sync>>,
509 E: HttpServerConnExec<S::Future, B>,
510{
511 pub fn graceful_shutdown(self: Pin<&mut Self>) {
520 match self.project().state.project() {
521 UpgradeableConnStateProj::ReadVersion { read_version, .. } => read_version.cancel(),
522 #[cfg(feature = "http1")]
523 UpgradeableConnStateProj::H1 { conn } => conn.graceful_shutdown(),
524 #[cfg(feature = "http2")]
525 UpgradeableConnStateProj::H2 { conn } => conn.graceful_shutdown(),
526 #[cfg(any(not(feature = "http1"), not(feature = "http2")))]
527 _ => unreachable!(),
528 }
529 }
530
531 pub fn into_owned(self) -> UpgradeableConnection<'static, I, S, E>
533 where
534 Builder<E>: Clone,
535 {
536 UpgradeableConnection {
537 state: match self.state {
538 UpgradeableConnState::ReadVersion {
539 read_version,
540 builder,
541 service,
542 } => UpgradeableConnState::ReadVersion {
543 read_version,
544 service,
545 builder: Cow::Owned(builder.clone()),
546 },
547 #[cfg(feature = "http1")]
548 UpgradeableConnState::H1 { conn } => UpgradeableConnState::H1 { conn },
549 #[cfg(feature = "http2")]
550 UpgradeableConnState::H2 { conn } => UpgradeableConnState::H2 { conn },
551 #[cfg(any(not(feature = "http1"), not(feature = "http2")))]
552 _ => unreachable!(),
553 },
554 }
555 }
556}
557
558impl<I, S, E, B> Future for UpgradeableConnection<'_, I, S, E>
559where
560 S: Service<Request<Incoming>, Response = Response<B>>,
561 S::Future: 'static,
562 S::Error: Into<Box<dyn StdError + Send + Sync>>,
563 B: Body + 'static,
564 B::Error: Into<Box<dyn StdError + Send + Sync>>,
565 I: Read + Write + Unpin + Send + 'static,
566 E: HttpServerConnExec<S::Future, B>,
567{
568 type Output = Result<()>;
569
570 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
571 loop {
572 let mut this = self.as_mut().project();
573
574 match this.state.as_mut().project() {
575 UpgradeableConnStateProj::ReadVersion {
576 read_version,
577 builder,
578 service,
579 } => {
580 let (version, io) = ready!(read_version.poll(cx))?;
581 let service = service.take().unwrap();
582 match version {
583 #[cfg(feature = "http1")]
584 Version::H1 => {
585 let conn = builder.http1.serve_connection(io, service).with_upgrades();
586 this.state.set(UpgradeableConnState::H1 { conn });
587 }
588 #[cfg(feature = "http2")]
589 Version::H2 => {
590 let conn = builder.http2.serve_connection(io, service);
591 this.state.set(UpgradeableConnState::H2 { conn });
592 }
593 #[cfg(any(not(feature = "http1"), not(feature = "http2")))]
594 _ => return Poll::Ready(Err(version.unsupported())),
595 }
596 }
597 #[cfg(feature = "http1")]
598 UpgradeableConnStateProj::H1 { conn } => {
599 return conn.poll(cx).map_err(Into::into);
600 }
601 #[cfg(feature = "http2")]
602 UpgradeableConnStateProj::H2 { conn } => {
603 return conn.poll(cx).map_err(Into::into);
604 }
605 #[cfg(any(not(feature = "http1"), not(feature = "http2")))]
606 _ => unreachable!(),
607 }
608 }
609 }
610}
611
612#[cfg(feature = "http1")]
614pub struct Http1Builder<'a, E> {
615 inner: &'a mut Builder<E>,
616}
617
618#[cfg(feature = "http1")]
619impl<E> Http1Builder<'_, E> {
620 #[cfg(feature = "http2")]
622 pub fn http2(&mut self) -> Http2Builder<'_, E> {
623 Http2Builder { inner: self.inner }
624 }
625
626 pub fn half_close(&mut self, val: bool) -> &mut Self {
635 self.inner.http1.half_close(val);
636 self
637 }
638
639 pub fn keep_alive(&mut self, val: bool) -> &mut Self {
643 self.inner.http1.keep_alive(val);
644 self
645 }
646
647 pub fn title_case_headers(&mut self, enabled: bool) -> &mut Self {
654 self.inner.http1.title_case_headers(enabled);
655 self
656 }
657
658 pub fn preserve_header_case(&mut self, enabled: bool) -> &mut Self {
672 self.inner.http1.preserve_header_case(enabled);
673 self
674 }
675
676 pub fn max_headers(&mut self, val: usize) -> &mut Self {
692 self.inner.http1.max_headers(val);
693 self
694 }
695
696 pub fn header_read_timeout(&mut self, read_timeout: impl Into<Option<Duration>>) -> &mut Self {
706 self.inner.http1.header_read_timeout(read_timeout);
707 self
708 }
709
710 pub fn writev(&mut self, val: bool) -> &mut Self {
723 self.inner.http1.writev(val);
724 self
725 }
726
727 pub fn max_buf_size(&mut self, max: usize) -> &mut Self {
735 self.inner.http1.max_buf_size(max);
736 self
737 }
738
739 pub fn pipeline_flush(&mut self, enabled: bool) -> &mut Self {
745 self.inner.http1.pipeline_flush(enabled);
746 self
747 }
748
749 pub fn timer<M>(&mut self, timer: M) -> &mut Self
751 where
752 M: Timer + Send + Sync + 'static,
753 {
754 self.inner.http1.timer(timer);
755 self
756 }
757
758 #[cfg(feature = "http2")]
760 pub async fn serve_connection<I, S, B>(&self, io: I, service: S) -> Result<()>
761 where
762 S: Service<Request<Incoming>, Response = Response<B>>,
763 S::Future: 'static,
764 S::Error: Into<Box<dyn StdError + Send + Sync>>,
765 B: Body + 'static,
766 B::Error: Into<Box<dyn StdError + Send + Sync>>,
767 I: Read + Write + Unpin + 'static,
768 E: HttpServerConnExec<S::Future, B>,
769 {
770 self.inner.serve_connection(io, service).await
771 }
772
773 #[cfg(not(feature = "http2"))]
775 pub async fn serve_connection<I, S, B>(&self, io: I, service: S) -> Result<()>
776 where
777 S: Service<Request<Incoming>, Response = Response<B>>,
778 S::Future: 'static,
779 S::Error: Into<Box<dyn StdError + Send + Sync>>,
780 B: Body + 'static,
781 B::Error: Into<Box<dyn StdError + Send + Sync>>,
782 I: Read + Write + Unpin + 'static,
783 {
784 self.inner.serve_connection(io, service).await
785 }
786
787 #[cfg(feature = "http2")]
791 pub fn serve_connection_with_upgrades<I, S, B>(
792 &self,
793 io: I,
794 service: S,
795 ) -> UpgradeableConnection<'_, I, S, E>
796 where
797 S: Service<Request<Incoming>, Response = Response<B>>,
798 S::Future: 'static,
799 S::Error: Into<Box<dyn StdError + Send + Sync>>,
800 B: Body + 'static,
801 B::Error: Into<Box<dyn StdError + Send + Sync>>,
802 I: Read + Write + Unpin + Send + 'static,
803 E: HttpServerConnExec<S::Future, B>,
804 {
805 self.inner.serve_connection_with_upgrades(io, service)
806 }
807}
808
809#[cfg(feature = "http2")]
811pub struct Http2Builder<'a, E> {
812 inner: &'a mut Builder<E>,
813}
814
815#[cfg(feature = "http2")]
816impl<E> Http2Builder<'_, E> {
817 #[cfg(feature = "http1")]
818 pub fn http1(&mut self) -> Http1Builder<'_, E> {
820 Http1Builder { inner: self.inner }
821 }
822
823 pub fn max_pending_accept_reset_streams(&mut self, max: impl Into<Option<usize>>) -> &mut Self {
830 self.inner.http2.max_pending_accept_reset_streams(max);
831 self
832 }
833
834 pub fn initial_stream_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
843 self.inner.http2.initial_stream_window_size(sz);
844 self
845 }
846
847 pub fn initial_connection_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
853 self.inner.http2.initial_connection_window_size(sz);
854 self
855 }
856
857 pub fn adaptive_window(&mut self, enabled: bool) -> &mut Self {
863 self.inner.http2.adaptive_window(enabled);
864 self
865 }
866
867 pub fn max_frame_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
873 self.inner.http2.max_frame_size(sz);
874 self
875 }
876
877 pub fn max_concurrent_streams(&mut self, max: impl Into<Option<u32>>) -> &mut Self {
884 self.inner.http2.max_concurrent_streams(max);
885 self
886 }
887
888 pub fn keep_alive_interval(&mut self, interval: impl Into<Option<Duration>>) -> &mut Self {
898 self.inner.http2.keep_alive_interval(interval);
899 self
900 }
901
902 pub fn keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self {
912 self.inner.http2.keep_alive_timeout(timeout);
913 self
914 }
915
916 pub fn max_send_buf_size(&mut self, max: usize) -> &mut Self {
924 self.inner.http2.max_send_buf_size(max);
925 self
926 }
927
928 pub fn enable_connect_protocol(&mut self) -> &mut Self {
932 self.inner.http2.enable_connect_protocol();
933 self
934 }
935
936 pub fn max_header_list_size(&mut self, max: u32) -> &mut Self {
940 self.inner.http2.max_header_list_size(max);
941 self
942 }
943
944 pub fn timer<M>(&mut self, timer: M) -> &mut Self
946 where
947 M: Timer + Send + Sync + 'static,
948 {
949 self.inner.http2.timer(timer);
950 self
951 }
952
953 pub async fn serve_connection<I, S, B>(&self, io: I, service: S) -> Result<()>
955 where
956 S: Service<Request<Incoming>, Response = Response<B>>,
957 S::Future: 'static,
958 S::Error: Into<Box<dyn StdError + Send + Sync>>,
959 B: Body + 'static,
960 B::Error: Into<Box<dyn StdError + Send + Sync>>,
961 I: Read + Write + Unpin + 'static,
962 E: HttpServerConnExec<S::Future, B>,
963 {
964 self.inner.serve_connection(io, service).await
965 }
966
967 pub fn serve_connection_with_upgrades<I, S, B>(
971 &self,
972 io: I,
973 service: S,
974 ) -> UpgradeableConnection<'_, I, S, E>
975 where
976 S: Service<Request<Incoming>, Response = Response<B>>,
977 S::Future: 'static,
978 S::Error: Into<Box<dyn StdError + Send + Sync>>,
979 B: Body + 'static,
980 B::Error: Into<Box<dyn StdError + Send + Sync>>,
981 I: Read + Write + Unpin + Send + 'static,
982 E: HttpServerConnExec<S::Future, B>,
983 {
984 self.inner.serve_connection_with_upgrades(io, service)
985 }
986}
987
988#[cfg(test)]
989mod tests {
990 use crate::{
991 rt::{TokioExecutor, TokioIo},
992 server::conn::auto,
993 };
994 use http::{Request, Response};
995 use http_body::Body;
996 use http_body_util::{BodyExt, Empty, Full};
997 use hyper::{body, body::Bytes, client, service::service_fn};
998 use std::{convert::Infallible, error::Error as StdError, net::SocketAddr, time::Duration};
999 use tokio::{
1000 net::{TcpListener, TcpStream},
1001 pin,
1002 };
1003
1004 const BODY: &[u8] = b"Hello, world!";
1005
1006 #[test]
1007 fn configuration() {
1008 auto::Builder::new(TokioExecutor::new())
1010 .http1()
1011 .keep_alive(true)
1012 .http2()
1013 .keep_alive_interval(None);
1014 let mut builder = auto::Builder::new(TokioExecutor::new());
1018
1019 builder.http1().keep_alive(true);
1020 builder.http2().keep_alive_interval(None);
1021 }
1023
1024 #[cfg(not(miri))]
1025 #[tokio::test]
1026 async fn http1() {
1027 let addr = start_server(false, false).await;
1028 let mut sender = connect_h1(addr).await;
1029
1030 let response = sender
1031 .send_request(Request::new(Empty::<Bytes>::new()))
1032 .await
1033 .unwrap();
1034
1035 let body = response.into_body().collect().await.unwrap().to_bytes();
1036
1037 assert_eq!(body, BODY);
1038 }
1039
1040 #[cfg(not(miri))]
1041 #[tokio::test]
1042 async fn http2() {
1043 let addr = start_server(false, false).await;
1044 let mut sender = connect_h2(addr).await;
1045
1046 let response = sender
1047 .send_request(Request::new(Empty::<Bytes>::new()))
1048 .await
1049 .unwrap();
1050
1051 let body = response.into_body().collect().await.unwrap().to_bytes();
1052
1053 assert_eq!(body, BODY);
1054 }
1055
1056 #[cfg(not(miri))]
1057 #[tokio::test]
1058 async fn http2_only() {
1059 let addr = start_server(false, true).await;
1060 let mut sender = connect_h2(addr).await;
1061
1062 let response = sender
1063 .send_request(Request::new(Empty::<Bytes>::new()))
1064 .await
1065 .unwrap();
1066
1067 let body = response.into_body().collect().await.unwrap().to_bytes();
1068
1069 assert_eq!(body, BODY);
1070 }
1071
1072 #[cfg(not(miri))]
1073 #[tokio::test]
1074 async fn http2_only_fail_if_client_is_http1() {
1075 let addr = start_server(false, true).await;
1076 let mut sender = connect_h1(addr).await;
1077
1078 let _ = sender
1079 .send_request(Request::new(Empty::<Bytes>::new()))
1080 .await
1081 .expect_err("should fail");
1082 }
1083
1084 #[cfg(not(miri))]
1085 #[tokio::test]
1086 async fn http1_only() {
1087 let addr = start_server(true, false).await;
1088 let mut sender = connect_h1(addr).await;
1089
1090 let response = sender
1091 .send_request(Request::new(Empty::<Bytes>::new()))
1092 .await
1093 .unwrap();
1094
1095 let body = response.into_body().collect().await.unwrap().to_bytes();
1096
1097 assert_eq!(body, BODY);
1098 }
1099
1100 #[cfg(not(miri))]
1101 #[tokio::test]
1102 async fn http1_only_fail_if_client_is_http2() {
1103 let addr = start_server(true, false).await;
1104 let mut sender = connect_h2(addr).await;
1105
1106 let _ = sender
1107 .send_request(Request::new(Empty::<Bytes>::new()))
1108 .await
1109 .expect_err("should fail");
1110 }
1111
1112 #[cfg(not(miri))]
1113 #[tokio::test]
1114 async fn graceful_shutdown() {
1115 let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))
1116 .await
1117 .unwrap();
1118
1119 let listener_addr = listener.local_addr().unwrap();
1120
1121 let listen_task = tokio::spawn(async move { listener.accept().await.unwrap() });
1123 let _stream = TcpStream::connect(listener_addr).await.unwrap();
1125
1126 let (stream, _) = listen_task.await.unwrap();
1127 let stream = TokioIo::new(stream);
1128 let builder = auto::Builder::new(TokioExecutor::new());
1129 let connection = builder.serve_connection(stream, service_fn(hello));
1130
1131 pin!(connection);
1132
1133 connection.as_mut().graceful_shutdown();
1134
1135 let connection_error = tokio::time::timeout(Duration::from_millis(200), connection)
1136 .await
1137 .expect("Connection should have finished in a timely manner after graceful shutdown.")
1138 .expect_err("Connection should have been interrupted.");
1139
1140 let connection_error = connection_error
1141 .downcast_ref::<std::io::Error>()
1142 .expect("The error should have been `std::io::Error`.");
1143 assert_eq!(connection_error.kind(), std::io::ErrorKind::Interrupted);
1144 }
1145
1146 async fn connect_h1<B>(addr: SocketAddr) -> client::conn::http1::SendRequest<B>
1147 where
1148 B: Body + Send + 'static,
1149 B::Data: Send,
1150 B::Error: Into<Box<dyn StdError + Send + Sync>>,
1151 {
1152 let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap());
1153 let (sender, connection) = client::conn::http1::handshake(stream).await.unwrap();
1154
1155 tokio::spawn(connection);
1156
1157 sender
1158 }
1159
1160 async fn connect_h2<B>(addr: SocketAddr) -> client::conn::http2::SendRequest<B>
1161 where
1162 B: Body + Unpin + Send + 'static,
1163 B::Data: Send,
1164 B::Error: Into<Box<dyn StdError + Send + Sync>>,
1165 {
1166 let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap());
1167 let (sender, connection) = client::conn::http2::Builder::new(TokioExecutor::new())
1168 .handshake(stream)
1169 .await
1170 .unwrap();
1171
1172 tokio::spawn(connection);
1173
1174 sender
1175 }
1176
1177 async fn start_server(h1_only: bool, h2_only: bool) -> SocketAddr {
1178 let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
1179 let listener = TcpListener::bind(addr).await.unwrap();
1180
1181 let local_addr = listener.local_addr().unwrap();
1182
1183 tokio::spawn(async move {
1184 loop {
1185 let (stream, _) = listener.accept().await.unwrap();
1186 let stream = TokioIo::new(stream);
1187 tokio::task::spawn(async move {
1188 let mut builder = auto::Builder::new(TokioExecutor::new());
1189 if h1_only {
1190 builder = builder.http1_only();
1191 builder.serve_connection(stream, service_fn(hello)).await
1192 } else if h2_only {
1193 builder = builder.http2_only();
1194 builder.serve_connection(stream, service_fn(hello)).await
1195 } else {
1196 builder
1197 .http2()
1198 .max_header_list_size(4096)
1199 .serve_connection_with_upgrades(stream, service_fn(hello))
1200 .await
1201 }
1202 .unwrap();
1203 });
1204 }
1205 });
1206
1207 local_addr
1208 }
1209
1210 async fn hello(_req: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, Infallible> {
1211 Ok(Response::new(Full::new(Bytes::from(BODY))))
1212 }
1213}