1use std::collections::VecDeque;
17use std::fmt::{self, Display, Write as _};
18use std::future::{self, Future};
19use std::io::IoSlice;
20use std::pin::Pin;
21use std::str::{self, FromStr};
22use std::sync::atomic::Ordering;
23use std::sync::Arc;
24use std::task::{Context, Poll};
25
26#[cfg(feature = "websockets")]
27use {
28 futures::{SinkExt, StreamExt},
29 pin_project::pin_project,
30 tokio::io::ReadBuf,
31 tokio_websockets::WebSocketStream,
32};
33
34use bytes::{Buf, Bytes, BytesMut};
35use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite};
36
37use crate::header::{HeaderMap, HeaderName, IntoHeaderValue};
38use crate::status::StatusCode;
39use crate::subject::Subject;
40use crate::{ClientOp, ServerError, ServerOp, Statistics};
41
42const SOFT_WRITE_BUF_LIMIT: usize = 65535;
45const WRITE_FLATTEN_THRESHOLD: usize = 4096;
48const WRITE_VECTORED_CHUNKS: usize = 64;
50
51pub(crate) trait AsyncReadWrite: AsyncWrite + AsyncRead + Send + Unpin {}
53
54impl<T> AsyncReadWrite for T where T: AsyncRead + AsyncWrite + Unpin + Send {}
56
57#[derive(Debug, Eq, PartialEq, Clone)]
59pub enum State {
60 Pending,
61 Connected,
62 Disconnected,
63}
64
65#[derive(Debug, Eq, PartialEq, Clone)]
66pub enum ShouldFlush {
67 Yes,
69 May,
71 No,
73}
74
75impl Display for State {
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 match self {
78 State::Pending => write!(f, "pending"),
79 State::Connected => write!(f, "connected"),
80 State::Disconnected => write!(f, "disconnected"),
81 }
82 }
83}
84
85pub(crate) struct Connection {
87 pub(crate) stream: Box<dyn AsyncReadWrite>,
88 read_buf: BytesMut,
89 write_buf: VecDeque<Bytes>,
90 write_buf_len: usize,
91 flattened_writes: BytesMut,
92 can_flush: bool,
93 statistics: Arc<Statistics>,
94}
95
96impl Connection {
99 pub(crate) fn new(
100 stream: Box<dyn AsyncReadWrite>,
101 read_buffer_capacity: usize,
102 statistics: Arc<Statistics>,
103 ) -> Self {
104 Self {
105 stream,
106 read_buf: BytesMut::with_capacity(read_buffer_capacity),
107 write_buf: VecDeque::new(),
108 write_buf_len: 0,
109 flattened_writes: BytesMut::new(),
110 can_flush: false,
111 statistics,
112 }
113 }
114
115 pub(crate) fn is_write_buf_full(&self) -> bool {
117 self.write_buf_len >= SOFT_WRITE_BUF_LIMIT
118 }
119
120 pub(crate) fn should_flush(&self) -> ShouldFlush {
122 match (
123 self.can_flush,
124 self.write_buf.is_empty() && self.flattened_writes.is_empty(),
125 ) {
126 (true, true) => ShouldFlush::Yes,
127 (true, false) => ShouldFlush::May,
128 (false, _) => ShouldFlush::No,
129 }
130 }
131
132 pub(crate) fn try_read_op(&mut self) -> Result<Option<ServerOp>, io::Error> {
135 let len = match memchr::memmem::find(&self.read_buf, b"\r\n") {
136 Some(len) => len,
137 None => return Ok(None),
138 };
139
140 if self.read_buf.starts_with(b"+OK") {
141 self.read_buf.advance(len + 2);
142 return Ok(Some(ServerOp::Ok));
143 }
144
145 if self.read_buf.starts_with(b"PING") {
146 self.read_buf.advance(len + 2);
147 return Ok(Some(ServerOp::Ping));
148 }
149
150 if self.read_buf.starts_with(b"PONG") {
151 self.read_buf.advance(len + 2);
152 return Ok(Some(ServerOp::Pong));
153 }
154
155 if self.read_buf.starts_with(b"-ERR") {
156 let description = str::from_utf8(&self.read_buf[5..len])
157 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?
158 .trim_matches('\'')
159 .to_owned();
160
161 self.read_buf.advance(len + 2);
162
163 return Ok(Some(ServerOp::Error(ServerError::new(description))));
164 }
165
166 if self.read_buf.starts_with(b"INFO ") {
167 let info = serde_json::from_slice(&self.read_buf[4..len])
168 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
169
170 self.read_buf.advance(len + 2);
171
172 return Ok(Some(ServerOp::Info(Box::new(info))));
173 }
174
175 if self.read_buf.starts_with(b"MSG ") {
176 let line = str::from_utf8(&self.read_buf[4..len]).unwrap();
177 let mut args = line.split(' ').filter(|s| !s.is_empty());
178
179 let (subject, sid, reply_to, payload_len) = match (
181 args.next(),
182 args.next(),
183 args.next(),
184 args.next(),
185 args.next(),
186 ) {
187 (Some(subject), Some(sid), Some(reply_to), Some(payload_len), None) => {
188 (subject, sid, Some(reply_to), payload_len)
189 }
190 (Some(subject), Some(sid), Some(payload_len), None, None) => {
191 (subject, sid, None, payload_len)
192 }
193 _ => {
194 return Err(io::Error::new(
195 io::ErrorKind::InvalidInput,
196 "invalid number of arguments after MSG",
197 ))
198 }
199 };
200
201 let sid = sid
202 .parse::<u64>()
203 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
204
205 let payload_len = payload_len
207 .parse::<usize>()
208 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
209
210 if len + payload_len + 4 > self.read_buf.remaining() {
213 return Ok(None);
214 }
215
216 let length = payload_len
217 + reply_to.as_ref().map(|reply| reply.len()).unwrap_or(0)
218 + subject.len();
219
220 let subject = Subject::from(subject);
221 let reply = reply_to.map(Subject::from);
222
223 self.read_buf.advance(len + 2);
224 let payload = self.read_buf.split_to(payload_len).freeze();
225 self.read_buf.advance(2);
226
227 return Ok(Some(ServerOp::Message {
228 sid,
229 length,
230 reply,
231 headers: None,
232 subject,
233 payload,
234 status: None,
235 description: None,
236 }));
237 }
238
239 if self.read_buf.starts_with(b"HMSG ") {
240 let line = std::str::from_utf8(&self.read_buf[5..len]).unwrap();
242 let mut args = line.split_whitespace().filter(|s| !s.is_empty());
243
244 let (subject, sid, reply_to, header_len, total_len) = match (
246 args.next(),
247 args.next(),
248 args.next(),
249 args.next(),
250 args.next(),
251 args.next(),
252 ) {
253 (
254 Some(subject),
255 Some(sid),
256 Some(reply_to),
257 Some(header_len),
258 Some(total_len),
259 None,
260 ) => (subject, sid, Some(reply_to), header_len, total_len),
261 (Some(subject), Some(sid), Some(header_len), Some(total_len), None, None) => {
262 (subject, sid, None, header_len, total_len)
263 }
264 _ => {
265 return Err(io::Error::new(
266 io::ErrorKind::InvalidInput,
267 "invalid number of arguments after HMSG",
268 ))
269 }
270 };
271
272 let subject = Subject::from(subject);
274
275 let sid = sid.parse::<u64>().map_err(|_| {
277 io::Error::new(
278 io::ErrorKind::InvalidInput,
279 "cannot parse sid argument after HMSG",
280 )
281 })?;
282
283 let reply = reply_to.map(Subject::from);
285
286 let header_len = header_len.parse::<usize>().map_err(|_| {
288 io::Error::new(
289 io::ErrorKind::InvalidInput,
290 "cannot parse the number of header bytes argument after \
291 HMSG",
292 )
293 })?;
294
295 let total_len = total_len.parse::<usize>().map_err(|_| {
297 io::Error::new(
298 io::ErrorKind::InvalidInput,
299 "cannot parse the number of bytes argument after HMSG",
300 )
301 })?;
302
303 if total_len < header_len {
304 return Err(io::Error::new(
305 io::ErrorKind::InvalidInput,
306 "number of header bytes was greater than or equal to the \
307 total number of bytes after HMSG",
308 ));
309 }
310
311 if len + total_len + 4 > self.read_buf.remaining() {
312 return Ok(None);
313 }
314
315 self.read_buf.advance(len + 2);
316 let header = self.read_buf.split_to(header_len);
317 let payload = self.read_buf.split_to(total_len - header_len).freeze();
318 self.read_buf.advance(2);
319
320 let mut lines = std::str::from_utf8(&header)
321 .map_err(|_| {
322 io::Error::new(io::ErrorKind::InvalidInput, "header isn't valid utf-8")
323 })?
324 .lines()
325 .peekable();
326 let version_line = lines.next().ok_or_else(|| {
327 io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
328 })?;
329
330 let version_line_suffix = version_line
331 .strip_prefix("NATS/1.0")
332 .map(str::trim)
333 .ok_or_else(|| {
334 io::Error::new(
335 io::ErrorKind::InvalidInput,
336 "header version line does not begin with `NATS/1.0`",
337 )
338 })?;
339
340 let (status, description) = version_line_suffix
341 .split_once(' ')
342 .map(|(status, description)| (status.trim(), description.trim()))
343 .unwrap_or((version_line_suffix, ""));
344 let status = if !status.is_empty() {
345 Some(status.parse::<StatusCode>().map_err(|_| {
346 std::io::Error::new(io::ErrorKind::Other, "could not parse status parameter")
347 })?)
348 } else {
349 None
350 };
351 let description = if !description.is_empty() {
352 Some(description.to_owned())
353 } else {
354 None
355 };
356
357 let mut headers = HeaderMap::new();
358 while let Some(line) = lines.next() {
359 if line.is_empty() {
360 continue;
361 }
362
363 let (name, value) = line.split_once(':').ok_or_else(|| {
364 io::Error::new(io::ErrorKind::InvalidInput, "no header version line found")
365 })?;
366
367 let name = HeaderName::from_str(name)
368 .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
369
370 let mut value = value.trim_start().to_owned();
373 while let Some(v) = lines.next_if(|s| s.starts_with(char::is_whitespace)) {
374 value.push_str(v);
375 }
376 value.truncate(value.trim_end().len());
377
378 headers.append(name, value.into_header_value());
379 }
380
381 return Ok(Some(ServerOp::Message {
382 length: reply.as_ref().map_or(0, |reply| reply.len()) + subject.len() + total_len,
383 sid,
384 reply,
385 subject,
386 headers: Some(headers),
387 payload,
388 status,
389 description,
390 }));
391 }
392
393 let buffer = self.read_buf.split_to(len + 2);
394 let line = str::from_utf8(&buffer).map_err(|_| {
395 io::Error::new(io::ErrorKind::InvalidInput, "unable to parse unknown input")
396 })?;
397
398 Err(io::Error::new(
399 io::ErrorKind::InvalidInput,
400 format!("invalid server operation: '{line}'"),
401 ))
402 }
403
404 pub(crate) fn read_op(&mut self) -> impl Future<Output = io::Result<Option<ServerOp>>> + '_ {
405 future::poll_fn(|cx| self.poll_read_op(cx))
406 }
407
408 pub(crate) fn poll_read_op(
412 &mut self,
413 cx: &mut Context<'_>,
414 ) -> Poll<io::Result<Option<ServerOp>>> {
415 loop {
416 if let Some(op) = self.try_read_op()? {
417 return Poll::Ready(Ok(Some(op)));
418 }
419
420 let read_buf = self.stream.read_buf(&mut self.read_buf);
421 tokio::pin!(read_buf);
422 return match read_buf.poll(cx) {
423 Poll::Pending => Poll::Pending,
424 Poll::Ready(Ok(0)) if self.read_buf.is_empty() => Poll::Ready(Ok(None)),
425 Poll::Ready(Ok(0)) => Poll::Ready(Err(io::ErrorKind::ConnectionReset.into())),
426 Poll::Ready(Ok(n)) => {
427 self.statistics.in_bytes.add(n as u64, Ordering::Relaxed);
428 continue;
429 }
430 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
431 };
432 }
433 }
434
435 pub(crate) async fn easy_write_and_flush<'a>(
436 &mut self,
437 items: impl Iterator<Item = &'a ClientOp>,
438 ) -> io::Result<()> {
439 for item in items {
440 self.enqueue_write_op(item);
441 }
442
443 future::poll_fn(|cx| self.poll_write(cx)).await?;
444 future::poll_fn(|cx| self.poll_flush(cx)).await?;
445 Ok(())
446 }
447
448 pub(crate) fn enqueue_write_op(&mut self, item: &ClientOp) {
450 macro_rules! small_write {
451 ($dst:expr) => {
452 write!(self.small_write(), $dst).expect("do small write to Connection");
453 };
454 }
455
456 match item {
457 ClientOp::Connect(connect_info) => {
458 let json = serde_json::to_vec(&connect_info).expect("serialize `ConnectInfo`");
459
460 self.write("CONNECT ");
461 self.write(json);
462 self.write("\r\n");
463 }
464 ClientOp::Publish {
465 subject,
466 payload,
467 respond,
468 headers,
469 } => {
470 let verb = match headers.as_ref() {
471 Some(headers) if !headers.is_empty() => "HPUB",
472 _ => "PUB",
473 };
474
475 small_write!("{verb} {subject} ");
476
477 if let Some(respond) = respond {
478 small_write!("{respond} ");
479 }
480
481 match headers {
482 Some(headers) if !headers.is_empty() => {
483 let headers = headers.to_bytes();
484
485 let headers_len = headers.len();
486 let total_len = headers_len + payload.len();
487 small_write!("{headers_len} {total_len}\r\n");
488 self.write(headers);
489 }
490 _ => {
491 let payload_len = payload.len();
492 small_write!("{payload_len}\r\n");
493 }
494 }
495
496 self.write(Bytes::clone(payload));
497 self.write("\r\n");
498 }
499
500 ClientOp::Subscribe {
501 sid,
502 subject,
503 queue_group,
504 } => match queue_group {
505 Some(queue_group) => {
506 small_write!("SUB {subject} {queue_group} {sid}\r\n");
507 }
508 None => {
509 small_write!("SUB {subject} {sid}\r\n");
510 }
511 },
512
513 ClientOp::Unsubscribe { sid, max } => match max {
514 Some(max) => {
515 small_write!("UNSUB {sid} {max}\r\n");
516 }
517 None => {
518 small_write!("UNSUB {sid}\r\n");
519 }
520 },
521 ClientOp::Ping => {
522 self.write("PING\r\n");
523 }
524 ClientOp::Pong => {
525 self.write("PONG\r\n");
526 }
527 }
528 }
529
530 pub(crate) fn poll_write(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
543 if !self.stream.is_write_vectored() {
544 self.poll_write_sequential(cx)
545 } else {
546 self.poll_write_vectored(cx)
547 }
548 }
549
550 fn poll_write_sequential(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
554 loop {
555 let buf = match self.write_buf.front() {
556 Some(buf) => &**buf,
557 None if !self.flattened_writes.is_empty() => &self.flattened_writes,
558 None => return Poll::Ready(Ok(())),
559 };
560
561 debug_assert!(!buf.is_empty());
562
563 match Pin::new(&mut self.stream).poll_write(cx, buf) {
564 Poll::Pending => return Poll::Pending,
565 Poll::Ready(Ok(n)) => {
566 self.statistics.out_bytes.add(n as u64, Ordering::Relaxed);
567 self.write_buf_len -= n;
568 self.can_flush = true;
569
570 match self.write_buf.front_mut() {
571 Some(buf) if n < buf.len() => {
572 buf.advance(n);
573 }
574 Some(_buf) => {
575 self.write_buf.pop_front();
576 }
577 None => {
578 self.flattened_writes.advance(n);
579 }
580 }
581 continue;
582 }
583 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
584 }
585 }
586 }
587 fn poll_write_vectored(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
592 'outer: loop {
593 let mut writes = [IoSlice::new(b""); WRITE_VECTORED_CHUNKS];
594 let mut writes_len = 0;
595
596 self.write_buf
597 .iter()
598 .take(WRITE_VECTORED_CHUNKS)
599 .enumerate()
600 .for_each(|(i, buf)| {
601 writes[i] = IoSlice::new(buf);
602 writes_len += 1;
603 });
604
605 if writes_len < WRITE_VECTORED_CHUNKS && !self.flattened_writes.is_empty() {
606 writes[writes_len] = IoSlice::new(&self.flattened_writes);
607 writes_len += 1;
608 }
609
610 if writes_len == 0 {
611 return Poll::Ready(Ok(()));
612 }
613
614 match Pin::new(&mut self.stream).poll_write_vectored(cx, &writes[..writes_len]) {
615 Poll::Pending => return Poll::Pending,
616 Poll::Ready(Ok(mut n)) => {
617 self.statistics.out_bytes.add(n as u64, Ordering::Relaxed);
618 self.write_buf_len -= n;
619 self.can_flush = true;
620
621 while let Some(buf) = self.write_buf.front_mut() {
622 if n < buf.len() {
623 buf.advance(n);
624 continue 'outer;
625 }
626
627 n -= buf.len();
628 self.write_buf.pop_front();
629 }
630
631 self.flattened_writes.advance(n);
632 }
633 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
634 }
635 }
636 }
637
638 fn write(&mut self, buf: impl Into<Bytes>) {
645 let buf = buf.into();
646 if buf.is_empty() {
647 return;
648 }
649
650 self.write_buf_len += buf.len();
651 if buf.len() < WRITE_FLATTEN_THRESHOLD {
652 self.flattened_writes.extend_from_slice(&buf);
653 } else {
654 if !self.flattened_writes.is_empty() {
655 let buf = self.flattened_writes.split().freeze();
656 self.write_buf.push_back(buf);
657 }
658
659 self.write_buf.push_back(buf);
660 }
661 }
662
663 fn small_write(&mut self) -> impl fmt::Write + '_ {
665 struct Writer<'a> {
666 this: &'a mut Connection,
667 }
668
669 impl fmt::Write for Writer<'_> {
670 fn write_str(&mut self, s: &str) -> fmt::Result {
671 self.this.write_buf_len += s.len();
672 self.this.flattened_writes.write_str(s)
673 }
674 }
675
676 Writer { this: self }
677 }
678
679 pub(crate) fn poll_flush(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
683 match Pin::new(&mut self.stream).poll_flush(cx) {
684 Poll::Pending => Poll::Pending,
685 Poll::Ready(Ok(())) => {
686 self.can_flush = false;
687 Poll::Ready(Ok(()))
688 }
689 Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
690 }
691 }
692}
693
694#[cfg(feature = "websockets")]
695#[pin_project]
696pub(crate) struct WebSocketAdapter<T> {
697 #[pin]
698 pub(crate) inner: WebSocketStream<T>,
699 pub(crate) read_buf: BytesMut,
700}
701
702#[cfg(feature = "websockets")]
703impl<T> WebSocketAdapter<T> {
704 pub(crate) fn new(inner: WebSocketStream<T>) -> Self {
705 Self {
706 inner,
707 read_buf: BytesMut::new(),
708 }
709 }
710}
711
712#[cfg(feature = "websockets")]
713impl<T> AsyncRead for WebSocketAdapter<T>
714where
715 T: AsyncRead + AsyncWrite + Unpin,
716{
717 fn poll_read(
718 self: Pin<&mut Self>,
719 cx: &mut Context<'_>,
720 buf: &mut ReadBuf<'_>,
721 ) -> Poll<std::io::Result<()>> {
722 let mut this = self.project();
723
724 loop {
725 if !this.read_buf.is_empty() {
727 let len = std::cmp::min(buf.remaining(), this.read_buf.len());
728 buf.put_slice(&this.read_buf.split_to(len));
729 return Poll::Ready(Ok(()));
730 }
731
732 match this.inner.poll_next_unpin(cx) {
733 Poll::Ready(Some(Ok(message))) => {
734 this.read_buf.extend_from_slice(message.as_payload());
735 }
736 Poll::Ready(Some(Err(e))) => {
737 return Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)));
738 }
739 Poll::Ready(None) => {
740 return Poll::Ready(Err(std::io::Error::new(
741 std::io::ErrorKind::UnexpectedEof,
742 "WebSocket closed",
743 )));
744 }
745 Poll::Pending => {
746 return Poll::Pending;
747 }
748 }
749 }
750 }
751}
752
753#[cfg(feature = "websockets")]
754impl<T> AsyncWrite for WebSocketAdapter<T>
755where
756 T: AsyncRead + AsyncWrite + Unpin,
757{
758 fn poll_write(
759 self: Pin<&mut Self>,
760 cx: &mut Context<'_>,
761 buf: &[u8],
762 ) -> Poll<std::io::Result<usize>> {
763 let mut this = self.project();
764
765 let data = buf.to_vec();
766 match this.inner.poll_ready_unpin(cx) {
767 Poll::Ready(Ok(())) => match this
768 .inner
769 .start_send_unpin(tokio_websockets::Message::binary(data))
770 {
771 Ok(()) => Poll::Ready(Ok(buf.len())),
772 Err(e) => Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e))),
773 },
774 Poll::Ready(Err(e)) => {
775 Poll::Ready(Err(std::io::Error::new(std::io::ErrorKind::Other, e)))
776 }
777 Poll::Pending => Poll::Pending,
778 }
779 }
780
781 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
782 self.project()
783 .inner
784 .poll_flush_unpin(cx)
785 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
786 }
787
788 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
789 self.project()
790 .inner
791 .poll_close_unpin(cx)
792 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
793 }
794}
795
796#[cfg(test)]
797mod read_op {
798 use std::sync::Arc;
799
800 use super::Connection;
801 use crate::{HeaderMap, ServerError, ServerInfo, ServerOp, Statistics, StatusCode};
802 use tokio::io::{self, AsyncWriteExt};
803
804 #[tokio::test]
805 async fn ok() {
806 let (stream, mut server) = io::duplex(128);
807 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
808
809 server.write_all(b"+OK\r\n").await.unwrap();
810 let result = connection.read_op().await.unwrap();
811 assert_eq!(result, Some(ServerOp::Ok));
812 }
813
814 #[tokio::test]
815 async fn ping() {
816 let (stream, mut server) = io::duplex(128);
817 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
818
819 server.write_all(b"PING\r\n").await.unwrap();
820 let result = connection.read_op().await.unwrap();
821 assert_eq!(result, Some(ServerOp::Ping));
822 }
823
824 #[tokio::test]
825 async fn pong() {
826 let (stream, mut server) = io::duplex(128);
827 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
828
829 server.write_all(b"PONG\r\n").await.unwrap();
830 let result = connection.read_op().await.unwrap();
831 assert_eq!(result, Some(ServerOp::Pong));
832 }
833
834 #[tokio::test]
835 async fn info() {
836 let (stream, mut server) = io::duplex(128);
837 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
838
839 server.write_all(b"INFO {}\r\n").await.unwrap();
840 server.flush().await.unwrap();
841
842 let result = connection.read_op().await.unwrap();
843 assert_eq!(result, Some(ServerOp::Info(Box::default())));
844
845 server
846 .write_all(b"INFO { \"version\": \"1.0.0\" }\r\n")
847 .await
848 .unwrap();
849 server.flush().await.unwrap();
850
851 let result = connection.read_op().await.unwrap();
852 assert_eq!(
853 result,
854 Some(ServerOp::Info(Box::new(ServerInfo {
855 version: "1.0.0".into(),
856 ..Default::default()
857 })))
858 );
859 }
860
861 #[tokio::test]
862 async fn error() {
863 let (stream, mut server) = io::duplex(128);
864 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
865
866 server.write_all(b"INFO {}\r\n").await.unwrap();
867 let result = connection.read_op().await.unwrap();
868 assert_eq!(result, Some(ServerOp::Info(Box::default())));
869
870 server
871 .write_all(b"-ERR something went wrong\r\n")
872 .await
873 .unwrap();
874 let result = connection.read_op().await.unwrap();
875 assert_eq!(
876 result,
877 Some(ServerOp::Error(ServerError::Other(
878 "something went wrong".into()
879 )))
880 );
881 }
882
883 #[tokio::test]
884 async fn message() {
885 let (stream, mut server) = io::duplex(128);
886 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
887
888 server
889 .write_all(b"MSG FOO.BAR 9 11\r\nHello World\r\n")
890 .await
891 .unwrap();
892
893 let result = connection.read_op().await.unwrap();
894 assert_eq!(
895 result,
896 Some(ServerOp::Message {
897 sid: 9,
898 subject: "FOO.BAR".into(),
899 reply: None,
900 headers: None,
901 payload: "Hello World".into(),
902 status: None,
903 description: None,
904 length: 7 + 11,
905 })
906 );
907
908 server
909 .write_all(b"MSG FOO.BAR 9 INBOX.34 11\r\nHello World\r\n")
910 .await
911 .unwrap();
912
913 let result = connection.read_op().await.unwrap();
914 assert_eq!(
915 result,
916 Some(ServerOp::Message {
917 sid: 9,
918 subject: "FOO.BAR".into(),
919 reply: Some("INBOX.34".into()),
920 headers: None,
921 payload: "Hello World".into(),
922 status: None,
923 description: None,
924 length: 7 + 8 + 11,
925 })
926 );
927
928 server
929 .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
930 .await
931 .unwrap();
932 server.write_all(b"NATS/1.0\r\n").await.unwrap();
933 server.write_all(b"Header: X\r\n").await.unwrap();
934 server.write_all(b"\r\n").await.unwrap();
935 server.write_all(b"Hello World\r\n").await.unwrap();
936
937 let result = connection.read_op().await.unwrap();
938
939 assert_eq!(
940 result,
941 Some(ServerOp::Message {
942 sid: 10,
943 subject: "FOO.BAR".into(),
944 reply: Some("INBOX.35".into()),
945 headers: Some(HeaderMap::from_iter([(
946 "Header".parse().unwrap(),
947 "X".parse().unwrap()
948 )])),
949 payload: "Hello World".into(),
950 status: None,
951 description: None,
952 length: 7 + 8 + 34
953 })
954 );
955
956 server
957 .write_all(b"HMSG FOO.BAR 10 INBOX.35 23 34\r\n")
958 .await
959 .unwrap();
960 server.write_all(b"NATS/1.0\r\n").await.unwrap();
961 server.write_all(b"Header: Y\r\n").await.unwrap();
962 server.write_all(b"\r\n").await.unwrap();
963 server.write_all(b"Hello World\r\n").await.unwrap();
964
965 let result = connection.read_op().await.unwrap();
966 assert_eq!(
967 result,
968 Some(ServerOp::Message {
969 sid: 10,
970 subject: "FOO.BAR".into(),
971 reply: Some("INBOX.35".into()),
972 headers: Some(HeaderMap::from_iter([(
973 "Header".parse().unwrap(),
974 "Y".parse().unwrap()
975 )])),
976 payload: "Hello World".into(),
977 status: None,
978 description: None,
979 length: 7 + 8 + 34,
980 })
981 );
982
983 server
984 .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
985 .await
986 .unwrap();
987 server
988 .write_all(b"NATS/1.0 404 No Messages\r\n")
989 .await
990 .unwrap();
991 server.write_all(b"\r\n").await.unwrap();
992 server.write_all(b"\r\n").await.unwrap();
993
994 let result = connection.read_op().await.unwrap();
995 assert_eq!(
996 result,
997 Some(ServerOp::Message {
998 sid: 10,
999 subject: "FOO.BAR".into(),
1000 reply: Some("INBOX.35".into()),
1001 headers: Some(HeaderMap::default()),
1002 payload: "".into(),
1003 status: Some(StatusCode::NOT_FOUND),
1004 description: Some("No Messages".to_string()),
1005 length: 7 + 8 + 28,
1006 })
1007 );
1008
1009 server
1010 .write_all(b"MSG FOO.BAR 9 11\r\nHello Again\r\n")
1011 .await
1012 .unwrap();
1013
1014 let result = connection.read_op().await.unwrap();
1015 assert_eq!(
1016 result,
1017 Some(ServerOp::Message {
1018 sid: 9,
1019 subject: "FOO.BAR".into(),
1020 reply: None,
1021 headers: None,
1022 payload: "Hello Again".into(),
1023 status: None,
1024 description: None,
1025 length: 7 + 11,
1026 })
1027 );
1028 }
1029
1030 #[tokio::test]
1031 async fn unknown() {
1032 let (stream, mut server) = io::duplex(128);
1033 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1034
1035 server.write_all(b"ONE\r\n").await.unwrap();
1036 connection.read_op().await.unwrap_err();
1037
1038 server.write_all(b"TWO\r\n").await.unwrap();
1039 connection.read_op().await.unwrap_err();
1040
1041 server.write_all(b"PING\r\n").await.unwrap();
1042 connection.read_op().await.unwrap();
1043
1044 server.write_all(b"THREE\r\n").await.unwrap();
1045 connection.read_op().await.unwrap_err();
1046
1047 server
1048 .write_all(b"HMSG FOO.BAR 10 INBOX.35 28 28\r\n")
1049 .await
1050 .unwrap();
1051 server
1052 .write_all(b"NATS/1.0 404 No Messages\r\n")
1053 .await
1054 .unwrap();
1055 server.write_all(b"\r\n").await.unwrap();
1056 server.write_all(b"\r\n").await.unwrap();
1057
1058 let result = connection.read_op().await.unwrap();
1059 assert_eq!(
1060 result,
1061 Some(ServerOp::Message {
1062 sid: 10,
1063 subject: "FOO.BAR".into(),
1064 reply: Some("INBOX.35".into()),
1065 headers: Some(HeaderMap::default()),
1066 payload: "".into(),
1067 status: Some(StatusCode::NOT_FOUND),
1068 description: Some("No Messages".to_string()),
1069 length: 7 + 8 + 28,
1070 })
1071 );
1072
1073 server.write_all(b"FOUR\r\n").await.unwrap();
1074 connection.read_op().await.unwrap_err();
1075
1076 server.write_all(b"PONG\r\n").await.unwrap();
1077 connection.read_op().await.unwrap();
1078 }
1079}
1080
1081#[cfg(test)]
1082mod write_op {
1083 use std::sync::Arc;
1084
1085 use super::Connection;
1086 use crate::{ClientOp, ConnectInfo, HeaderMap, Protocol, Statistics};
1087 use tokio::io::{self, AsyncBufReadExt, BufReader};
1088
1089 #[tokio::test]
1090 async fn publish() {
1091 let (stream, server) = io::duplex(128);
1092 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1093
1094 connection
1095 .easy_write_and_flush(
1096 [ClientOp::Publish {
1097 subject: "FOO.BAR".into(),
1098 payload: "Hello World".into(),
1099 respond: None,
1100 headers: None,
1101 }]
1102 .iter(),
1103 )
1104 .await
1105 .unwrap();
1106
1107 let mut buffer = String::new();
1108 let mut reader = BufReader::new(server);
1109 reader.read_line(&mut buffer).await.unwrap();
1110 reader.read_line(&mut buffer).await.unwrap();
1111 assert_eq!(buffer, "PUB FOO.BAR 11\r\nHello World\r\n");
1112
1113 connection
1114 .easy_write_and_flush(
1115 [ClientOp::Publish {
1116 subject: "FOO.BAR".into(),
1117 payload: "Hello World".into(),
1118 respond: Some("INBOX.67".into()),
1119 headers: None,
1120 }]
1121 .iter(),
1122 )
1123 .await
1124 .unwrap();
1125
1126 buffer.clear();
1127 reader.read_line(&mut buffer).await.unwrap();
1128 reader.read_line(&mut buffer).await.unwrap();
1129 assert_eq!(buffer, "PUB FOO.BAR INBOX.67 11\r\nHello World\r\n");
1130
1131 connection
1132 .easy_write_and_flush(
1133 [ClientOp::Publish {
1134 subject: "FOO.BAR".into(),
1135 payload: "Hello World".into(),
1136 respond: Some("INBOX.67".into()),
1137 headers: Some(HeaderMap::from_iter([(
1138 "Header".parse().unwrap(),
1139 "X".parse().unwrap(),
1140 )])),
1141 }]
1142 .iter(),
1143 )
1144 .await
1145 .unwrap();
1146
1147 buffer.clear();
1148 reader.read_line(&mut buffer).await.unwrap();
1149 reader.read_line(&mut buffer).await.unwrap();
1150 reader.read_line(&mut buffer).await.unwrap();
1151 reader.read_line(&mut buffer).await.unwrap();
1152 assert_eq!(
1153 buffer,
1154 "HPUB FOO.BAR INBOX.67 23 34\r\nNATS/1.0\r\nHeader: X\r\n\r\n"
1155 );
1156 }
1157
1158 #[tokio::test]
1159 async fn subscribe() {
1160 let (stream, server) = io::duplex(128);
1161 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1162
1163 connection
1164 .easy_write_and_flush(
1165 [ClientOp::Subscribe {
1166 sid: 11,
1167 subject: "FOO.BAR".into(),
1168 queue_group: None,
1169 }]
1170 .iter(),
1171 )
1172 .await
1173 .unwrap();
1174
1175 let mut buffer = String::new();
1176 let mut reader = BufReader::new(server);
1177 reader.read_line(&mut buffer).await.unwrap();
1178 assert_eq!(buffer, "SUB FOO.BAR 11\r\n");
1179
1180 connection
1181 .easy_write_and_flush(
1182 [ClientOp::Subscribe {
1183 sid: 11,
1184 subject: "FOO.BAR".into(),
1185 queue_group: Some("QUEUE.GROUP".into()),
1186 }]
1187 .iter(),
1188 )
1189 .await
1190 .unwrap();
1191
1192 buffer.clear();
1193 reader.read_line(&mut buffer).await.unwrap();
1194 assert_eq!(buffer, "SUB FOO.BAR QUEUE.GROUP 11\r\n");
1195 }
1196
1197 #[tokio::test]
1198 async fn unsubscribe() {
1199 let (stream, server) = io::duplex(128);
1200 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1201
1202 connection
1203 .easy_write_and_flush([ClientOp::Unsubscribe { sid: 11, max: None }].iter())
1204 .await
1205 .unwrap();
1206
1207 let mut buffer = String::new();
1208 let mut reader = BufReader::new(server);
1209 reader.read_line(&mut buffer).await.unwrap();
1210 assert_eq!(buffer, "UNSUB 11\r\n");
1211
1212 connection
1213 .easy_write_and_flush(
1214 [ClientOp::Unsubscribe {
1215 sid: 11,
1216 max: Some(2),
1217 }]
1218 .iter(),
1219 )
1220 .await
1221 .unwrap();
1222
1223 buffer.clear();
1224 reader.read_line(&mut buffer).await.unwrap();
1225 assert_eq!(buffer, "UNSUB 11 2\r\n");
1226 }
1227
1228 #[tokio::test]
1229 async fn ping() {
1230 let (stream, server) = io::duplex(128);
1231 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1232
1233 let mut reader = BufReader::new(server);
1234 let mut buffer = String::new();
1235
1236 connection
1237 .easy_write_and_flush([ClientOp::Ping].iter())
1238 .await
1239 .unwrap();
1240
1241 reader.read_line(&mut buffer).await.unwrap();
1242
1243 assert_eq!(buffer, "PING\r\n");
1244 }
1245
1246 #[tokio::test]
1247 async fn pong() {
1248 let (stream, server) = io::duplex(128);
1249 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1250
1251 let mut reader = BufReader::new(server);
1252 let mut buffer = String::new();
1253
1254 connection
1255 .easy_write_and_flush([ClientOp::Pong].iter())
1256 .await
1257 .unwrap();
1258
1259 reader.read_line(&mut buffer).await.unwrap();
1260
1261 assert_eq!(buffer, "PONG\r\n");
1262 }
1263
1264 #[tokio::test]
1265 async fn connect() {
1266 let (stream, server) = io::duplex(1024);
1267 let mut connection = Connection::new(Box::new(stream), 0, Arc::new(Statistics::default()));
1268
1269 let mut reader = BufReader::new(server);
1270 let mut buffer = String::new();
1271
1272 connection
1273 .easy_write_and_flush(
1274 [ClientOp::Connect(ConnectInfo {
1275 verbose: false,
1276 pedantic: false,
1277 user_jwt: None,
1278 nkey: None,
1279 signature: None,
1280 name: None,
1281 echo: false,
1282 lang: "Rust".into(),
1283 version: "1.0.0".into(),
1284 protocol: Protocol::Dynamic,
1285 tls_required: false,
1286 user: None,
1287 pass: None,
1288 auth_token: None,
1289 headers: false,
1290 no_responders: false,
1291 })]
1292 .iter(),
1293 )
1294 .await
1295 .unwrap();
1296
1297 reader.read_line(&mut buffer).await.unwrap();
1298 assert_eq!(
1299 buffer,
1300 "CONNECT {\"verbose\":false,\"pedantic\":false,\"jwt\":null,\"nkey\":null,\"sig\":null,\"name\":null,\"echo\":false,\"lang\":\"Rust\",\"version\":\"1.0.0\",\"protocol\":1,\"tls_required\":false,\"user\":null,\"pass\":null,\"auth_token\":null,\"headers\":false,\"no_responders\":false}\r\n"
1301 );
1302 }
1303}