1#![deny(unreachable_pub)]
191#![deny(rustdoc::broken_intra_doc_links)]
192#![deny(rustdoc::private_intra_doc_links)]
193#![deny(rustdoc::invalid_codeblock_attributes)]
194#![deny(rustdoc::invalid_rust_codeblocks)]
195#![cfg_attr(docsrs, feature(doc_auto_cfg))]
196
197use thiserror::Error;
198
199use futures::stream::Stream;
200use tokio::io::AsyncWriteExt;
201use tokio::sync::oneshot;
202use tracing::{debug, error};
203
204use core::fmt;
205use std::collections::HashMap;
206use std::fmt::Display;
207use std::future::Future;
208use std::iter;
209use std::mem;
210use std::net::SocketAddr;
211use std::option;
212use std::pin::Pin;
213use std::slice;
214use std::str::{self, FromStr};
215use std::sync::atomic::AtomicUsize;
216use std::sync::atomic::Ordering;
217use std::sync::Arc;
218use std::task::{Context, Poll};
219use tokio::io::ErrorKind;
220use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
221use url::{Host, Url};
222
223use bytes::Bytes;
224use serde::{Deserialize, Serialize};
225use serde_repr::{Deserialize_repr, Serialize_repr};
226use tokio::io;
227use tokio::sync::mpsc;
228use tokio::task;
229
230pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
231
232const VERSION: &str = env!("CARGO_PKG_VERSION");
233const LANG: &str = "rust";
234const MAX_PENDING_PINGS: usize = 2;
235const MULTIPLEXER_SID: u64 = 0;
236
237pub use tokio_rustls::rustls;
241
242use connection::{Connection, State};
243use connector::{Connector, ConnectorOptions};
244pub use header::{HeaderMap, HeaderName, HeaderValue};
245pub use subject::Subject;
246
247mod auth;
248pub(crate) mod auth_utils;
249pub mod client;
250pub mod connection;
251mod connector;
252mod options;
253
254pub use auth::Auth;
255pub use client::{
256 Client, PublishError, Request, RequestError, RequestErrorKind, Statistics, SubscribeError,
257};
258pub use options::{AuthError, ConnectOptions};
259
260mod crypto;
261pub mod error;
262pub mod header;
263pub mod jetstream;
264pub mod message;
265#[cfg(feature = "service")]
266pub mod service;
267pub mod status;
268pub mod subject;
269mod tls;
270
271pub use message::Message;
272pub use status::StatusCode;
273
274#[derive(Debug, Deserialize, Default, Clone, Eq, PartialEq)]
277pub struct ServerInfo {
278 #[serde(default)]
280 pub server_id: String,
281 #[serde(default)]
283 pub server_name: String,
284 #[serde(default)]
286 pub host: String,
287 #[serde(default)]
289 pub port: u16,
290 #[serde(default)]
292 pub version: String,
293 #[serde(default)]
296 pub auth_required: bool,
297 #[serde(default)]
299 pub tls_required: bool,
300 #[serde(default)]
302 pub max_payload: usize,
303 #[serde(default)]
305 pub proto: i8,
306 #[serde(default)]
308 pub client_id: u64,
309 #[serde(default)]
311 pub go: String,
312 #[serde(default)]
314 pub nonce: String,
315 #[serde(default)]
317 pub connect_urls: Vec<String>,
318 #[serde(default)]
320 pub client_ip: String,
321 #[serde(default)]
323 pub headers: bool,
324 #[serde(default, rename = "ldm")]
326 pub lame_duck_mode: bool,
327}
328
329#[derive(Clone, Debug, Eq, PartialEq)]
330pub(crate) enum ServerOp {
331 Ok,
332 Info(Box<ServerInfo>),
333 Ping,
334 Pong,
335 Error(ServerError),
336 Message {
337 sid: u64,
338 subject: Subject,
339 reply: Option<Subject>,
340 payload: Bytes,
341 headers: Option<HeaderMap>,
342 status: Option<StatusCode>,
343 description: Option<String>,
344 length: usize,
345 },
346}
347
348#[derive(Debug)]
350pub struct PublishMessage {
351 pub subject: Subject,
352 pub payload: Bytes,
353 pub reply: Option<Subject>,
354 pub headers: Option<HeaderMap>,
355}
356
357#[derive(Debug)]
359pub(crate) enum Command {
360 Publish(PublishMessage),
361 Request {
362 subject: Subject,
363 payload: Bytes,
364 respond: Subject,
365 headers: Option<HeaderMap>,
366 sender: oneshot::Sender<Message>,
367 },
368 Subscribe {
369 sid: u64,
370 subject: Subject,
371 queue_group: Option<String>,
372 sender: mpsc::Sender<Message>,
373 },
374 Unsubscribe {
375 sid: u64,
376 max: Option<u64>,
377 },
378 Flush {
379 observer: oneshot::Sender<()>,
380 },
381 Drain {
382 sid: Option<u64>,
383 },
384 Reconnect,
385}
386
387#[derive(Debug)]
389pub(crate) enum ClientOp {
390 Publish {
391 subject: Subject,
392 payload: Bytes,
393 respond: Option<Subject>,
394 headers: Option<HeaderMap>,
395 },
396 Subscribe {
397 sid: u64,
398 subject: Subject,
399 queue_group: Option<String>,
400 },
401 Unsubscribe {
402 sid: u64,
403 max: Option<u64>,
404 },
405 Ping,
406 Pong,
407 Connect(ConnectInfo),
408}
409
410#[derive(Debug)]
411struct Subscription {
412 subject: Subject,
413 sender: mpsc::Sender<Message>,
414 queue_group: Option<String>,
415 delivered: u64,
416 max: Option<u64>,
417 is_draining: bool,
418}
419
420#[derive(Debug)]
421struct Multiplexer {
422 subject: Subject,
423 prefix: Subject,
424 senders: HashMap<String, oneshot::Sender<Message>>,
425}
426
427pub(crate) struct ConnectionHandler {
429 connection: Connection,
430 connector: Connector,
431 subscriptions: HashMap<u64, Subscription>,
432 multiplexer: Option<Multiplexer>,
433 pending_pings: usize,
434 info_sender: tokio::sync::watch::Sender<ServerInfo>,
435 ping_interval: Interval,
436 should_reconnect: bool,
437 flush_observers: Vec<oneshot::Sender<()>>,
438 is_draining: bool,
439}
440
441impl ConnectionHandler {
442 pub(crate) fn new(
443 connection: Connection,
444 connector: Connector,
445 info_sender: tokio::sync::watch::Sender<ServerInfo>,
446 ping_period: Duration,
447 ) -> ConnectionHandler {
448 let mut ping_interval = interval(ping_period);
449 ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
450
451 ConnectionHandler {
452 connection,
453 connector,
454 subscriptions: HashMap::new(),
455 multiplexer: None,
456 pending_pings: 0,
457 info_sender,
458 ping_interval,
459 should_reconnect: false,
460 flush_observers: Vec::new(),
461 is_draining: false,
462 }
463 }
464
465 pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver<Command>) {
466 struct ProcessFut<'a> {
467 handler: &'a mut ConnectionHandler,
468 receiver: &'a mut mpsc::Receiver<Command>,
469 recv_buf: &'a mut Vec<Command>,
470 }
471
472 enum ExitReason {
473 Disconnected(Option<io::Error>),
474 ReconnectRequested,
475 Closed,
476 }
477
478 impl ProcessFut<'_> {
479 const RECV_CHUNK_SIZE: usize = 16;
480
481 #[cold]
482 fn ping(&mut self) -> Poll<ExitReason> {
483 self.handler.pending_pings += 1;
484
485 if self.handler.pending_pings > MAX_PENDING_PINGS {
486 debug!(
487 "pending pings {}, max pings {}. disconnecting",
488 self.handler.pending_pings, MAX_PENDING_PINGS
489 );
490
491 Poll::Ready(ExitReason::Disconnected(None))
492 } else {
493 self.handler.connection.enqueue_write_op(&ClientOp::Ping);
494
495 Poll::Pending
496 }
497 }
498 }
499
500 impl Future for ProcessFut<'_> {
501 type Output = ExitReason;
502
503 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
517 while self.handler.ping_interval.poll_tick(cx).is_ready() {
521 if let Poll::Ready(exit) = self.ping() {
522 return Poll::Ready(exit);
523 }
524 }
525
526 loop {
527 match self.handler.connection.poll_read_op(cx) {
528 Poll::Pending => break,
529 Poll::Ready(Ok(Some(server_op))) => {
530 self.handler.handle_server_op(server_op);
531 }
532 Poll::Ready(Ok(None)) => {
533 return Poll::Ready(ExitReason::Disconnected(None))
534 }
535 Poll::Ready(Err(err)) => {
536 return Poll::Ready(ExitReason::Disconnected(Some(err)))
537 }
538 }
539 }
540
541 self.handler.subscriptions.retain(|_, s| !s.is_draining);
546
547 if self.handler.is_draining {
548 return Poll::Ready(ExitReason::Closed);
553 }
554
555 let mut made_progress = true;
561 loop {
562 while !self.handler.connection.is_write_buf_full() {
563 debug_assert!(self.recv_buf.is_empty());
564
565 let Self {
566 recv_buf,
567 handler,
568 receiver,
569 } = &mut *self;
570 match receiver.poll_recv_many(cx, recv_buf, Self::RECV_CHUNK_SIZE) {
571 Poll::Pending => break,
572 Poll::Ready(1..) => {
573 made_progress = true;
574
575 for cmd in recv_buf.drain(..) {
576 handler.handle_command(cmd);
577 }
578 }
579 Poll::Ready(_) => return Poll::Ready(ExitReason::Closed),
581 }
582 }
583
584 if !mem::take(&mut made_progress) {
595 break;
596 }
597
598 match self.handler.connection.poll_write(cx) {
599 Poll::Pending => {
600 break;
602 }
603 Poll::Ready(Ok(())) => {
604 continue;
606 }
607 Poll::Ready(Err(err)) => {
608 return Poll::Ready(ExitReason::Disconnected(Some(err)))
609 }
610 }
611 }
612
613 if let (ShouldFlush::Yes, _) | (ShouldFlush::No, false) = (
614 self.handler.connection.should_flush(),
615 self.handler.flush_observers.is_empty(),
616 ) {
617 match self.handler.connection.poll_flush(cx) {
618 Poll::Pending => {}
619 Poll::Ready(Ok(())) => {
620 for observer in self.handler.flush_observers.drain(..) {
621 let _ = observer.send(());
622 }
623 }
624 Poll::Ready(Err(err)) => {
625 return Poll::Ready(ExitReason::Disconnected(Some(err)))
626 }
627 }
628 }
629
630 if mem::take(&mut self.handler.should_reconnect) {
631 return Poll::Ready(ExitReason::ReconnectRequested);
632 }
633
634 Poll::Pending
635 }
636 }
637
638 let mut recv_buf = Vec::with_capacity(ProcessFut::RECV_CHUNK_SIZE);
639 loop {
640 let process = ProcessFut {
641 handler: self,
642 receiver,
643 recv_buf: &mut recv_buf,
644 };
645 match process.await {
646 ExitReason::Disconnected(err) => {
647 debug!(?err, "disconnected");
648 if self.handle_disconnect().await.is_err() {
649 break;
650 };
651 debug!("reconnected");
652 }
653 ExitReason::Closed => {
654 self.connector.events_tx.try_send(Event::Closed).ok();
656 break;
657 }
658 ExitReason::ReconnectRequested => {
659 debug!("reconnect requested");
660 self.connection.stream.shutdown().await.ok();
662 if self.handle_disconnect().await.is_err() {
663 break;
664 };
665 }
666 }
667 }
668 }
669
670 fn handle_server_op(&mut self, server_op: ServerOp) {
671 self.ping_interval.reset();
672
673 match server_op {
674 ServerOp::Ping => {
675 self.connection.enqueue_write_op(&ClientOp::Pong);
676 }
677 ServerOp::Pong => {
678 debug!("received PONG");
679 self.pending_pings = self.pending_pings.saturating_sub(1);
680 }
681 ServerOp::Error(error) => {
682 self.connector
683 .events_tx
684 .try_send(Event::ServerError(error))
685 .ok();
686 }
687 ServerOp::Message {
688 sid,
689 subject,
690 reply,
691 payload,
692 headers,
693 status,
694 description,
695 length,
696 } => {
697 self.connector
698 .connect_stats
699 .in_messages
700 .add(1, Ordering::Relaxed);
701
702 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
703 let message: Message = Message {
704 subject,
705 reply,
706 payload,
707 headers,
708 status,
709 description,
710 length,
711 };
712
713 match subscription.sender.try_send(message) {
716 Ok(_) => {
717 subscription.delivered += 1;
718 if let Some(max) = subscription.max {
722 if subscription.delivered.ge(&max) {
723 self.subscriptions.remove(&sid);
724 }
725 }
726 }
727 Err(mpsc::error::TrySendError::Full(_)) => {
728 self.connector
729 .events_tx
730 .try_send(Event::SlowConsumer(sid))
731 .ok();
732 }
733 Err(mpsc::error::TrySendError::Closed(_)) => {
734 self.subscriptions.remove(&sid);
735 self.connection
736 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
737 }
738 }
739 } else if sid == MULTIPLEXER_SID {
740 if let Some(multiplexer) = self.multiplexer.as_mut() {
741 let maybe_token =
742 subject.strip_prefix(multiplexer.prefix.as_ref()).to_owned();
743
744 if let Some(token) = maybe_token {
745 if let Some(sender) = multiplexer.senders.remove(token) {
746 let message = Message {
747 subject,
748 reply,
749 payload,
750 headers,
751 status,
752 description,
753 length,
754 };
755
756 let _ = sender.send(message);
757 }
758 }
759 }
760 }
761 }
762 ServerOp::Info(info) => {
764 if info.lame_duck_mode {
765 self.connector.events_tx.try_send(Event::LameDuckMode).ok();
766 }
767 }
768
769 _ => {
770 }
772 }
773 }
774
775 fn handle_command(&mut self, command: Command) {
776 self.ping_interval.reset();
777
778 match command {
779 Command::Unsubscribe { sid, max } => {
780 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
781 subscription.max = max;
782 match subscription.max {
783 Some(n) => {
784 if subscription.delivered >= n {
785 self.subscriptions.remove(&sid);
786 }
787 }
788 None => {
789 self.subscriptions.remove(&sid);
790 }
791 }
792
793 self.connection
794 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max });
795 }
796 }
797 Command::Flush { observer } => {
798 self.flush_observers.push(observer);
799 }
800 Command::Drain { sid } => {
801 let mut drain_sub = |sid: u64, sub: &mut Subscription| {
802 sub.is_draining = true;
803 self.connection
804 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
805 };
806
807 if let Some(sid) = sid {
808 if let Some(sub) = self.subscriptions.get_mut(&sid) {
809 drain_sub(sid, sub);
810 }
811 } else {
812 self.connector.events_tx.try_send(Event::Draining).ok();
814 self.is_draining = true;
815 for (&sid, sub) in self.subscriptions.iter_mut() {
816 drain_sub(sid, sub);
817 }
818 }
819 }
820 Command::Subscribe {
821 sid,
822 subject,
823 queue_group,
824 sender,
825 } => {
826 let subscription = Subscription {
827 sender,
828 delivered: 0,
829 max: None,
830 subject: subject.to_owned(),
831 queue_group: queue_group.to_owned(),
832 is_draining: false,
833 };
834
835 self.subscriptions.insert(sid, subscription);
836
837 self.connection.enqueue_write_op(&ClientOp::Subscribe {
838 sid,
839 subject,
840 queue_group,
841 });
842 }
843 Command::Request {
844 subject,
845 payload,
846 respond,
847 headers,
848 sender,
849 } => {
850 let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject");
851
852 let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() {
853 multiplexer
854 } else {
855 let prefix = Subject::from(format!("{}.{}.", prefix, nuid::next()));
856 let subject = Subject::from(format!("{}*", prefix));
857
858 self.connection.enqueue_write_op(&ClientOp::Subscribe {
859 sid: MULTIPLEXER_SID,
860 subject: subject.clone(),
861 queue_group: None,
862 });
863
864 self.multiplexer.insert(Multiplexer {
865 subject,
866 prefix,
867 senders: HashMap::new(),
868 })
869 };
870 self.connector
871 .connect_stats
872 .out_messages
873 .add(1, Ordering::Relaxed);
874
875 multiplexer.senders.insert(token.to_owned(), sender);
876
877 let respond: Subject = format!("{}{}", multiplexer.prefix, token).into();
878
879 let pub_op = ClientOp::Publish {
880 subject,
881 payload,
882 respond: Some(respond),
883 headers,
884 };
885
886 self.connection.enqueue_write_op(&pub_op);
887 }
888
889 Command::Publish(PublishMessage {
890 subject,
891 payload,
892 reply: respond,
893 headers,
894 }) => {
895 self.connector
896 .connect_stats
897 .out_messages
898 .add(1, Ordering::Relaxed);
899
900 let header_len = headers
901 .as_ref()
902 .map(|headers| headers.len())
903 .unwrap_or_default();
904
905 self.connector.connect_stats.out_bytes.add(
906 (payload.len()
907 + respond.as_ref().map_or_else(|| 0, |r| r.len())
908 + subject.len()
909 + header_len) as u64,
910 Ordering::Relaxed,
911 );
912
913 self.connection.enqueue_write_op(&ClientOp::Publish {
914 subject,
915 payload,
916 respond,
917 headers,
918 });
919 }
920
921 Command::Reconnect => {
922 self.should_reconnect = true;
923 }
924 }
925 }
926
927 async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
928 self.pending_pings = 0;
929 self.connector.events_tx.try_send(Event::Disconnected).ok();
930 self.connector.state_tx.send(State::Disconnected).ok();
931
932 self.handle_reconnect().await
933 }
934
935 async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
936 let (info, connection) = self.connector.connect().await?;
937 self.connection = connection;
938 let _ = self.info_sender.send(info);
939
940 self.subscriptions
941 .retain(|_, subscription| !subscription.sender.is_closed());
942
943 for (sid, subscription) in &self.subscriptions {
944 self.connection.enqueue_write_op(&ClientOp::Subscribe {
945 sid: *sid,
946 subject: subscription.subject.to_owned(),
947 queue_group: subscription.queue_group.to_owned(),
948 });
949 }
950
951 if let Some(multiplexer) = &self.multiplexer {
952 self.connection.enqueue_write_op(&ClientOp::Subscribe {
953 sid: MULTIPLEXER_SID,
954 subject: multiplexer.subject.to_owned(),
955 queue_group: None,
956 });
957 }
958 Ok(())
959 }
960}
961
962pub async fn connect_with_options<A: ToServerAddrs>(
978 addrs: A,
979 options: ConnectOptions,
980) -> Result<Client, ConnectError> {
981 let ping_period = options.ping_interval;
982
983 let (events_tx, mut events_rx) = mpsc::channel(128);
984 let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
985 let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));
987 let statistics = Arc::new(Statistics::default());
988
989 let mut connector = Connector::new(
990 addrs,
991 ConnectorOptions {
992 tls_required: options.tls_required,
993 certificates: options.certificates,
994 client_key: options.client_key,
995 client_cert: options.client_cert,
996 tls_client_config: options.tls_client_config,
997 tls_first: options.tls_first,
998 auth: options.auth,
999 no_echo: options.no_echo,
1000 connection_timeout: options.connection_timeout,
1001 name: options.name,
1002 ignore_discovered_servers: options.ignore_discovered_servers,
1003 retain_servers_order: options.retain_servers_order,
1004 read_buffer_capacity: options.read_buffer_capacity,
1005 reconnect_delay_callback: options.reconnect_delay_callback,
1006 auth_callback: options.auth_callback,
1007 max_reconnects: options.max_reconnects,
1008 },
1009 events_tx,
1010 state_tx,
1011 max_payload.clone(),
1012 statistics.clone(),
1013 )
1014 .map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;
1015
1016 let mut info: ServerInfo = Default::default();
1017 let mut connection = None;
1018 if !options.retry_on_initial_connect {
1019 debug!("retry on initial connect failure is disabled");
1020 let (info_ok, connection_ok) = connector.try_connect().await?;
1021 connection = Some(connection_ok);
1022 info = info_ok;
1023 }
1024
1025 let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
1026 let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
1027
1028 let client = Client::new(
1029 info_watcher,
1030 state_rx,
1031 sender,
1032 options.subscription_capacity,
1033 options.inbox_prefix,
1034 options.request_timeout,
1035 max_payload,
1036 statistics,
1037 );
1038
1039 task::spawn(async move {
1040 while let Some(event) = events_rx.recv().await {
1041 tracing::info!("event: {}", event);
1042 if let Some(event_callback) = &options.event_callback {
1043 event_callback.call(event).await;
1044 }
1045 }
1046 });
1047
1048 task::spawn(async move {
1049 if connection.is_none() && options.retry_on_initial_connect {
1050 let (info, connection_ok) = match connector.connect().await {
1051 Ok((info, connection)) => (info, connection),
1052 Err(err) => {
1053 error!("connection closed: {}", err);
1054 return;
1055 }
1056 };
1057 info_sender.send(info).ok();
1058 connection = Some(connection_ok);
1059 }
1060 let connection = connection.unwrap();
1061 let mut connection_handler =
1062 ConnectionHandler::new(connection, connector, info_sender, ping_period);
1063 connection_handler.process(&mut receiver).await
1064 });
1065
1066 Ok(client)
1067}
1068
1069#[derive(Debug, Clone, PartialEq, Eq)]
1070pub enum Event {
1071 Connected,
1072 Disconnected,
1073 LameDuckMode,
1074 Draining,
1075 Closed,
1076 SlowConsumer(u64),
1077 ServerError(ServerError),
1078 ClientError(ClientError),
1079}
1080
1081impl fmt::Display for Event {
1082 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1083 match self {
1084 Event::Connected => write!(f, "connected"),
1085 Event::Disconnected => write!(f, "disconnected"),
1086 Event::LameDuckMode => write!(f, "lame duck mode detected"),
1087 Event::Draining => write!(f, "draining"),
1088 Event::Closed => write!(f, "closed"),
1089 Event::SlowConsumer(sid) => write!(f, "slow consumers for subscription {sid}"),
1090 Event::ServerError(err) => write!(f, "server error: {err}"),
1091 Event::ClientError(err) => write!(f, "client error: {err}"),
1092 }
1093 }
1094}
1095
1096pub async fn connect<A: ToServerAddrs>(addrs: A) -> Result<Client, ConnectError> {
1163 connect_with_options(addrs, ConnectOptions::default()).await
1164}
1165
1166#[derive(Debug, Clone, Copy, PartialEq)]
1167pub enum ConnectErrorKind {
1168 ServerParse,
1170 Dns,
1172 Authentication,
1174 AuthorizationViolation,
1176 TimedOut,
1178 Tls,
1180 Io,
1182 MaxReconnects,
1184}
1185
1186impl Display for ConnectErrorKind {
1187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1188 match self {
1189 Self::ServerParse => write!(f, "failed to parse server or server list"),
1190 Self::Dns => write!(f, "DNS error"),
1191 Self::Authentication => write!(f, "failed signing nonce"),
1192 Self::AuthorizationViolation => write!(f, "authorization violation"),
1193 Self::TimedOut => write!(f, "timed out"),
1194 Self::Tls => write!(f, "TLS error"),
1195 Self::Io => write!(f, "IO error"),
1196 Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
1197 }
1198 }
1199}
1200
1201pub type ConnectError = error::Error<ConnectErrorKind>;
1204
1205impl From<io::Error> for ConnectError {
1206 fn from(err: io::Error) -> Self {
1207 ConnectError::with_source(ConnectErrorKind::Io, err)
1208 }
1209}
1210
1211#[derive(Debug)]
1225pub struct Subscriber {
1226 sid: u64,
1227 receiver: mpsc::Receiver<Message>,
1228 sender: mpsc::Sender<Command>,
1229}
1230
1231impl Subscriber {
1232 fn new(
1233 sid: u64,
1234 sender: mpsc::Sender<Command>,
1235 receiver: mpsc::Receiver<Message>,
1236 ) -> Subscriber {
1237 Subscriber {
1238 sid,
1239 sender,
1240 receiver,
1241 }
1242 }
1243
1244 pub async fn unsubscribe(&mut self) -> Result<(), UnsubscribeError> {
1259 self.sender
1260 .send(Command::Unsubscribe {
1261 sid: self.sid,
1262 max: None,
1263 })
1264 .await?;
1265 self.receiver.close();
1266 Ok(())
1267 }
1268
1269 pub async fn unsubscribe_after(&mut self, unsub_after: u64) -> Result<(), UnsubscribeError> {
1295 self.sender
1296 .send(Command::Unsubscribe {
1297 sid: self.sid,
1298 max: Some(unsub_after),
1299 })
1300 .await?;
1301 Ok(())
1302 }
1303
1304 pub async fn drain(&mut self) -> Result<(), UnsubscribeError> {
1337 self.sender
1338 .send(Command::Drain {
1339 sid: Some(self.sid),
1340 })
1341 .await?;
1342
1343 Ok(())
1344 }
1345}
1346
1347#[derive(Error, Debug, PartialEq)]
1348#[error("failed to send unsubscribe")]
1349pub struct UnsubscribeError(String);
1350
1351impl From<tokio::sync::mpsc::error::SendError<Command>> for UnsubscribeError {
1352 fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
1353 UnsubscribeError(err.to_string())
1354 }
1355}
1356
1357impl Drop for Subscriber {
1358 fn drop(&mut self) {
1359 self.receiver.close();
1360 tokio::spawn({
1361 let sender = self.sender.clone();
1362 let sid = self.sid;
1363 async move {
1364 sender
1365 .send(Command::Unsubscribe { sid, max: None })
1366 .await
1367 .ok();
1368 }
1369 });
1370 }
1371}
1372
1373impl Stream for Subscriber {
1374 type Item = Message;
1375
1376 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1377 self.receiver.poll_recv(cx)
1378 }
1379}
1380
1381#[derive(Clone, Debug, Eq, PartialEq)]
1382pub enum CallbackError {
1383 Client(ClientError),
1384 Server(ServerError),
1385}
1386impl std::fmt::Display for CallbackError {
1387 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1388 match self {
1389 Self::Client(error) => write!(f, "{error}"),
1390 Self::Server(error) => write!(f, "{error}"),
1391 }
1392 }
1393}
1394
1395impl From<ServerError> for CallbackError {
1396 fn from(server_error: ServerError) -> Self {
1397 CallbackError::Server(server_error)
1398 }
1399}
1400
1401impl From<ClientError> for CallbackError {
1402 fn from(client_error: ClientError) -> Self {
1403 CallbackError::Client(client_error)
1404 }
1405}
1406
1407#[derive(Clone, Debug, Eq, PartialEq, Error)]
1408pub enum ServerError {
1409 AuthorizationViolation,
1410 SlowConsumer(u64),
1411 Other(String),
1412}
1413
1414#[derive(Clone, Debug, Eq, PartialEq)]
1415pub enum ClientError {
1416 Other(String),
1417 MaxReconnects,
1418}
1419impl std::fmt::Display for ClientError {
1420 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1421 match self {
1422 Self::Other(error) => write!(f, "nats: {error}"),
1423 Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
1424 }
1425 }
1426}
1427
1428impl ServerError {
1429 fn new(error: String) -> ServerError {
1430 match error.to_lowercase().as_str() {
1431 "authorization violation" => ServerError::AuthorizationViolation,
1432 _ => ServerError::Other(error),
1434 }
1435 }
1436}
1437
1438impl std::fmt::Display for ServerError {
1439 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1440 match self {
1441 Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
1442 Self::SlowConsumer(sid) => write!(f, "nats: subscription {sid} is a slow consumer"),
1443 Self::Other(error) => write!(f, "nats: {error}"),
1444 }
1445 }
1446}
1447
1448#[derive(Clone, Debug, Serialize)]
1450pub struct ConnectInfo {
1451 pub verbose: bool,
1453
1454 pub pedantic: bool,
1457
1458 #[serde(rename = "jwt")]
1460 pub user_jwt: Option<String>,
1461
1462 pub nkey: Option<String>,
1464
1465 #[serde(rename = "sig")]
1467 pub signature: Option<String>,
1468
1469 pub name: Option<String>,
1471
1472 pub echo: bool,
1477
1478 pub lang: String,
1480
1481 pub version: String,
1483
1484 pub protocol: Protocol,
1489
1490 pub tls_required: bool,
1492
1493 pub user: Option<String>,
1495
1496 pub pass: Option<String>,
1498
1499 pub auth_token: Option<String>,
1501
1502 pub headers: bool,
1504
1505 pub no_responders: bool,
1507}
1508
1509#[derive(Serialize_repr, Deserialize_repr, PartialEq, Eq, Debug, Clone, Copy)]
1511#[repr(u8)]
1512pub enum Protocol {
1513 Original = 0,
1515 Dynamic = 1,
1517}
1518
1519#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1521pub struct ServerAddr(Url);
1522
1523impl FromStr for ServerAddr {
1524 type Err = io::Error;
1525
1526 fn from_str(input: &str) -> Result<Self, Self::Err> {
1530 let url: Url = if input.contains("://") {
1531 input.parse()
1532 } else {
1533 format!("nats://{input}").parse()
1534 }
1535 .map_err(|e| {
1536 io::Error::new(
1537 ErrorKind::InvalidInput,
1538 format!("NATS server URL is invalid: {e}"),
1539 )
1540 })?;
1541
1542 Self::from_url(url)
1543 }
1544}
1545
1546impl ServerAddr {
1547 pub fn from_url(url: Url) -> io::Result<Self> {
1549 if url.scheme() != "nats"
1550 && url.scheme() != "tls"
1551 && url.scheme() != "ws"
1552 && url.scheme() != "wss"
1553 {
1554 return Err(std::io::Error::new(
1555 ErrorKind::InvalidInput,
1556 format!("invalid scheme for NATS server URL: {}", url.scheme()),
1557 ));
1558 }
1559
1560 Ok(Self(url))
1561 }
1562
1563 pub fn into_inner(self) -> Url {
1565 self.0
1566 }
1567
1568 pub fn tls_required(&self) -> bool {
1570 self.0.scheme() == "tls"
1571 }
1572
1573 pub fn has_user_pass(&self) -> bool {
1575 self.0.username() != ""
1576 }
1577
1578 pub fn scheme(&self) -> &str {
1579 self.0.scheme()
1580 }
1581
1582 pub fn host(&self) -> &str {
1584 match self.0.host() {
1585 Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
1586 Some(Host::Ipv6 { .. }) => {
1588 let host = self.0.host_str().unwrap();
1589 &host[1..host.len() - 1]
1590 }
1591 None => "",
1592 }
1593 }
1594
1595 pub fn is_websocket(&self) -> bool {
1596 self.0.scheme() == "ws" || self.0.scheme() == "wss"
1597 }
1598
1599 pub fn port(&self) -> u16 {
1601 self.0.port().unwrap_or(4222)
1602 }
1603
1604 pub fn username(&self) -> Option<&str> {
1606 let user = self.0.username();
1607 if user.is_empty() {
1608 None
1609 } else {
1610 Some(user)
1611 }
1612 }
1613
1614 pub fn password(&self) -> Option<&str> {
1616 self.0.password()
1617 }
1618
1619 pub async fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
1621 tokio::net::lookup_host((self.host(), self.port())).await
1622 }
1623}
1624
1625pub trait ToServerAddrs {
1630 type Iter: Iterator<Item = ServerAddr>;
1633
1634 fn to_server_addrs(&self) -> io::Result<Self::Iter>;
1635}
1636
1637impl ToServerAddrs for ServerAddr {
1638 type Iter = option::IntoIter<ServerAddr>;
1639 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1640 Ok(Some(self.clone()).into_iter())
1641 }
1642}
1643
1644impl ToServerAddrs for str {
1645 type Iter = option::IntoIter<ServerAddr>;
1646 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1647 self.parse::<ServerAddr>()
1648 .map(|addr| Some(addr).into_iter())
1649 }
1650}
1651
1652impl ToServerAddrs for String {
1653 type Iter = option::IntoIter<ServerAddr>;
1654 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1655 (**self).to_server_addrs()
1656 }
1657}
1658
1659impl<T: AsRef<str>> ToServerAddrs for [T] {
1660 type Iter = std::vec::IntoIter<ServerAddr>;
1661 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1662 self.iter()
1663 .map(AsRef::as_ref)
1664 .map(str::parse)
1665 .collect::<io::Result<_>>()
1666 .map(Vec::into_iter)
1667 }
1668}
1669
1670impl<T: AsRef<str>> ToServerAddrs for Vec<T> {
1671 type Iter = std::vec::IntoIter<ServerAddr>;
1672 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1673 self.as_slice().to_server_addrs()
1674 }
1675}
1676
1677impl<'a> ToServerAddrs for &'a [ServerAddr] {
1678 type Iter = iter::Cloned<slice::Iter<'a, ServerAddr>>;
1679
1680 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1681 Ok(self.iter().cloned())
1682 }
1683}
1684
1685impl ToServerAddrs for Vec<ServerAddr> {
1686 type Iter = std::vec::IntoIter<ServerAddr>;
1687
1688 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1689 Ok(self.clone().into_iter())
1690 }
1691}
1692
1693impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
1694 type Iter = T::Iter;
1695 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1696 (**self).to_server_addrs()
1697 }
1698}
1699
1700pub(crate) fn is_valid_subject<T: AsRef<str>>(subject: T) -> bool {
1701 let subject_str = subject.as_ref();
1702 !subject_str.starts_with('.')
1703 && !subject_str.ends_with('.')
1704 && subject_str.bytes().all(|c| !c.is_ascii_whitespace())
1705}
1706macro_rules! from_with_timeout {
1707 ($t:ty, $k:ty, $origin: ty, $origin_kind: ty) => {
1708 impl From<$origin> for $t {
1709 fn from(err: $origin) -> Self {
1710 match err.kind() {
1711 <$origin_kind>::TimedOut => Self::new(<$k>::TimedOut),
1712 _ => Self::with_source(<$k>::Other, err),
1713 }
1714 }
1715 }
1716 };
1717}
1718pub(crate) use from_with_timeout;
1719
1720use crate::connection::ShouldFlush;
1721
1722#[cfg(test)]
1723mod tests {
1724 use super::*;
1725
1726 #[test]
1727 fn server_address_ipv6() {
1728 let address = ServerAddr::from_str("nats://[::]").unwrap();
1729 assert_eq!(address.host(), "::")
1730 }
1731
1732 #[test]
1733 fn server_address_ipv4() {
1734 let address = ServerAddr::from_str("nats://127.0.0.1").unwrap();
1735 assert_eq!(address.host(), "127.0.0.1")
1736 }
1737
1738 #[test]
1739 fn server_address_domain() {
1740 let address = ServerAddr::from_str("nats://example.com").unwrap();
1741 assert_eq!(address.host(), "example.com")
1742 }
1743
1744 #[test]
1745 fn to_server_addrs_vec_str() {
1746 let vec = vec!["nats://127.0.0.1", "nats://[::]"];
1747 let mut addrs_iter = vec.to_server_addrs().unwrap();
1748 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1749 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1750 assert_eq!(addrs_iter.next(), None);
1751 }
1752
1753 #[test]
1754 fn to_server_addrs_arr_str() {
1755 let arr = ["nats://127.0.0.1", "nats://[::]"];
1756 let mut addrs_iter = arr.to_server_addrs().unwrap();
1757 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1758 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1759 assert_eq!(addrs_iter.next(), None);
1760 }
1761
1762 #[test]
1763 fn to_server_addrs_vec_string() {
1764 let vec = vec!["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1765 let mut addrs_iter = vec.to_server_addrs().unwrap();
1766 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1767 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1768 assert_eq!(addrs_iter.next(), None);
1769 }
1770
1771 #[test]
1772 fn to_server_addrs_arr_string() {
1773 let arr = ["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1774 let mut addrs_iter = arr.to_server_addrs().unwrap();
1775 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1776 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1777 assert_eq!(addrs_iter.next(), None);
1778 }
1779}