sqlx_postgres/listener.rs
1use std::fmt::{self, Debug};
2use std::io;
3use std::str::from_utf8;
4
5use futures_channel::mpsc;
6use futures_core::future::BoxFuture;
7use futures_core::stream::{BoxStream, Stream};
8use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt};
9use sqlx_core::acquire::Acquire;
10use sqlx_core::transaction::Transaction;
11use sqlx_core::Either;
12
13use crate::describe::Describe;
14use crate::error::Error;
15use crate::executor::{Execute, Executor};
16use crate::message::{BackendMessageFormat, Notification};
17use crate::pool::PoolOptions;
18use crate::pool::{Pool, PoolConnection};
19use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres};
20
21/// A stream of asynchronous notifications from Postgres.
22///
23/// This listener will auto-reconnect. If the active
24/// connection being used ever dies, this listener will detect that event, create a
25/// new connection, will re-subscribe to all of the originally specified channels, and will resume
26/// operations as normal.
27pub struct PgListener {
28 pool: Pool<Postgres>,
29 connection: Option<PoolConnection<Postgres>>,
30 buffer_rx: mpsc::UnboundedReceiver<Notification>,
31 buffer_tx: Option<mpsc::UnboundedSender<Notification>>,
32 channels: Vec<String>,
33 ignore_close_event: bool,
34 eager_reconnect: bool,
35}
36
37/// An asynchronous notification from Postgres.
38pub struct PgNotification(Notification);
39
40impl PgListener {
41 pub async fn connect(url: &str) -> Result<Self, Error> {
42 // Create a pool of 1 without timeouts (as they don't apply here)
43 // We only use the pool to handle re-connections
44 let pool = PoolOptions::<Postgres>::new()
45 .max_connections(1)
46 .max_lifetime(None)
47 .idle_timeout(None)
48 .connect(url)
49 .await?;
50
51 let mut this = Self::connect_with(&pool).await?;
52 // We don't need to handle close events
53 this.ignore_close_event = true;
54
55 Ok(this)
56 }
57
58 pub async fn connect_with(pool: &Pool<Postgres>) -> Result<Self, Error> {
59 // Pull out an initial connection
60 let mut connection = pool.acquire().await?;
61
62 // Setup a notification buffer
63 let (sender, receiver) = mpsc::unbounded();
64 connection.inner.stream.notifications = Some(sender);
65
66 Ok(Self {
67 pool: pool.clone(),
68 connection: Some(connection),
69 buffer_rx: receiver,
70 buffer_tx: None,
71 channels: Vec::new(),
72 ignore_close_event: false,
73 eager_reconnect: true,
74 })
75 }
76
77 /// Set whether or not to ignore [`Pool::close_event()`]. Defaults to `false`.
78 ///
79 /// By default, when [`Pool::close()`] is called on the pool this listener is using
80 /// while [`Self::recv()`] or [`Self::try_recv()`] are waiting for a message, the wait is
81 /// cancelled and `Err(PoolClosed)` is returned.
82 ///
83 /// This is because `Pool::close()` will wait until _all_ connections are returned and closed,
84 /// including the one being used by this listener.
85 ///
86 /// Otherwise, `pool.close().await` would have to wait until `PgListener` encountered a
87 /// need to acquire a new connection (timeout, error, etc.) and dropped the one it was
88 /// currently holding, at which point `.recv()` or `.try_recv()` would return `Err(PoolClosed)`
89 /// on the attempt to acquire a new connection anyway.
90 ///
91 /// However, if you want `PgListener` to ignore the close event and continue waiting for a
92 /// message as long as it can, set this to `true`.
93 ///
94 /// Does nothing if this was constructed with [`PgListener::connect()`], as that creates an
95 /// internal pool just for the new instance of `PgListener` which cannot be closed manually.
96 pub fn ignore_pool_close_event(&mut self, val: bool) {
97 self.ignore_close_event = val;
98 }
99
100 /// Set whether a lost connection in `try_recv()` should be re-established before it returns
101 /// `Ok(None)`, or on the next call to `try_recv()`.
102 ///
103 /// By default, this is `true` and the connection is re-established before returning `Ok(None)`.
104 ///
105 /// If this is set to `false` then notifications will continue to be lost until the next call
106 /// to `try_recv()`. If your recovery logic uses a different database connection then
107 /// notifications that occur after it completes may be lost without any way to tell that they
108 /// have been.
109 pub fn eager_reconnect(&mut self, val: bool) {
110 self.eager_reconnect = val;
111 }
112
113 /// Starts listening for notifications on a channel.
114 /// The channel name is quoted here to ensure case sensitivity.
115 pub async fn listen(&mut self, channel: &str) -> Result<(), Error> {
116 self.connection()
117 .await?
118 .execute(&*format!(r#"LISTEN "{}""#, ident(channel)))
119 .await?;
120
121 self.channels.push(channel.to_owned());
122
123 Ok(())
124 }
125
126 /// Starts listening for notifications on all channels.
127 pub async fn listen_all(
128 &mut self,
129 channels: impl IntoIterator<Item = &str>,
130 ) -> Result<(), Error> {
131 let beg = self.channels.len();
132 self.channels.extend(channels.into_iter().map(|s| s.into()));
133
134 let query = build_listen_all_query(&self.channels[beg..]);
135 self.connection().await?.execute(&*query).await?;
136
137 Ok(())
138 }
139
140 /// Stops listening for notifications on a channel.
141 /// The channel name is quoted here to ensure case sensitivity.
142 pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> {
143 // use RAW connection and do NOT re-connect automatically, since this is not required for
144 // UNLISTEN (we've disconnected anyways)
145 if let Some(connection) = self.connection.as_mut() {
146 connection
147 .execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
148 .await?;
149 }
150
151 if let Some(pos) = self.channels.iter().position(|s| s == channel) {
152 self.channels.remove(pos);
153 }
154
155 Ok(())
156 }
157
158 /// Stops listening for notifications on all channels.
159 pub async fn unlisten_all(&mut self) -> Result<(), Error> {
160 // use RAW connection and do NOT re-connect automatically, since this is not required for
161 // UNLISTEN (we've disconnected anyways)
162 if let Some(connection) = self.connection.as_mut() {
163 connection.execute("UNLISTEN *").await?;
164 }
165
166 self.channels.clear();
167
168 Ok(())
169 }
170
171 #[inline]
172 async fn connect_if_needed(&mut self) -> Result<(), Error> {
173 if self.connection.is_none() {
174 let mut connection = self.pool.acquire().await?;
175 connection.inner.stream.notifications = self.buffer_tx.take();
176
177 connection
178 .execute(&*build_listen_all_query(&self.channels))
179 .await?;
180
181 self.connection = Some(connection);
182 }
183
184 Ok(())
185 }
186
187 #[inline]
188 async fn connection(&mut self) -> Result<&mut PgConnection, Error> {
189 // Ensure we have an active connection to work with.
190 self.connect_if_needed().await?;
191
192 Ok(self.connection.as_mut().unwrap())
193 }
194
195 /// Receives the next notification available from any of the subscribed channels.
196 ///
197 /// If the connection to PostgreSQL is lost, it is automatically reconnected on the next
198 /// call to `recv()`, and should be entirely transparent (as long as it was just an
199 /// intermittent network failure or long-lived connection reaper).
200 ///
201 /// As notifications are transient, any received while the connection was lost, will not
202 /// be returned. If you'd prefer the reconnection to be explicit and have a chance to
203 /// do something before, please see [`try_recv`](Self::try_recv).
204 ///
205 /// # Example
206 ///
207 /// ```rust,no_run
208 /// # use sqlx::postgres::PgListener;
209 /// #
210 /// # sqlx::__rt::test_block_on(async move {
211 /// let mut listener = PgListener::connect("postgres:// ...").await?;
212 /// loop {
213 /// // ask for next notification, re-connecting (transparently) if needed
214 /// let notification = listener.recv().await?;
215 ///
216 /// // handle notification, do something interesting
217 /// }
218 /// # Result::<(), sqlx::Error>::Ok(())
219 /// # }).unwrap();
220 /// ```
221 pub async fn recv(&mut self) -> Result<PgNotification, Error> {
222 loop {
223 if let Some(notification) = self.try_recv().await? {
224 return Ok(notification);
225 }
226 }
227 }
228
229 /// Receives the next notification available from any of the subscribed channels.
230 ///
231 /// If the connection to PostgreSQL is lost, `None` is returned, and the connection is
232 /// reconnected either immediately, or on the next call to `try_recv()`, depending on
233 /// the value of [`eager_reconnect`].
234 ///
235 /// # Example
236 ///
237 /// ```rust,no_run
238 /// # use sqlx::postgres::PgListener;
239 /// #
240 /// # sqlx::__rt::test_block_on(async move {
241 /// # let mut listener = PgListener::connect("postgres:// ...").await?;
242 /// loop {
243 /// // start handling notifications, connecting if needed
244 /// while let Some(notification) = listener.try_recv().await? {
245 /// // handle notification
246 /// }
247 ///
248 /// // connection lost, do something interesting
249 /// }
250 /// # Result::<(), sqlx::Error>::Ok(())
251 /// # }).unwrap();
252 /// ```
253 ///
254 /// [`eager_reconnect`]: PgListener::eager_reconnect
255 pub async fn try_recv(&mut self) -> Result<Option<PgNotification>, Error> {
256 // Flush the buffer first, if anything
257 // This would only fill up if this listener is used as a connection
258 if let Some(notification) = self.next_buffered() {
259 return Ok(Some(notification));
260 }
261
262 // Fetch our `CloseEvent` listener, if applicable.
263 let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event());
264
265 loop {
266 let next_message = self.connection().await?.inner.stream.recv_unchecked();
267
268 let res = if let Some(ref mut close_event) = close_event {
269 // cancels the wait and returns `Err(PoolClosed)` if the pool is closed
270 // before `next_message` returns, or if the pool was already closed
271 close_event.do_until(next_message).await?
272 } else {
273 next_message.await
274 };
275
276 let message = match res {
277 Ok(message) => message,
278
279 // The connection is dead, ensure that it is dropped,
280 // update self state, and loop to try again.
281 Err(Error::Io(err))
282 if matches!(
283 err.kind(),
284 io::ErrorKind::ConnectionAborted |
285 io::ErrorKind::UnexpectedEof |
286 // see ERRORS section in tcp(7) man page (https://man7.org/linux/man-pages/man7/tcp.7.html)
287 io::ErrorKind::TimedOut |
288 io::ErrorKind::BrokenPipe
289 ) =>
290 {
291 if let Some(mut conn) = self.connection.take() {
292 self.buffer_tx = conn.inner.stream.notifications.take();
293 // Close the connection in a background task, so we can continue.
294 conn.close_on_drop();
295 }
296
297 if self.eager_reconnect {
298 self.connect_if_needed().await?;
299 }
300
301 // lost connection
302 return Ok(None);
303 }
304
305 // Forward other errors
306 Err(error) => {
307 return Err(error);
308 }
309 };
310
311 match message.format {
312 // We've received an async notification, return it.
313 BackendMessageFormat::NotificationResponse => {
314 return Ok(Some(PgNotification(message.decode()?)));
315 }
316
317 // Mark the connection as ready for another query
318 BackendMessageFormat::ReadyForQuery => {
319 self.connection().await?.inner.pending_ready_for_query_count -= 1;
320 }
321
322 // Ignore unexpected messages
323 _ => {}
324 }
325 }
326 }
327
328 /// Receives the next notification that already exists in the connection buffer, if any.
329 ///
330 /// This is similar to `try_recv`, except it will not wait if the connection has not yet received a notification.
331 ///
332 /// This is helpful if you want to retrieve all buffered notifications and process them in batches.
333 pub fn next_buffered(&mut self) -> Option<PgNotification> {
334 if let Ok(Some(notification)) = self.buffer_rx.try_next() {
335 Some(PgNotification(notification))
336 } else {
337 None
338 }
339 }
340
341 /// Consume this listener, returning a `Stream` of notifications.
342 ///
343 /// The backing connection will be automatically reconnected should it be lost.
344 ///
345 /// This has the same potential drawbacks as [`recv`](PgListener::recv).
346 ///
347 pub fn into_stream(mut self) -> impl Stream<Item = Result<PgNotification, Error>> + Unpin {
348 Box::pin(try_stream! {
349 loop {
350 r#yield!(self.recv().await?);
351 }
352 })
353 }
354}
355
356impl Drop for PgListener {
357 fn drop(&mut self) {
358 if let Some(mut conn) = self.connection.take() {
359 let fut = async move {
360 let _ = conn.execute("UNLISTEN *").await;
361
362 // inline the drop handler from `PoolConnection` so it doesn't try to spawn another task
363 // otherwise, it may trigger a panic if this task is dropped because the runtime is going away:
364 // https://github.com/launchbadge/sqlx/issues/1389
365 conn.return_to_pool().await;
366 };
367
368 // Unregister any listeners before returning the connection to the pool.
369 crate::rt::spawn(fut);
370 }
371 }
372}
373
374impl<'c> Acquire<'c> for &'c mut PgListener {
375 type Database = Postgres;
376 type Connection = &'c mut PgConnection;
377
378 fn acquire(self) -> BoxFuture<'c, Result<Self::Connection, Error>> {
379 self.connection().boxed()
380 }
381
382 fn begin(self) -> BoxFuture<'c, Result<Transaction<'c, Self::Database>, Error>> {
383 self.connection().and_then(|c| c.begin()).boxed()
384 }
385}
386
387impl<'c> Executor<'c> for &'c mut PgListener {
388 type Database = Postgres;
389
390 fn fetch_many<'e, 'q, E>(
391 self,
392 query: E,
393 ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
394 where
395 'c: 'e,
396 E: Execute<'q, Self::Database>,
397 'q: 'e,
398 E: 'q,
399 {
400 futures_util::stream::once(async move {
401 // need some basic type annotation to help the compiler a bit
402 let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query));
403 res
404 })
405 .try_flatten()
406 .boxed()
407 }
408
409 fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
410 where
411 'c: 'e,
412 E: Execute<'q, Self::Database>,
413 'q: 'e,
414 E: 'q,
415 {
416 async move { self.connection().await?.fetch_optional(query).await }.boxed()
417 }
418
419 fn prepare_with<'e, 'q: 'e>(
420 self,
421 query: &'q str,
422 parameters: &'e [PgTypeInfo],
423 ) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
424 where
425 'c: 'e,
426 {
427 async move {
428 self.connection()
429 .await?
430 .prepare_with(query, parameters)
431 .await
432 }
433 .boxed()
434 }
435
436 #[doc(hidden)]
437 fn describe<'e, 'q: 'e>(
438 self,
439 query: &'q str,
440 ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
441 where
442 'c: 'e,
443 {
444 async move { self.connection().await?.describe(query).await }.boxed()
445 }
446}
447
448impl PgNotification {
449 /// The process ID of the notifying backend process.
450 #[inline]
451 pub fn process_id(&self) -> u32 {
452 self.0.process_id
453 }
454
455 /// The channel that the notify has been raised on. This can be thought
456 /// of as the message topic.
457 #[inline]
458 pub fn channel(&self) -> &str {
459 from_utf8(&self.0.channel).unwrap()
460 }
461
462 /// The payload of the notification. An empty payload is received as an
463 /// empty string.
464 #[inline]
465 pub fn payload(&self) -> &str {
466 from_utf8(&self.0.payload).unwrap()
467 }
468}
469
470impl Debug for PgListener {
471 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
472 f.debug_struct("PgListener").finish()
473 }
474}
475
476impl Debug for PgNotification {
477 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
478 f.debug_struct("PgNotification")
479 .field("process_id", &self.process_id())
480 .field("channel", &self.channel())
481 .field("payload", &self.payload())
482 .finish()
483 }
484}
485
486fn ident(mut name: &str) -> String {
487 // If the input string contains a NUL byte, we should truncate the
488 // identifier.
489 if let Some(index) = name.find('\0') {
490 name = &name[..index];
491 }
492
493 // Any double quotes must be escaped
494 name.replace('"', "\"\"")
495}
496
497fn build_listen_all_query(channels: impl IntoIterator<Item = impl AsRef<str>>) -> String {
498 channels.into_iter().fold(String::new(), |mut acc, chan| {
499 acc.push_str(r#"LISTEN ""#);
500 acc.push_str(&ident(chan.as_ref()));
501 acc.push_str(r#"";"#);
502 acc
503 })
504}
505
506#[test]
507fn test_build_listen_all_query_with_single_channel() {
508 let output = build_listen_all_query(&["test"]);
509 assert_eq!(output.as_str(), r#"LISTEN "test";"#);
510}
511
512#[test]
513fn test_build_listen_all_query_with_multiple_channels() {
514 let output = build_listen_all_query(&["channel.0", "channel.1"]);
515 assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#);
516}