1use std::{
5 convert::Infallible, error, fmt, io, marker::PhantomData, net::SocketAddr, pin::Pin, result,
6 sync::Arc, task::Poll,
7};
8
9use bytes::Bytes;
10use flume::{Receiver, Sender};
11use futures_lite::{Stream, StreamExt};
12use futures_sink::Sink;
13use hyper::{
14 client::{connect::Connect, HttpConnector, ResponseFuture},
15 server::conn::{AddrIncoming, AddrStream},
16 service::{make_service_fn, service_fn},
17 Body, Client, Request, Response, Server, StatusCode, Uri,
18};
19use tokio::{sync::mpsc, task::JoinHandle};
20use tracing::{debug, event, trace, Level};
21
22use crate::{
23 transport::{ConnectionErrors, Connector, Listener, LocalAddr, StreamTypes},
24 RpcMessage,
25};
26
27struct HyperConnectionInner {
28 client: Box<dyn Requester>,
29 config: Arc<ChannelConfig>,
30 uri: Uri,
31}
32
33pub struct HyperConnector<In: RpcMessage, Out: RpcMessage> {
35 inner: Arc<HyperConnectionInner>,
36 _p: PhantomData<(In, Out)>,
37}
38
39impl<In: RpcMessage, Out: RpcMessage> Clone for HyperConnector<In, Out> {
40 fn clone(&self) -> Self {
41 Self {
42 inner: self.inner.clone(),
43 _p: PhantomData,
44 }
45 }
46}
47
48trait Requester: Send + Sync + 'static {
50 fn request(&self, req: Request<Body>) -> ResponseFuture;
51}
52
53impl<C: Connect + Clone + Send + Sync + 'static> Requester for Client<C, Body> {
54 fn request(&self, req: Request<Body>) -> ResponseFuture {
55 self.request(req)
56 }
57}
58
59impl<In: RpcMessage, Out: RpcMessage> HyperConnector<In, Out> {
60 pub fn new(uri: Uri) -> Self {
62 Self::with_config(uri, ChannelConfig::default())
63 }
64
65 pub fn with_config(uri: Uri, config: ChannelConfig) -> Self {
67 let mut connector = HttpConnector::new();
68 connector.set_nodelay(true);
69 Self::with_connector(connector, uri, Arc::new(config))
70 }
71
72 pub fn with_connector<C: Connect + Clone + Send + Sync + 'static>(
74 connector: C,
75 uri: Uri,
76 config: Arc<ChannelConfig>,
77 ) -> Self {
78 let client = Client::builder()
79 .http2_only(true)
80 .http2_initial_connection_window_size(Some(config.max_frame_size))
81 .http2_initial_stream_window_size(Some(config.max_frame_size))
82 .http2_max_frame_size(Some(config.max_frame_size))
83 .http2_max_send_buf_size(config.max_frame_size.try_into().unwrap())
84 .build(connector);
85 Self {
86 inner: Arc::new(HyperConnectionInner {
87 client: Box::new(client),
88 uri,
89 config,
90 }),
91 _p: PhantomData,
92 }
93 }
94}
95
96impl<In: RpcMessage, Out: RpcMessage> fmt::Debug for HyperConnector<In, Out> {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 f.debug_struct("ClientChannel")
99 .field("uri", &self.inner.uri)
100 .field("config", &self.inner.config)
101 .finish()
102 }
103}
104
105type InternalChannel<In> = (
107 Receiver<result::Result<In, RecvError>>,
108 Sender<io::Result<Bytes>>,
109);
110
111#[derive(Debug, Clone)]
113pub enum ChannelConfigError {
114 InvalidMaxFrameSize(u32),
116 InvalidMaxPayloadSize(usize),
118}
119
120impl fmt::Display for ChannelConfigError {
121 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
122 fmt::Debug::fmt(&self, f)
123 }
124}
125
126impl error::Error for ChannelConfigError {}
127
128#[derive(Debug, Clone)]
132pub struct ChannelConfig {
133 max_frame_size: u32,
135 max_payload_size: usize,
136}
137
138impl ChannelConfig {
139 pub fn max_frame_size(mut self, value: u32) -> result::Result<Self, ChannelConfigError> {
141 if !(0x4000..=0xFFFFFF).contains(&value) {
142 return Err(ChannelConfigError::InvalidMaxFrameSize(value));
143 }
144 self.max_frame_size = value;
145 Ok(self)
146 }
147
148 pub fn max_payload_size(mut self, value: usize) -> result::Result<Self, ChannelConfigError> {
150 if !(4096..1024 * 1024 * 16).contains(&value) {
151 return Err(ChannelConfigError::InvalidMaxPayloadSize(value));
152 }
153 self.max_payload_size = value;
154 Ok(self)
155 }
156}
157
158impl Default for ChannelConfig {
159 fn default() -> Self {
160 Self {
161 max_frame_size: 0xFFFFFF,
162 max_payload_size: 0xFFFFFF,
163 }
164 }
165}
166
167#[derive(Debug)]
176pub struct HyperListener<In: RpcMessage, Out: RpcMessage> {
177 channel: Receiver<InternalChannel<In>>,
179 config: Arc<ChannelConfig>,
181 stop_tx: mpsc::Sender<()>,
186 local_addr: [LocalAddr; 1],
191 _p: PhantomData<(In, Out)>,
193}
194
195impl<In: RpcMessage, Out: RpcMessage> HyperListener<In, Out> {
196 pub fn serve(addr: &SocketAddr) -> hyper::Result<Self> {
198 Self::serve_with_config(addr, Default::default())
199 }
200
201 pub fn serve_with_config(addr: &SocketAddr, config: ChannelConfig) -> hyper::Result<Self> {
203 let (accept_tx, accept_rx) = flume::bounded(32);
204
205 let service = make_service_fn(move |socket: &AddrStream| {
208 let remote_addr = socket.remote_addr();
209 event!(Level::TRACE, "Connection from {:?}", remote_addr);
210
211 let accept_tx = accept_tx.clone();
213 async move {
214 let one_req_service = service_fn(move |req: Request<Body>| {
215 Self::handle_one_http2_request(req, accept_tx.clone())
217 });
218 Ok::<_, Infallible>(one_req_service)
219 }
220 });
221
222 let mut incoming = AddrIncoming::bind(addr)?;
223 incoming.set_nodelay(true);
224 let server = Server::builder(incoming)
225 .http2_only(true)
226 .http2_initial_connection_window_size(Some(config.max_frame_size))
227 .http2_initial_stream_window_size(Some(config.max_frame_size))
228 .http2_max_frame_size(Some(config.max_frame_size))
229 .http2_max_send_buf_size(config.max_frame_size.try_into().unwrap())
230 .serve(service);
231 let local_addr = server.local_addr();
232
233 let (stop_tx, mut stop_rx) = mpsc::channel::<()>(1);
234 let server = server.with_graceful_shutdown(async move {
235 stop_rx.recv().await;
237 });
238 tokio::spawn(server);
239
240 Ok(Self {
241 channel: accept_rx,
242 config: Arc::new(config),
243 stop_tx,
244 local_addr: [LocalAddr::Socket(local_addr)],
245 _p: PhantomData,
246 })
247 }
248
249 async fn handle_one_http2_request(
254 req: Request<Body>,
255 accept_tx: Sender<InternalChannel<In>>,
256 ) -> Result<Response<Body>, String> {
257 let (req_tx, req_rx) = flume::bounded::<result::Result<In, RecvError>>(32);
258 let (res_tx, res_rx) = flume::bounded::<io::Result<Bytes>>(32);
259 accept_tx
260 .send_async((req_rx, res_tx))
261 .await
262 .map_err(|_e| "unable to send")?;
263
264 spawn_recv_forwarder(req.into_body(), req_tx);
265 let response = Response::builder()
267 .status(StatusCode::OK)
268 .body(Body::wrap_stream(res_rx.into_stream()))
269 .map_err(|_| "unable to set body")?;
270 Ok(response)
271 }
272}
273
274fn try_get_length_prefixed(buf: &[u8]) -> Option<&[u8]> {
275 if buf.len() < 4 {
276 return None;
277 }
278 let len = u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]) as usize;
279 if buf.len() < 4 + len {
280 return None;
281 }
282 Some(&buf[4..4 + len])
283}
284
285async fn try_forward_all<In: RpcMessage>(
294 buffer: &[u8],
295 req_tx: &Sender<Result<In, RecvError>>,
296) -> result::Result<usize, ()> {
297 let mut sent = 0;
298 while let Some(msg) = try_get_length_prefixed(&buffer[sent..]) {
299 sent += msg.len() + 4;
300 let item = postcard::from_bytes::<In>(msg).map_err(RecvError::DeserializeError);
301 if let Err(_cause) = req_tx.send_async(item).await {
302 trace!("Flume receiver dropped");
309 return Err(());
310 }
311 }
312 Ok(sent)
313}
314
315fn spawn_recv_forwarder<In: RpcMessage>(
327 req: Body,
328 req_tx: Sender<result::Result<In, RecvError>>,
329) -> JoinHandle<result::Result<(), ()>> {
330 tokio::spawn(async move {
331 let mut stream = req;
332 let mut buf = Vec::new();
333
334 while let Some(chunk) = stream.next().await {
335 match chunk.as_ref() {
336 Ok(chunk) => {
337 event!(Level::TRACE, "Server got {} bytes", chunk.len());
338 if buf.is_empty() {
339 let sent = try_forward_all(chunk, &req_tx).await?;
341 buf.extend_from_slice(&chunk[sent..]);
343 } else {
344 buf.extend_from_slice(chunk);
346 }
347 }
348 Err(cause) => {
349 debug!("Network error: {}", cause);
353 break;
354 }
355 };
356 let sent = try_forward_all(&buf, &req_tx).await?;
357 buf.drain(..sent);
360 }
361 Ok(())
362 })
363}
364
365impl<In: RpcMessage, Out: RpcMessage> Clone for HyperListener<In, Out> {
369 fn clone(&self) -> Self {
370 Self {
371 channel: self.channel.clone(),
372 stop_tx: self.stop_tx.clone(),
373 local_addr: self.local_addr.clone(),
374 config: self.config.clone(),
375 _p: PhantomData,
376 }
377 }
378}
379
380pub struct RecvStream<Res: RpcMessage> {
385 recv: flume::r#async::RecvStream<'static, result::Result<Res, RecvError>>,
386}
387
388impl<Res: RpcMessage> RecvStream<Res> {
389 pub fn new(recv: flume::Receiver<result::Result<Res, RecvError>>) -> Self {
391 Self {
392 recv: recv.into_stream(),
393 }
394 }
395
396 }
399
400impl<In: RpcMessage> Clone for RecvStream<In> {
401 fn clone(&self) -> Self {
402 Self {
403 recv: self.recv.clone(),
404 }
405 }
406}
407
408impl<Res: RpcMessage> Stream for RecvStream<Res> {
409 type Item = Result<Res, RecvError>;
410
411 fn poll_next(
412 mut self: Pin<&mut Self>,
413 cx: &mut std::task::Context<'_>,
414 ) -> Poll<Option<Self::Item>> {
415 Pin::new(&mut self.recv).poll_next(cx)
416 }
417}
418
419pub struct SendSink<Out: RpcMessage> {
421 sink: flume::r#async::SendSink<'static, io::Result<Bytes>>,
422 config: Arc<ChannelConfig>,
423 _p: PhantomData<Out>,
424}
425
426impl<Out: RpcMessage> SendSink<Out> {
427 fn new(sender: flume::Sender<io::Result<Bytes>>, config: Arc<ChannelConfig>) -> Self {
428 Self {
429 sink: sender.into_sink(),
430 config,
431 _p: PhantomData,
432 }
433 }
434 fn serialize(&self, item: Out) -> Result<Bytes, SendError> {
435 let mut data = Vec::with_capacity(1024);
436 data.extend_from_slice(&[0u8; 4]);
437 let mut data = postcard::to_extend(&item, data).map_err(SendError::SerializeError)?;
438 let len = data.len() - 4;
439 if len > self.config.max_payload_size {
440 return Err(SendError::SizeError(len));
441 }
442 let len: u32 = len.try_into().expect("max_payload_size fits into u32");
443 data[0..4].copy_from_slice(&len.to_be_bytes());
444 Ok(data.into())
445 }
446
447 pub fn into_inner(self) -> flume::r#async::SendSink<'static, io::Result<Bytes>> {
452 self.sink
453 }
454}
455
456impl<Out: RpcMessage> Sink<Out> for SendSink<Out> {
457 type Error = SendError;
458
459 fn poll_ready(
460 mut self: Pin<&mut Self>,
461 cx: &mut std::task::Context<'_>,
462 ) -> Poll<Result<(), Self::Error>> {
463 Pin::new(&mut self.sink)
464 .poll_ready(cx)
465 .map_err(|_| SendError::ReceiverDropped)
466 }
467
468 fn start_send(mut self: Pin<&mut Self>, item: Out) -> Result<(), Self::Error> {
469 let (send, res) = match self.serialize(item) {
471 Ok(data) => (Ok(data), Ok(())),
472 Err(cause) => (
473 Err(io::Error::new(io::ErrorKind::Other, cause.to_string())),
474 Err(cause),
475 ),
476 };
477 Pin::new(&mut self.sink)
479 .start_send(send)
480 .map_err(|_| SendError::ReceiverDropped)?;
481 res
482 }
483
484 fn poll_flush(
485 mut self: Pin<&mut Self>,
486 cx: &mut std::task::Context<'_>,
487 ) -> Poll<Result<(), Self::Error>> {
488 Pin::new(&mut self.sink)
489 .poll_flush(cx)
490 .map_err(|_| SendError::ReceiverDropped)
491 }
492
493 fn poll_close(
494 mut self: Pin<&mut Self>,
495 cx: &mut std::task::Context<'_>,
496 ) -> Poll<Result<(), Self::Error>> {
497 Pin::new(&mut self.sink)
498 .poll_close(cx)
499 .map_err(|_| SendError::ReceiverDropped)
500 }
501}
502
503#[derive(Debug)]
505pub enum SendError {
506 SerializeError(postcard::Error),
508 SizeError(usize),
510 ReceiverDropped,
512}
513
514impl fmt::Display for SendError {
515 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
516 fmt::Debug::fmt(&self, f)
517 }
518}
519
520impl error::Error for SendError {}
521
522#[derive(Debug)]
524pub enum RecvError {
525 DeserializeError(postcard::Error),
527 NetworkError(hyper::Error),
529}
530
531impl fmt::Display for RecvError {
532 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
533 fmt::Debug::fmt(&self, f)
534 }
535}
536
537impl error::Error for RecvError {}
538
539#[derive(Debug)]
541pub enum OpenError {
542 HyperHttp(hyper::http::Error),
544 Hyper(hyper::Error),
546 RemoteDropped,
548}
549
550impl fmt::Display for OpenError {
551 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
552 fmt::Debug::fmt(self, f)
553 }
554}
555
556impl std::error::Error for OpenError {}
557
558#[derive(Debug)]
562pub enum AcceptError {
563 Hyper(hyper::http::Error),
565 RemoteDropped,
567}
568
569impl fmt::Display for AcceptError {
570 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
571 fmt::Debug::fmt(self, f)
572 }
573}
574
575impl error::Error for AcceptError {}
576
577impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for HyperConnector<In, Out> {
578 type SendError = self::SendError;
579
580 type RecvError = self::RecvError;
581
582 type OpenError = OpenError;
583
584 type AcceptError = AcceptError;
585}
586
587impl<In: RpcMessage, Out: RpcMessage> StreamTypes for HyperConnector<In, Out> {
588 type In = In;
589 type Out = Out;
590 type RecvStream = self::RecvStream<In>;
591 type SendSink = self::SendSink<Out>;
592}
593
594impl<In: RpcMessage, Out: RpcMessage> Connector for HyperConnector<In, Out> {
595 async fn open(&self) -> Result<(Self::SendSink, Self::RecvStream), Self::OpenError> {
596 let (out_tx, out_rx) = flume::bounded::<io::Result<Bytes>>(32);
597 let req: Request<Body> = Request::post(&self.inner.uri)
598 .body(Body::wrap_stream(out_rx.into_stream()))
599 .map_err(OpenError::HyperHttp)?;
600 let res = self
601 .inner
602 .client
603 .request(req)
604 .await
605 .map_err(OpenError::Hyper)?;
606 let (in_tx, in_rx) = flume::bounded::<result::Result<In, RecvError>>(32);
607 spawn_recv_forwarder(res.into_body(), in_tx);
608
609 let out_tx = self::SendSink::new(out_tx, self.inner.config.clone());
610 let in_rx = self::RecvStream::new(in_rx);
611 Ok((out_tx, in_rx))
612 }
613}
614
615impl<In: RpcMessage, Out: RpcMessage> ConnectionErrors for HyperListener<In, Out> {
616 type SendError = self::SendError;
617 type RecvError = self::RecvError;
618 type OpenError = AcceptError;
619 type AcceptError = AcceptError;
620}
621
622impl<In: RpcMessage, Out: RpcMessage> StreamTypes for HyperListener<In, Out> {
623 type In = In;
624 type Out = Out;
625 type RecvStream = self::RecvStream<In>;
626 type SendSink = self::SendSink<Out>;
627}
628
629impl<In: RpcMessage, Out: RpcMessage> Listener for HyperListener<In, Out> {
630 fn local_addr(&self) -> &[LocalAddr] {
631 &self.local_addr
632 }
633
634 async fn accept(&self) -> Result<(Self::SendSink, Self::RecvStream), AcceptError> {
635 let (recv, send) = self
636 .channel
637 .recv_async()
638 .await
639 .map_err(|_| AcceptError::RemoteDropped)?;
640 Ok((
641 SendSink::new(send, self.config.clone()),
642 RecvStream::new(recv),
643 ))
644 }
645}