async_nats/
connection.rs

1// Copyright 2020-2022 The NATS Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! This module provides a connection implementation for communicating with a NATS server.
15
16use std::collections::VecDeque;
17use std::fmt::{self, Display, Write as _};
18use std::future::{self, Future};
19use std::io::IoSlice;
20use std::pin::Pin;
21use std::str::{self, FromStr};
22use std::sync::atomic::Ordering;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26#[cfg(feature = "websockets")]
27use {
28    futures::{SinkExt, StreamExt},
29    pin_project::pin_project,
30    tokio::io::ReadBuf,
31    tokio_websockets::WebSocketStream,
32};
33
34use bytes::{Buf, Bytes, BytesMut};
35use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite};
36
37use crate::header::{HeaderMap, HeaderName, IntoHeaderValue};
38use crate::status::StatusCode;
39use crate::subject::Subject;
40use crate::{ClientOp, ServerError, ServerOp, Statistics};
41
42/// Soft limit for the amount of bytes in [`Connection::write_buf`]
43/// and [`Connection::flattened_writes`].
44const SOFT_WRITE_BUF_LIMIT: usize = 65535;
45/// How big a single buffer must be before it's written separately
46/// instead of being flattened.
47const WRITE_FLATTEN_THRESHOLD: usize = 4096;
48/// How many buffers to write in a single vectored write call.
49const WRITE_VECTORED_CHUNKS: usize = 64;
50
51/// Supertrait enabling trait object for containing both TLS and non TLS `TcpStream` connection.
52pub(crate) trait AsyncReadWrite: AsyncWrite + AsyncRead + Send + Unpin {}
53
54/// Blanked implementation that applies to both TLS and non-TLS `TcpStream`.
55impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
56
57/// An enum representing the state of the connection.
58#[derive(Debug, Eq, PartialEq, Clone)]
59pub enum State {
60    Pending,
61    Connected,
62    Disconnected,
63}
64
65#[derive(Debug, Eq, PartialEq, Clone)]
66pub enum ShouldFlush {
67    /// Write buffers are empty, but the connection hasn't been flushed yet
68    Yes,
69    /// The connection hasn't been flushed yet, but write buffers aren't empty
70    May,
71    /// Flushing would just be a no-op
72    No,
73}
74
75impl Display for State {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        match self {
78            State::Pending => write!(f, "pending"),
79            State::Connected => write!(f, "connected"),
80            State::Disconnected => write!(f, "disconnected"),
81        }
82    }
83}
84
85/// A framed connection
86pub(crate) struct Connection {
87    pub(crate) stream: Box<dyn AsyncReadWrite>,
88    read_buf: BytesMut,
89    write_buf: VecDeque<Bytes>,
90    write_buf_len: usize,
91    flattened_writes: BytesMut,
92    can_flush: bool,
93    statistics: Arc<Statistics>,
94}
95
96/// Internal representation of the connection.
97/// Holds connection with NATS Server and communicates with `Client` via channels.
98impl Connection {
99    pub(crate) fn new(
100        stream: Box<dyn AsyncReadWrite>,
101        read_buffer_capacity: usize,
102        statistics: Arc<Statistics>,
103    ) -> Self {
104        Self {
105            stream,
106            read_buf: BytesMut::with_capacity(read_buffer_capacity),
107            write_buf: VecDeque::new(),
108            write_buf_len: 0,
109            flattened_writes: BytesMut::new(),
110            can_flush: false,
111            statistics,
112        }
113    }
114
115    /// Returns `true` if no more calls to [`Self::enqueue_write_op`] _should_ be made.
116    pub(crate) fn is_write_buf_full(&self) -> bool {
117        self.write_buf_len >= SOFT_WRITE_BUF_LIMIT
118    }
119
120    /// Returns `true` if [`Self::poll_flush`] should be polled.
121    pub(crate) fn should_flush(&self) -> ShouldFlush {
122        match (
123            self.can_flush,
124            self.write_buf.is_empty() && self.flattened_writes.is_empty(),
125        ) {
126            (true, true) => ShouldFlush::Yes,
127            (true, false) => ShouldFlush::May,
128            (false, _) => ShouldFlush::No,
129        }
130    }
131
132    /// Attempts to read a server operation from the read buffer.
133    /// Returns `None` if there is not enough data to parse an entire operation.
134    pub(crate) fn try_read_op(&mut self) -> Result<Option<ServerOp>, io::Error> {
135        let len = match memchr::memmem::find(&self.read_buf, b"\r\n") {
136            Some(len) => len,
137            None => return Ok(None),
138        };
139
140        if self.read_buf.starts_with(b"+OK") {
141            self.read_buf.advance(len + 2);
142            return Ok(Some(ServerOp::Ok));
143        }
144
145        if self.read_buf.starts_with(b"PING") {
146            self.read_buf.advance(len + 2);
147            return Ok(Some(ServerOp::Ping));
148        }
149
150        if self.read_buf.starts_with(b"PONG") {
151            self.read_buf.advance(len + 2);
152            return Ok(Some(ServerOp::Pong));
153        }
154
155        if self.read_buf.starts_with(b"-ERR") {
156            let description = str::from_utf8(&self.read_buf[5..len])
157                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?
158                .trim_matches('\'')
159                .to_owned();
160
161            self.read_buf.advance(len + 2);
162
163            return Ok(Some(ServerOp::Error(ServerError::new(description))));
164        }
165
166        if self.read_buf.starts_with(b"INFO ") {
167            let info = serde_json::from_slice(&self.read_buf[4..len])
168                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
169
170            self.read_buf.advance(len + 2);
171
172            return Ok(Some(ServerOp::Info(Box::new(info))));
173        }
174
175        if self.read_buf.starts_with(b"MSG ") {
176            let line = str::from_utf8(&self.read_buf[4..len]).unwrap();
177            let mut args = line.split(' ').filter(|s| !s.is_empty());
178
179            // Parse the operation syntax: MSG <subject> <sid> [reply-to] <#bytes>
180            let (subject, sid, reply_to, payload_len) = match (
181                args.next(),
182                args.next(),
183                args.next(),
184                args.next(),
185                args.next(),
186            ) {
187                (Some(subject), Some(sid), Some(reply_to), Some(payload_len), None) => {
188                    (subject, sid, Some(reply_to), payload_len)
189                }
190                (Some(subject), Some(sid), Some(payload_len), None, None) => {
191                    (subject, sid, None, payload_len)
192                }
193                _ => {
194                    return Err(io::Error::new(
195                        io::ErrorKind::InvalidInput,
196                        "invalid number of arguments after MSG",
197                    ))
198                }
199            };
200
201            let sid = sid
202                .parse::<u64>()
203                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
204
205            // Parse the number of payload bytes.
206            let payload_len = payload_len
207                .parse::<usize>()
208                .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
209
210            // Return early without advancing if there is not enough data read the entire
211            // message
212            if len + payload_len + 4 > self.read_buf.remaining() {
213                return Ok(None);
214            }
215
216            let length = payload_len
217                + reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0)
218                + subject.len();
219
220            let subject = Subject::from(subject);
221            let reply = reply_to.map(Subject::from);
222
223            self.read_buf.advance(len + 2);
224            let payload = self.read_buf.split_to(payload_len).freeze();
225            self.read_buf.advance(2);
226
227            return Ok(Some(ServerOp::Message {
228                sid,
229                length,
230                reply,
231                headers: None,
232                subject,
233                payload,
234                status: None,
235                description: None,
236            }));
237        }
238
239        if self.read_buf.starts_with(b"HMSG ") {
240            // Extract whitespace-delimited arguments that come after "HMSG".
241            let line = std::str::from_utf8(&self.read_buf[5..len]).unwrap();
242            let mut args = line.split_whitespace().filter(|s| !s.is_empty());
243
244            // <subject> <sid> [reply-to] <# header bytes><# total bytes>
245            let (subject, sid, reply_to, header_len, total_len) = match (
246                args.next(),
247                args.next(),
248                args.next(),
249                args.next(),
250                args.next(),
251                args.next(),
252            ) {
253                (
254                    Some(subject),
255                    Some(sid),
256                    Some(reply_to),
257                    Some(header_len),
258                    Some(total_len),
259                    None,
260                ) => (subject, sid, Some(reply_to), header_len, total_len),
261                (Some(subject), Some(sid), Some(header_len), Some(total_len), None, None) => {
262                    (subject, sid, None, header_len, total_len)
263                }
264                _ => {
265                    return Err(io::Error::new(
266                        io::ErrorKind::InvalidInput,
267                        "invalid number of arguments after HMSG",
268                    ))
269                }
270            };
271
272            // Convert the slice into a subject
273            let subject = Subject::from(subject);
274
275            // Parse the subject ID.
276            let sid = sid.parse::<u64>().map_err(|_| {
277                io::Error::new(
278                    io::ErrorKind::InvalidInput,
279                    "cannot parse sid argument after HMSG",
280                )
281            })?;
282
283            // Convert the slice into a subject.
284            let reply = reply_to.map(Subject::from);
285
286            // Parse the number of payload bytes.
287            let header_len = header_len.parse::<usize>().map_err(|_| {
288                io::Error::new(
289                    io::ErrorKind::InvalidInput,
290                    "cannot parse the number of header bytes argument after \
291                     HMSG",
292                )
293            })?;
294
295            // Parse the number of payload bytes.
296            let total_len = total_len.parse::<usize>().map_err(|_| {
297                io::Error::new(
298                    io::ErrorKind::InvalidInput,
299                    "cannot parse the number of bytes argument after HMSG",
300                )
301            })?;
302
303            if total_len < header_len {
304                return Err(io::Error::new(
305                    io::ErrorKind::InvalidInput,
306                    "number of header bytes was greater than or equal to the \
307                 total number of bytes after HMSG",
308                ));
309            }
310
311            if len + total_len + 4 > self.read_buf.remaining() {
312                return Ok(None);
313            }
314
315            self.read_buf.advance(len + 2);
316            let header = self.read_buf.split_to(header_len);
317            let payload = self.read_buf.split_to(total_len - header_len).freeze();
318            self.read_buf.advance(2);
319
320            let mut lines = std::str::from_utf8(&header)
321                .map_err(|_| {
322                    io::Error::new(io::ErrorKind::InvalidInput, "header isn't valid utf-8")
323                })?
324                .lines()
325                .peekable();
326            let version_line = lines.next().ok_or_else(|| {
327                io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
328            })?;
329
330            let version_line_suffix = version_line
331                .strip_prefix("NATS/1.0")
332                .map(str::trim)
333                .ok_or_else(|| {
334                    io::Error::new(
335                        io::ErrorKind::InvalidInput,
336                        "header version line does not begin with `NATS/1.0`",
337                    )
338                })?;
339
340            let (status, description) = version_line_suffix
341                .split_once(' ')
342                .map(|(status, description)| (status.trim(), description.trim()))
343                .unwrap_or((version_line_suffix, ""));
344            let status = if !status.is_empty() {
345                Some(status.parse::<StatusCode>().map_err(|_| {
346                    std::io::Error::new(io::ErrorKind::Other, "could not parse status parameter")
347                })?)
348            } else {
349                None
350            };
351            let description = if !description.is_empty() {
352                Some(description.to_owned())
353            } else {
354                None
355            };
356
357            let mut headers = HeaderMap::new();
358            while let Some(line) = lines.next() {
359                if line.is_empty() {
360                    continue;
361                }
362
363                let (name, value) = line.split_once(':').ok_or_else(|| {
364                    io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
365                })?;
366
367                let name = HeaderName::from_str(name)
368                    .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
369
370                // Read the header value, which might have been split into multiple lines
371                // `trim_start` and `trim_end` do the same job as doing `value.trim().to_owned()` at the end, but without a reallocation
372                let mut value = value.trim_start().to_owned();
373                while let Some(v) = lines.next_if(|s| s.starts_with(char::is_whitespace)) {
374                    value.push_str(v);
375                }
376                value.truncate(value.trim_end().len());
377
378                headers.append(name, value.into_header_value());
379            }
380
381            return Ok(Some(ServerOp::Message {
382                length: reply.as_ref().map_or(0, |reply| reply.len()) + subject.len() + total_len,
383                sid,
384                reply,
385                subject,
386                headers: Some(headers),
387                payload,
388                status,
389                description,
390            }));
391        }
392
393        let buffer = self.read_buf.split_to(len + 2);
394        let line = str::from_utf8(&buffer).map_err(|_| {
395            io::Error::new(io::ErrorKind::InvalidInput, "unable to parse unknown input")
396        })?;
397
398        Err(io::Error::new(
399            io::ErrorKind::InvalidInput,
400            format!("invalid server operation: '{line}'"),
401        ))
402    }
403
404    pub(crate) fn read_op(&mut self) -> impl Future<Output = io::Result<Option<ServerOp>>> + '_ {
405        future::poll_fn(|cx| self.poll_read_op(cx))
406    }
407
408    // TODO: do we want an custom error here?
409    /// Read a server operation from read buffer.
410    /// Blocks until an operation ca be parsed.
411    pub(crate) fn poll_read_op(
412        &mut self,
413        cx: &mut Context<'_>,
414    ) -> Poll<io::Result<Option<ServerOp>>> {
415        loop {
416            if let Some(op) = self.try_read_op()? {
417                return Poll::Ready(Ok(Some(op)));
418            }
419
420            let read_buf = self.stream.read_buf(&mut self.read_buf);
421            tokio::pin!(read_buf);
422            return match read_buf.poll(cx) {
423                Poll::Pending => Poll::Pending,
424                Poll::Ready(Ok(0)) if self.read_buf.is_empty() => Poll::Ready(Ok(None)),
425                Poll::Ready(Ok(0)) => Poll::Ready(Err(io::ErrorKind::ConnectionReset.into())),
426                Poll::Ready(Ok(n)) => {
427                    self.statistics.in_bytes.add(n as u64, Ordering::Relaxed);
428                    continue;
429                }
430                Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
431            };
432        }
433    }
434
435    pub(crate) async fn easy_write_and_flush<'a>(
436        &mut self,
437        items: impl Iterator<Item = &'a ClientOp>,
438    ) -> io::Result<()> {
439        for item in items {
440            self.enqueue_write_op(item);
441        }
442
443        future::poll_fn(|cx| self.poll_write(cx)).await?;
444        future::poll_fn(|cx| self.poll_flush(cx)).await?;
445        Ok(())
446    }
447
448    /// Writes a client operation to the write buffer.
449    pub(crate) fn enqueue_write_op(&mut self, item: &ClientOp) {
450        macro_rules! small_write {
451            ($dst:expr) => {
452                write!(self.small_write(), $dst).expect("do small write to Connection");
453            };
454        }
455
456        match item {
457            ClientOp::Connect(connect_info) => {
458                let json = serde_json::to_vec(&connect_info).expect("serialize `ConnectInfo`");
459
460                self.write("CONNECT ");
461                self.write(json);
462                self.write("\r\n");
463            }
464            ClientOp::Publish {
465                subject,
466                payload,
467                respond,
468                headers,
469            } => {
470                let verb = match headers.as_ref() {
471                    Some(headers) if !headers.is_empty() => "HPUB",
472                    _ => "PUB",
473                };
474
475                small_write!("{verb} {subject} ");
476
477                if let Some(respond) = respond {
478                    small_write!("{respond} ");
479                }
480
481                match headers {
482                    Some(headers) if !headers.is_empty() => {
483                        let headers = headers.to_bytes();
484
485                        let headers_len = headers.len();
486                        let total_len = headers_len + payload.len();
487                        small_write!("{headers_len} {total_len}\r\n");
488                        self.write(headers);
489                    }
490                    _ => {
491                        let payload_len = payload.len();
492                        small_write!("{payload_len}\r\n");
493                    }
494                }
495
496                self.write(Bytes::clone(payload));
497                self.write("\r\n");
498            }
499
500            ClientOp::Subscribe {
501                sid,
502                subject,
503                queue_group,
504            } => match queue_group {
505                Some(queue_group) => {
506                    small_write!("SUB {subject} {queue_group} {sid}\r\n");
507                }
508                None => {
509                    small_write!("SUB {subject} {sid}\r\n");
510                }
511            },
512
513            ClientOp::Unsubscribe { sid, max } => match max {
514                Some(max) => {
515                    small_write!("UNSUB {sid} {max}\r\n");
516                }
517                None => {
518                    small_write!("UNSUB {sid}\r\n");
519                }
520            },
521            ClientOp::Ping => {
522                self.write("PING\r\n");
523            }
524            ClientOp::Pong => {
525                self.write("PONG\r\n");
526            }
527        }
528    }
529
530    /// Write the internal buffers into the write stream
531    ///
532    /// Returns one of the following:
533    ///
534    /// * `Poll::Pending` means that we weren't able to fully empty
535    ///   the internal buffers. Compared to [`AsyncWrite::poll_write`],
536    ///   this implementation may do a partial write before yielding.
537    /// * `Poll::Ready(Ok())` means that the internal write buffers have
538    ///   been emptied or were already empty.
539    /// * `Poll::Ready(Err(err))` means that writing to the stream failed.
540    ///   Compared to [`AsyncWrite::poll_write`], this implementation
541    ///   may do a partial write before failing.
542    pub(crate) fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
543        if !self.stream.is_write_vectored() {
544            self.poll_write_sequential(cx)
545        } else {
546            self.poll_write_vectored(cx)
547        }
548    }
549
550    /// Write the internal buffers into the write stream using sequential write operations
551    ///
552    /// Writes one chunk at a time. Less efficient.
553    fn poll_write_sequential(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
554        loop {
555            let buf = match self.write_buf.front() {
556                Some(buf) => &**buf,
557                None if !self.flattened_writes.is_empty() => &self.flattened_writes,
558                None => return Poll::Ready(Ok(())),
559            };
560
561            debug_assert!(!buf.is_empty());
562
563            match Pin::new(&mut self.stream).poll_write(cx, buf) {
564                Poll::Pending => return Poll::Pending,
565                Poll::Ready(Ok(n)) => {
566                    self.statistics.out_bytes.add(n as u64, Ordering::Relaxed);
567                    self.write_buf_len -= n;
568                    self.can_flush = true;
569
570                    match self.write_buf.front_mut() {
571                        Some(buf) if n < buf.len() => {
572                            buf.advance(n);
573                        }
574                        Some(_buf) => {
575                            self.write_buf.pop_front();
576                        }
577                        None => {
578                            self.flattened_writes.advance(n);
579                        }
580                    }
581                    continue;
582                }
583                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
584            }
585        }
586    }
587    /// Write the internal buffers into the write stream using vectored write operations
588    ///
589    /// Writes [`WRITE_VECTORED_CHUNKS`] at a time. More efficient _if_
590    /// the underlying writer supports it.
591    fn poll_write_vectored(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
592        'outer: loop {
593            let mut writes = [IoSlice::new(b""); WRITE_VECTORED_CHUNKS];
594            let mut writes_len = 0;
595
596            self.write_buf
597                .iter()
598                .take(WRITE_VECTORED_CHUNKS)
599                .enumerate()
600                .for_each(|(i, buf)| {
601                    writes[i] = IoSlice::new(buf);
602                    writes_len += 1;
603                });
604
605            if writes_len < WRITE_VECTORED_CHUNKS && !self.flattened_writes.is_empty() {
606                writes[writes_len] = IoSlice::new(&self.flattened_writes);
607                writes_len += 1;
608            }
609
610            if writes_len == 0 {
611                return Poll::Ready(Ok(()));
612            }
613
614            match Pin::new(&mut self.stream).poll_write_vectored(cx, &writes[..writes_len]) {
615                Poll::Pending => return Poll::Pending,
616                Poll::Ready(Ok(mut n)) => {
617                    self.statistics.out_bytes.add(n as u64, Ordering::Relaxed);
618                    self.write_buf_len -= n;
619                    self.can_flush = true;
620
621                    while let Some(buf) = self.write_buf.front_mut() {
622                        if n < buf.len() {
623                            buf.advance(n);
624                            continue 'outer;
625                        }
626
627                        n -= buf.len();
628                        self.write_buf.pop_front();
629                    }
630
631                    self.flattened_writes.advance(n);
632                }
633                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
634            }
635        }
636    }
637
638    /// Write `buf` into the writes buffer
639    ///
640    /// If `buf` is smaller than [`WRITE_FLATTEN_THRESHOLD`]
641    /// flattens it, otherwise appends it to the chunks queue.
642    ///
643    /// Empty `buf`s are a no-op.
644    fn write(&mut self, buf: impl Into<Bytes>) {
645        let buf = buf.into();
646        if buf.is_empty() {
647            return;
648        }
649
650        self.write_buf_len += buf.len();
651        if buf.len() < WRITE_FLATTEN_THRESHOLD {
652            self.flattened_writes.extend_from_slice(&buf);
653        } else {
654            if !self.flattened_writes.is_empty() {
655                let buf = self.flattened_writes.split().freeze();
656                self.write_buf.push_back(buf);
657            }
658
659            self.write_buf.push_back(buf);
660        }
661    }
662
663    /// Obtain an [`fmt::Write`]r for the small writes buffer.
664    fn small_write(&mut self) -> impl fmt::Write + '_ {
665        struct Writer<'a> {
666            this: &'a mut Connection,
667        }
668
669        impl fmt::Write for Writer<'_> {
670            fn write_str(&mut self, s: &str) -> fmt::Result {
671                self.this.write_buf_len += s.len();
672                self.this.flattened_writes.write_str(s)
673            }
674        }
675
676        Writer { this: self }
677    }
678
679    /// Flush the write buffer, sending all pending data down the current write stream.
680    ///
681    /// no-op if the write stream didn't need to be flushed.
682    pub(crate) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
683        match Pin::new(&mut self.stream).poll_flush(cx) {
684            Poll::Pending => Poll::Pending,
685            Poll::Ready(Ok(())) => {
686                self.can_flush = false;
687                Poll::Ready(Ok(()))
688            }
689            Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
690        }
691    }
692}
693
694#[cfg(feature = "websockets")]
695#[pin_project]
696pub(crate) struct WebSocketAdapter<T> {
697    #[pin]
698    pub(crate) inner: WebSocketStream<T>,
699    pub(crate) read_buf: BytesMut,
700}
701
702#[cfg(feature = "websockets")]
703impl<T> WebSocketAdapter<T> {
704    pub(crate) fn new(inner: WebSocketStream<T>) -> Self {
705        Self {
706            inner,
707            read_buf: BytesMut::new(),
708        }
709    }
710}
711
712#[cfg(feature = "websockets")]
713impl<T> AsyncRead for WebSocketAdapter<T>
714where
715    T: AsyncRead + AsyncWrite + Unpin,
716{
717    fn poll_read(
718        self: Pin<&mut Self>,
719        cx: &mut Context<'_>,
720        buf: &mut ReadBuf<'_>,
721    ) -> Poll<std::io::Result<()>> {
722        let mut this = self.project();
723
724        loop {
725            // If we have data in the read buffer, let's move it to the output buffer.
726            if !this.read_buf.is_empty() {
727                let len = std::cmp::min(buf.remaining(), this.read_buf.len());
728                buf.put_slice(&this.read_buf.split_to(len));
729                return Poll::Ready(Ok(()));
730            }
731
732            match this.inner.poll_next_unpin(cx) {
733                Poll::Ready(Some(Ok(message))) => {
734                    this.read_buf.extend_from_slice(message.as_payload());
735                }
736                Poll::Ready(Some(Err(e))) => {
737                    return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)));
738                }
739                Poll::Ready(None) => {
740                    return Poll::Ready(Err(std::io::Error::new(
741                        std::io::ErrorKind::UnexpectedEof,
742                        "WebSocket closed",
743                    )));
744                }
745                Poll::Pending => {
746                    return Poll::Pending;
747                }
748            }
749        }
750    }
751}
752
753#[cfg(feature = "websockets")]
754impl<T> AsyncWrite for WebSocketAdapter<T>
755where
756    T: AsyncRead + AsyncWrite + Unpin,
757{
758    fn poll_write(
759        self: Pin<&mut Self>,
760        cx: &mut Context<'_>,
761        buf: &[u8],
762    ) -> Poll<std::io::Result<usize>> {
763        let mut this = self.project();
764
765        let data = buf.to_vec();
766        match this.inner.poll_ready_unpin(cx) {
767            Poll::Ready(Ok(())) => match this
768                .inner
769                .start_send_unpin(tokio_websockets::Message::binary(data))
770            {
771                Ok(()) => Poll::Ready(Ok(buf.len())),
772                Err(e) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))),
773            },
774            Poll::Ready(Err(e)) => {
775                Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)))
776            }
777            Poll::Pending => Poll::Pending,
778        }
779    }
780
781    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
782        self.project()
783            .inner
784            .poll_flush_unpin(cx)
785            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
786    }
787
788    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
789        self.project()
790            .inner
791            .poll_close_unpin(cx)
792            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
793    }
794}
795
796#[cfg(test)]
797mod read_op {
798    use std::sync::Arc;
799
800    use super::Connection;
801    use crate::{HeaderMap, ServerError, ServerInfo, ServerOp, Statistics, StatusCode};
802    use tokio::io::{self, AsyncWriteExt};
803
804    #[tokio::test]
805    async fn ok() {
806        let (stream, mut server) = io::duplex(128);
807        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
808
809        server.write_all(b"+OK\r\n").await.unwrap();
810        let result = connection.read_op().await.unwrap();
811        assert_eq!(result, Some(ServerOp::Ok));
812    }
813
814    #[tokio::test]
815    async fn ping() {
816        let (stream, mut server) = io::duplex(128);
817        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
818
819        server.write_all(b"PING\r\n").await.unwrap();
820        let result = connection.read_op().await.unwrap();
821        assert_eq!(result, Some(ServerOp::Ping));
822    }
823
824    #[tokio::test]
825    async fn pong() {
826        let (stream, mut server) = io::duplex(128);
827        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
828
829        server.write_all(b"PONG\r\n").await.unwrap();
830        let result = connection.read_op().await.unwrap();
831        assert_eq!(result, Some(ServerOp::Pong));
832    }
833
834    #[tokio::test]
835    async fn info() {
836        let (stream, mut server) = io::duplex(128);
837        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
838
839        server.write_all(b"INFO {}\r\n").await.unwrap();
840        server.flush().await.unwrap();
841
842        let result = connection.read_op().await.unwrap();
843        assert_eq!(result, Some(ServerOp::Info(Box::default())));
844
845        server
846            .write_all(b"INFO { \"version\": \"1.0.0\" }\r\n")
847            .await
848            .unwrap();
849        server.flush().await.unwrap();
850
851        let result = connection.read_op().await.unwrap();
852        assert_eq!(
853            result,
854            Some(ServerOp::Info(Box::new(ServerInfo {
855                version: "1.0.0".into(),
856                ..Default::default()
857            })))
858        );
859    }
860
861    #[tokio::test]
862    async fn error() {
863        let (stream, mut server) = io::duplex(128);
864        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
865
866        server.write_all(b"INFO {}\r\n").await.unwrap();
867        let result = connection.read_op().await.unwrap();
868        assert_eq!(result, Some(ServerOp::Info(Box::default())));
869
870        server
871            .write_all(b"-ERR something went wrong\r\n")
872            .await
873            .unwrap();
874        let result = connection.read_op().await.unwrap();
875        assert_eq!(
876            result,
877            Some(ServerOp::Error(ServerError::Other(
878                "something went wrong".into()
879            )))
880        );
881    }
882
883    #[tokio::test]
884    async fn message() {
885        let (stream, mut server) = io::duplex(128);
886        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
887
888        server
889            .write_all(b"MSG FOO.BAR 9 11\r\nHello World\r\n")
890            .await
891            .unwrap();
892
893        let result = connection.read_op().await.unwrap();
894        assert_eq!(
895            result,
896            Some(ServerOp::Message {
897                sid: 9,
898                subject: "FOO.BAR".into(),
899                reply: None,
900                headers: None,
901                payload: "Hello World".into(),
902                status: None,
903                description: None,
904                length: 7 + 11,
905            })
906        );
907
908        server
909            .write_all(b"MSG FOO.BAR 9 INBOX.34 11\r\nHello World\r\n")
910            .await
911            .unwrap();
912
913        let result = connection.read_op().await.unwrap();
914        assert_eq!(
915            result,
916            Some(ServerOp::Message {
917                sid: 9,
918                subject: "FOO.BAR".into(),
919                reply: Some("INBOX.34".into()),
920                headers: None,
921                payload: "Hello World".into(),
922                status: None,
923                description: None,
924                length: 7 + 8 + 11,
925            })
926        );
927
928        server
929            .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
930            .await
931            .unwrap();
932        server.write_all(b"NATS/1.0\r\n").await.unwrap();
933        server.write_all(b"Header: X\r\n").await.unwrap();
934        server.write_all(b"\r\n").await.unwrap();
935        server.write_all(b"Hello World\r\n").await.unwrap();
936
937        let result = connection.read_op().await.unwrap();
938
939        assert_eq!(
940            result,
941            Some(ServerOp::Message {
942                sid: 10,
943                subject: "FOO.BAR".into(),
944                reply: Some("INBOX.35".into()),
945                headers: Some(HeaderMap::from_iter([(
946                    "Header".parse().unwrap(),
947                    "X".parse().unwrap()
948                )])),
949                payload: "Hello World".into(),
950                status: None,
951                description: None,
952                length: 7 + 8 + 34
953            })
954        );
955
956        server
957            .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
958            .await
959            .unwrap();
960        server.write_all(b"NATS/1.0\r\n").await.unwrap();
961        server.write_all(b"Header: Y\r\n").await.unwrap();
962        server.write_all(b"\r\n").await.unwrap();
963        server.write_all(b"Hello World\r\n").await.unwrap();
964
965        let result = connection.read_op().await.unwrap();
966        assert_eq!(
967            result,
968            Some(ServerOp::Message {
969                sid: 10,
970                subject: "FOO.BAR".into(),
971                reply: Some("INBOX.35".into()),
972                headers: Some(HeaderMap::from_iter([(
973                    "Header".parse().unwrap(),
974                    "Y".parse().unwrap()
975                )])),
976                payload: "Hello World".into(),
977                status: None,
978                description: None,
979                length: 7 + 8 + 34,
980            })
981        );
982
983        server
984            .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
985            .await
986            .unwrap();
987        server
988            .write_all(b"NATS/1.0 404 No Messages\r\n")
989            .await
990            .unwrap();
991        server.write_all(b"\r\n").await.unwrap();
992        server.write_all(b"\r\n").await.unwrap();
993
994        let result = connection.read_op().await.unwrap();
995        assert_eq!(
996            result,
997            Some(ServerOp::Message {
998                sid: 10,
999                subject: "FOO.BAR".into(),
1000                reply: Some("INBOX.35".into()),
1001                headers: Some(HeaderMap::default()),
1002                payload: "".into(),
1003                status: Some(StatusCode::NOT_FOUND),
1004                description: Some("No Messages".to_string()),
1005                length: 7 + 8 + 28,
1006            })
1007        );
1008
1009        server
1010            .write_all(b"MSG FOO.BAR 9 11\r\nHello Again\r\n")
1011            .await
1012            .unwrap();
1013
1014        let result = connection.read_op().await.unwrap();
1015        assert_eq!(
1016            result,
1017            Some(ServerOp::Message {
1018                sid: 9,
1019                subject: "FOO.BAR".into(),
1020                reply: None,
1021                headers: None,
1022                payload: "Hello Again".into(),
1023                status: None,
1024                description: None,
1025                length: 7 + 11,
1026            })
1027        );
1028    }
1029
1030    #[tokio::test]
1031    async fn unknown() {
1032        let (stream, mut server) = io::duplex(128);
1033        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1034
1035        server.write_all(b"ONE\r\n").await.unwrap();
1036        connection.read_op().await.unwrap_err();
1037
1038        server.write_all(b"TWO\r\n").await.unwrap();
1039        connection.read_op().await.unwrap_err();
1040
1041        server.write_all(b"PING\r\n").await.unwrap();
1042        connection.read_op().await.unwrap();
1043
1044        server.write_all(b"THREE\r\n").await.unwrap();
1045        connection.read_op().await.unwrap_err();
1046
1047        server
1048            .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
1049            .await
1050            .unwrap();
1051        server
1052            .write_all(b"NATS/1.0 404 No Messages\r\n")
1053            .await
1054            .unwrap();
1055        server.write_all(b"\r\n").await.unwrap();
1056        server.write_all(b"\r\n").await.unwrap();
1057
1058        let result = connection.read_op().await.unwrap();
1059        assert_eq!(
1060            result,
1061            Some(ServerOp::Message {
1062                sid: 10,
1063                subject: "FOO.BAR".into(),
1064                reply: Some("INBOX.35".into()),
1065                headers: Some(HeaderMap::default()),
1066                payload: "".into(),
1067                status: Some(StatusCode::NOT_FOUND),
1068                description: Some("No Messages".to_string()),
1069                length: 7 + 8 + 28,
1070            })
1071        );
1072
1073        server.write_all(b"FOUR\r\n").await.unwrap();
1074        connection.read_op().await.unwrap_err();
1075
1076        server.write_all(b"PONG\r\n").await.unwrap();
1077        connection.read_op().await.unwrap();
1078    }
1079}
1080
1081#[cfg(test)]
1082mod write_op {
1083    use std::sync::Arc;
1084
1085    use super::Connection;
1086    use crate::{ClientOp, ConnectInfo, HeaderMap, Protocol, Statistics};
1087    use tokio::io::{self, AsyncBufReadExt, BufReader};
1088
1089    #[tokio::test]
1090    async fn publish() {
1091        let (stream, server) = io::duplex(128);
1092        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1093
1094        connection
1095            .easy_write_and_flush(
1096                [ClientOp::Publish {
1097                    subject: "FOO.BAR".into(),
1098                    payload: "Hello World".into(),
1099                    respond: None,
1100                    headers: None,
1101                }]
1102                .iter(),
1103            )
1104            .await
1105            .unwrap();
1106
1107        let mut buffer = String::new();
1108        let mut reader = BufReader::new(server);
1109        reader.read_line(&mut buffer).await.unwrap();
1110        reader.read_line(&mut buffer).await.unwrap();
1111        assert_eq!(buffer, "PUB FOO.BAR 11\r\nHello World\r\n");
1112
1113        connection
1114            .easy_write_and_flush(
1115                [ClientOp::Publish {
1116                    subject: "FOO.BAR".into(),
1117                    payload: "Hello World".into(),
1118                    respond: Some("INBOX.67".into()),
1119                    headers: None,
1120                }]
1121                .iter(),
1122            )
1123            .await
1124            .unwrap();
1125
1126        buffer.clear();
1127        reader.read_line(&mut buffer).await.unwrap();
1128        reader.read_line(&mut buffer).await.unwrap();
1129        assert_eq!(buffer, "PUB FOO.BAR INBOX.67 11\r\nHello World\r\n");
1130
1131        connection
1132            .easy_write_and_flush(
1133                [ClientOp::Publish {
1134                    subject: "FOO.BAR".into(),
1135                    payload: "Hello World".into(),
1136                    respond: Some("INBOX.67".into()),
1137                    headers: Some(HeaderMap::from_iter([(
1138                        "Header".parse().unwrap(),
1139                        "X".parse().unwrap(),
1140                    )])),
1141                }]
1142                .iter(),
1143            )
1144            .await
1145            .unwrap();
1146
1147        buffer.clear();
1148        reader.read_line(&mut buffer).await.unwrap();
1149        reader.read_line(&mut buffer).await.unwrap();
1150        reader.read_line(&mut buffer).await.unwrap();
1151        reader.read_line(&mut buffer).await.unwrap();
1152        assert_eq!(
1153            buffer,
1154            "HPUB FOO.BAR INBOX.67 23 34\r\nNATS/1.0\r\nHeader: X\r\n\r\n"
1155        );
1156    }
1157
1158    #[tokio::test]
1159    async fn subscribe() {
1160        let (stream, server) = io::duplex(128);
1161        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1162
1163        connection
1164            .easy_write_and_flush(
1165                [ClientOp::Subscribe {
1166                    sid: 11,
1167                    subject: "FOO.BAR".into(),
1168                    queue_group: None,
1169                }]
1170                .iter(),
1171            )
1172            .await
1173            .unwrap();
1174
1175        let mut buffer = String::new();
1176        let mut reader = BufReader::new(server);
1177        reader.read_line(&mut buffer).await.unwrap();
1178        assert_eq!(buffer, "SUB FOO.BAR 11\r\n");
1179
1180        connection
1181            .easy_write_and_flush(
1182                [ClientOp::Subscribe {
1183                    sid: 11,
1184                    subject: "FOO.BAR".into(),
1185                    queue_group: Some("QUEUE.GROUP".into()),
1186                }]
1187                .iter(),
1188            )
1189            .await
1190            .unwrap();
1191
1192        buffer.clear();
1193        reader.read_line(&mut buffer).await.unwrap();
1194        assert_eq!(buffer, "SUB FOO.BAR QUEUE.GROUP 11\r\n");
1195    }
1196
1197    #[tokio::test]
1198    async fn unsubscribe() {
1199        let (stream, server) = io::duplex(128);
1200        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1201
1202        connection
1203            .easy_write_and_flush([ClientOp::Unsubscribe { sid: 11, max: None }].iter())
1204            .await
1205            .unwrap();
1206
1207        let mut buffer = String::new();
1208        let mut reader = BufReader::new(server);
1209        reader.read_line(&mut buffer).await.unwrap();
1210        assert_eq!(buffer, "UNSUB 11\r\n");
1211
1212        connection
1213            .easy_write_and_flush(
1214                [ClientOp::Unsubscribe {
1215                    sid: 11,
1216                    max: Some(2),
1217                }]
1218                .iter(),
1219            )
1220            .await
1221            .unwrap();
1222
1223        buffer.clear();
1224        reader.read_line(&mut buffer).await.unwrap();
1225        assert_eq!(buffer, "UNSUB 11 2\r\n");
1226    }
1227
1228    #[tokio::test]
1229    async fn ping() {
1230        let (stream, server) = io::duplex(128);
1231        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1232
1233        let mut reader = BufReader::new(server);
1234        let mut buffer = String::new();
1235
1236        connection
1237            .easy_write_and_flush([ClientOp::Ping].iter())
1238            .await
1239            .unwrap();
1240
1241        reader.read_line(&mut buffer).await.unwrap();
1242
1243        assert_eq!(buffer, "PING\r\n");
1244    }
1245
1246    #[tokio::test]
1247    async fn pong() {
1248        let (stream, server) = io::duplex(128);
1249        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1250
1251        let mut reader = BufReader::new(server);
1252        let mut buffer = String::new();
1253
1254        connection
1255            .easy_write_and_flush([ClientOp::Pong].iter())
1256            .await
1257            .unwrap();
1258
1259        reader.read_line(&mut buffer).await.unwrap();
1260
1261        assert_eq!(buffer, "PONG\r\n");
1262    }
1263
1264    #[tokio::test]
1265    async fn connect() {
1266        let (stream, server) = io::duplex(1024);
1267        let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1268
1269        let mut reader = BufReader::new(server);
1270        let mut buffer = String::new();
1271
1272        connection
1273            .easy_write_and_flush(
1274                [ClientOp::Connect(ConnectInfo {
1275                    verbose: false,
1276                    pedantic: false,
1277                    user_jwt: None,
1278                    nkey: None,
1279                    signature: None,
1280                    name: None,
1281                    echo: false,
1282                    lang: "Rust".into(),
1283                    version: "1.0.0".into(),
1284                    protocol: Protocol::Dynamic,
1285                    tls_required: false,
1286                    user: None,
1287                    pass: None,
1288                    auth_token: None,
1289                    headers: false,
1290                    no_responders: false,
1291                })]
1292                .iter(),
1293            )
1294            .await
1295            .unwrap();
1296
1297        reader.read_line(&mut buffer).await.unwrap();
1298        assert_eq!(
1299            buffer,
1300            "CONNECT {\"verbose\":false,\"pedantic\":false,\"jwt\":null,\"nkey\":null,\"sig\":null,\"name\":null,\"echo\":false,\"lang\":\"Rust\",\"version\":\"1.0.0\",\"protocol\":1,\"tls_required\":false,\"user\":null,\"pass\":null,\"auth_token\":null,\"headers\":false,\"no_responders\":false}\r\n"
1301        );
1302    }
1303}