async_sse/
encoder.rs

1use futures_lite::prelude::*;
2use futures_lite::ready;
3use std::task::{Context, Poll};
4
5use std::io;
6use std::pin::Pin;
7use std::time::Duration;
8
9pin_project_lite::pin_project! {
10    /// An SSE protocol encoder.
11    #[derive(Debug)]
12    pub struct Encoder {
13        buf: Box<[u8]>,
14        cursor: usize,
15        #[pin]
16        receiver: async_channel::Receiver<Vec<u8>>,
17    }
18}
19
20impl AsyncRead for Encoder {
21    fn poll_read(
22        self: Pin<&mut Self>,
23        cx: &mut Context<'_>,
24        buf: &mut [u8],
25    ) -> Poll<io::Result<usize>> {
26        let mut this = self.project();
27        // Request a new buffer if current one is exhausted.
28        if this.buf.len() <= *this.cursor {
29            match ready!(this.receiver.as_mut().poll_next(cx)) {
30                Some(buf) => {
31                    log::trace!("> Received a new buffer with len {}", buf.len());
32                    *this.buf = buf.into_boxed_slice();
33                    *this.cursor = 0;
34                }
35                None => {
36                    log::trace!("> Encoder done reading");
37                    return Poll::Ready(Ok(0));
38                }
39            };
40        }
41
42        // Write the current buffer to completion.
43        let local_buf = &this.buf[*this.cursor..];
44        let max = buf.len().min(local_buf.len());
45        buf[..max].clone_from_slice(&local_buf[..max]);
46        *this.cursor += max;
47
48        // Return bytes read.
49        Poll::Ready(Ok(max))
50    }
51}
52
53impl AsyncBufRead for Encoder {
54    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
55        let mut this = self.project();
56        // Request a new buffer if current one is exhausted.
57        if this.buf.len() <= *this.cursor {
58            match ready!(this.receiver.as_mut().poll_next(cx)) {
59                Some(buf) => {
60                    log::trace!("> Received a new buffer with len {}", buf.len());
61                    *this.buf = buf.into_boxed_slice();
62                    *this.cursor = 0;
63                }
64                None => {
65                    log::trace!("> Encoder done reading");
66                    return Poll::Ready(Ok(&[]));
67                }
68            };
69        }
70        Poll::Ready(Ok(&this.buf[*this.cursor..]))
71    }
72
73    fn consume(self: Pin<&mut Self>, amt: usize) {
74        let this = self.project();
75        *this.cursor += amt;
76    }
77}
78
79/// The sending side of the encoder.
80#[derive(Debug, Clone)]
81pub struct Sender(async_channel::Sender<Vec<u8>>);
82
83/// Create a new SSE encoder.
84pub fn encode() -> (Sender, Encoder) {
85    let (sender, receiver) = async_channel::bounded(1);
86    let encoder = Encoder {
87        receiver,
88        buf: Box::default(),
89        cursor: 0,
90    };
91    (Sender(sender), encoder)
92}
93
94impl Sender {
95    async fn inner_send(&self, bytes: impl Into<Vec<u8>>) -> io::Result<()> {
96        self.0
97            .send(bytes.into())
98            .await
99            .map_err(|_| io::Error::new(io::ErrorKind::ConnectionAborted, "sse disconnected"))
100    }
101
102    /// Send a new message over SSE.
103    pub async fn send(
104        &self,
105        name: impl Into<Option<&str>>,
106        data: &str,
107        id: Option<&str>,
108    ) -> io::Result<()> {
109        // Write the event name
110        if let Some(name) = name.into() {
111            self.inner_send(format!("event:{}\n", name)).await?;
112        }
113
114        // Write the id
115        if let Some(id) = id {
116            self.inner_send(format!("id:{}\n", id)).await?;
117        }
118
119        // Write the data section, and end.
120        for line in data.lines() {
121            let msg = format!("data:{}\n", line);
122            self.inner_send(msg).await?;
123        }
124        self.inner_send("\n").await?;
125
126        Ok(())
127    }
128
129    /// Send a new "retry" message over SSE.
130    pub async fn send_retry(&self, dur: Duration, id: Option<&str>) -> io::Result<()> {
131        // Write the id
132        if let Some(id) = id {
133            self.inner_send(format!("id:{}\n", id)).await?;
134        }
135
136        // Write the retry section, and end.
137        let dur = dur.as_secs_f64() as u64;
138        let msg = format!("retry:{}\n\n", dur);
139        self.inner_send(msg).await?;
140        Ok(())
141    }
142}