sqlx_postgres/
listener.rs

1use std::fmt::{self, Debug};
2use std::io;
3use std::str::from_utf8;
4
5use futures_channel::mpsc;
6use futures_core::future::BoxFuture;
7use futures_core::stream::{BoxStream, Stream};
8use futures_util::{FutureExt, StreamExt, TryFutureExt, TryStreamExt};
9use sqlx_core::acquire::Acquire;
10use sqlx_core::transaction::Transaction;
11use sqlx_core::Either;
12
13use crate::describe::Describe;
14use crate::error::Error;
15use crate::executor::{Execute, Executor};
16use crate::message::{BackendMessageFormat, Notification};
17use crate::pool::PoolOptions;
18use crate::pool::{Pool, PoolConnection};
19use crate::{PgConnection, PgQueryResult, PgRow, PgStatement, PgTypeInfo, Postgres};
20
21/// A stream of asynchronous notifications from Postgres.
22///
23/// This listener will auto-reconnect. If the active
24/// connection being used ever dies, this listener will detect that event, create a
25/// new connection, will re-subscribe to all of the originally specified channels, and will resume
26/// operations as normal.
27pub struct PgListener {
28    pool: Pool<Postgres>,
29    connection: Option<PoolConnection<Postgres>>,
30    buffer_rx: mpsc::UnboundedReceiver<Notification>,
31    buffer_tx: Option<mpsc::UnboundedSender<Notification>>,
32    channels: Vec<String>,
33    ignore_close_event: bool,
34    eager_reconnect: bool,
35}
36
37/// An asynchronous notification from Postgres.
38pub struct PgNotification(Notification);
39
40impl PgListener {
41    pub async fn connect(url: &str) -> Result<Self, Error> {
42        // Create a pool of 1 without timeouts (as they don't apply here)
43        // We only use the pool to handle re-connections
44        let pool = PoolOptions::<Postgres>::new()
45            .max_connections(1)
46            .max_lifetime(None)
47            .idle_timeout(None)
48            .connect(url)
49            .await?;
50
51        let mut this = Self::connect_with(&pool).await?;
52        // We don't need to handle close events
53        this.ignore_close_event = true;
54
55        Ok(this)
56    }
57
58    pub async fn connect_with(pool: &Pool<Postgres>) -> Result<Self, Error> {
59        // Pull out an initial connection
60        let mut connection = pool.acquire().await?;
61
62        // Setup a notification buffer
63        let (sender, receiver) = mpsc::unbounded();
64        connection.inner.stream.notifications = Some(sender);
65
66        Ok(Self {
67            pool: pool.clone(),
68            connection: Some(connection),
69            buffer_rx: receiver,
70            buffer_tx: None,
71            channels: Vec::new(),
72            ignore_close_event: false,
73            eager_reconnect: true,
74        })
75    }
76
77    /// Set whether or not to ignore [`Pool::close_event()`]. Defaults to `false`.
78    ///
79    /// By default, when [`Pool::close()`] is called on the pool this listener is using
80    /// while [`Self::recv()`] or [`Self::try_recv()`] are waiting for a message, the wait is
81    /// cancelled and `Err(PoolClosed)` is returned.
82    ///
83    /// This is because `Pool::close()` will wait until _all_ connections are returned and closed,
84    /// including the one being used by this listener.
85    ///
86    /// Otherwise, `pool.close().await` would have to wait until `PgListener` encountered a
87    /// need to acquire a new connection (timeout, error, etc.) and dropped the one it was
88    /// currently holding, at which point `.recv()` or `.try_recv()` would return `Err(PoolClosed)`
89    /// on the attempt to acquire a new connection anyway.
90    ///
91    /// However, if you want `PgListener` to ignore the close event and continue waiting for a
92    /// message as long as it can, set this to `true`.
93    ///
94    /// Does nothing if this was constructed with [`PgListener::connect()`], as that creates an
95    /// internal pool just for the new instance of `PgListener` which cannot be closed manually.
96    pub fn ignore_pool_close_event(&mut self, val: bool) {
97        self.ignore_close_event = val;
98    }
99
100    /// Set whether a lost connection in `try_recv()` should be re-established before it returns
101    /// `Ok(None)`, or on the next call to `try_recv()`.
102    ///
103    /// By default, this is `true` and the connection is re-established before returning `Ok(None)`.
104    ///
105    /// If this is set to `false` then notifications will continue to be lost until the next call
106    /// to `try_recv()`. If your recovery logic uses a different database connection then
107    /// notifications that occur after it completes may be lost without any way to tell that they
108    /// have been.
109    pub fn eager_reconnect(&mut self, val: bool) {
110        self.eager_reconnect = val;
111    }
112
113    /// Starts listening for notifications on a channel.
114    /// The channel name is quoted here to ensure case sensitivity.
115    pub async fn listen(&mut self, channel: &str) -> Result<(), Error> {
116        self.connection()
117            .await?
118            .execute(&*format!(r#"LISTEN "{}""#, ident(channel)))
119            .await?;
120
121        self.channels.push(channel.to_owned());
122
123        Ok(())
124    }
125
126    /// Starts listening for notifications on all channels.
127    pub async fn listen_all(
128        &mut self,
129        channels: impl IntoIterator<Item = &str>,
130    ) -> Result<(), Error> {
131        let beg = self.channels.len();
132        self.channels.extend(channels.into_iter().map(|s| s.into()));
133
134        let query = build_listen_all_query(&self.channels[beg..]);
135        self.connection().await?.execute(&*query).await?;
136
137        Ok(())
138    }
139
140    /// Stops listening for notifications on a channel.
141    /// The channel name is quoted here to ensure case sensitivity.
142    pub async fn unlisten(&mut self, channel: &str) -> Result<(), Error> {
143        // use RAW connection and do NOT re-connect automatically, since this is not required for
144        // UNLISTEN (we've disconnected anyways)
145        if let Some(connection) = self.connection.as_mut() {
146            connection
147                .execute(&*format!(r#"UNLISTEN "{}""#, ident(channel)))
148                .await?;
149        }
150
151        if let Some(pos) = self.channels.iter().position(|s| s == channel) {
152            self.channels.remove(pos);
153        }
154
155        Ok(())
156    }
157
158    /// Stops listening for notifications on all channels.
159    pub async fn unlisten_all(&mut self) -> Result<(), Error> {
160        // use RAW connection and do NOT re-connect automatically, since this is not required for
161        // UNLISTEN (we've disconnected anyways)
162        if let Some(connection) = self.connection.as_mut() {
163            connection.execute("UNLISTEN *").await?;
164        }
165
166        self.channels.clear();
167
168        Ok(())
169    }
170
171    #[inline]
172    async fn connect_if_needed(&mut self) -> Result<(), Error> {
173        if self.connection.is_none() {
174            let mut connection = self.pool.acquire().await?;
175            connection.inner.stream.notifications = self.buffer_tx.take();
176
177            connection
178                .execute(&*build_listen_all_query(&self.channels))
179                .await?;
180
181            self.connection = Some(connection);
182        }
183
184        Ok(())
185    }
186
187    #[inline]
188    async fn connection(&mut self) -> Result<&mut PgConnection, Error> {
189        // Ensure we have an active connection to work with.
190        self.connect_if_needed().await?;
191
192        Ok(self.connection.as_mut().unwrap())
193    }
194
195    /// Receives the next notification available from any of the subscribed channels.
196    ///
197    /// If the connection to PostgreSQL is lost, it is automatically reconnected on the next
198    /// call to `recv()`, and should be entirely transparent (as long as it was just an
199    /// intermittent network failure or long-lived connection reaper).
200    ///
201    /// As notifications are transient, any received while the connection was lost, will not
202    /// be returned. If you'd prefer the reconnection to be explicit and have a chance to
203    /// do something before, please see [`try_recv`](Self::try_recv).
204    ///
205    /// # Example
206    ///
207    /// ```rust,no_run
208    /// # use sqlx::postgres::PgListener;
209    /// #
210    /// # sqlx::__rt::test_block_on(async move {
211    /// let mut listener = PgListener::connect("postgres:// ...").await?;
212    /// loop {
213    ///     // ask for next notification, re-connecting (transparently) if needed
214    ///     let notification = listener.recv().await?;
215    ///
216    ///     // handle notification, do something interesting
217    /// }
218    /// # Result::<(), sqlx::Error>::Ok(())
219    /// # }).unwrap();
220    /// ```
221    pub async fn recv(&mut self) -> Result<PgNotification, Error> {
222        loop {
223            if let Some(notification) = self.try_recv().await? {
224                return Ok(notification);
225            }
226        }
227    }
228
229    /// Receives the next notification available from any of the subscribed channels.
230    ///
231    /// If the connection to PostgreSQL is lost, `None` is returned, and the connection is
232    /// reconnected either immediately, or on the next call to `try_recv()`, depending on
233    /// the value of [`eager_reconnect`].
234    ///
235    /// # Example
236    ///
237    /// ```rust,no_run
238    /// # use sqlx::postgres::PgListener;
239    /// #
240    /// # sqlx::__rt::test_block_on(async move {
241    /// # let mut listener = PgListener::connect("postgres:// ...").await?;
242    /// loop {
243    ///     // start handling notifications, connecting if needed
244    ///     while let Some(notification) = listener.try_recv().await? {
245    ///         // handle notification
246    ///     }
247    ///
248    ///     // connection lost, do something interesting
249    /// }
250    /// # Result::<(), sqlx::Error>::Ok(())
251    /// # }).unwrap();
252    /// ```
253    ///
254    /// [`eager_reconnect`]: PgListener::eager_reconnect
255    pub async fn try_recv(&mut self) -> Result<Option<PgNotification>, Error> {
256        // Flush the buffer first, if anything
257        // This would only fill up if this listener is used as a connection
258        if let Some(notification) = self.next_buffered() {
259            return Ok(Some(notification));
260        }
261
262        // Fetch our `CloseEvent` listener, if applicable.
263        let mut close_event = (!self.ignore_close_event).then(|| self.pool.close_event());
264
265        loop {
266            let next_message = self.connection().await?.inner.stream.recv_unchecked();
267
268            let res = if let Some(ref mut close_event) = close_event {
269                // cancels the wait and returns `Err(PoolClosed)` if the pool is closed
270                // before `next_message` returns, or if the pool was already closed
271                close_event.do_until(next_message).await?
272            } else {
273                next_message.await
274            };
275
276            let message = match res {
277                Ok(message) => message,
278
279                // The connection is dead, ensure that it is dropped,
280                // update self state, and loop to try again.
281                Err(Error::Io(err))
282                    if matches!(
283                        err.kind(),
284                        io::ErrorKind::ConnectionAborted |
285                        io::ErrorKind::UnexpectedEof |
286                        // see ERRORS section in tcp(7) man page (https://man7.org/linux/man-pages/man7/tcp.7.html)
287                        io::ErrorKind::TimedOut |
288                        io::ErrorKind::BrokenPipe
289                    ) =>
290                {
291                    if let Some(mut conn) = self.connection.take() {
292                        self.buffer_tx = conn.inner.stream.notifications.take();
293                        // Close the connection in a background task, so we can continue.
294                        conn.close_on_drop();
295                    }
296
297                    if self.eager_reconnect {
298                        self.connect_if_needed().await?;
299                    }
300
301                    // lost connection
302                    return Ok(None);
303                }
304
305                // Forward other errors
306                Err(error) => {
307                    return Err(error);
308                }
309            };
310
311            match message.format {
312                // We've received an async notification, return it.
313                BackendMessageFormat::NotificationResponse => {
314                    return Ok(Some(PgNotification(message.decode()?)));
315                }
316
317                // Mark the connection as ready for another query
318                BackendMessageFormat::ReadyForQuery => {
319                    self.connection().await?.inner.pending_ready_for_query_count -= 1;
320                }
321
322                // Ignore unexpected messages
323                _ => {}
324            }
325        }
326    }
327
328    /// Receives the next notification that already exists in the connection buffer, if any.
329    ///
330    /// This is similar to `try_recv`, except it will not wait if the connection has not yet received a notification.
331    ///
332    /// This is helpful if you want to retrieve all buffered notifications and process them in batches.
333    pub fn next_buffered(&mut self) -> Option<PgNotification> {
334        if let Ok(Some(notification)) = self.buffer_rx.try_next() {
335            Some(PgNotification(notification))
336        } else {
337            None
338        }
339    }
340
341    /// Consume this listener, returning a `Stream` of notifications.
342    ///
343    /// The backing connection will be automatically reconnected should it be lost.
344    ///
345    /// This has the same potential drawbacks as [`recv`](PgListener::recv).
346    ///
347    pub fn into_stream(mut self) -> impl Stream<Item = Result<PgNotification, Error>> + Unpin {
348        Box::pin(try_stream! {
349            loop {
350                r#yield!(self.recv().await?);
351            }
352        })
353    }
354}
355
356impl Drop for PgListener {
357    fn drop(&mut self) {
358        if let Some(mut conn) = self.connection.take() {
359            let fut = async move {
360                let _ = conn.execute("UNLISTEN *").await;
361
362                // inline the drop handler from `PoolConnection` so it doesn't try to spawn another task
363                // otherwise, it may trigger a panic if this task is dropped because the runtime is going away:
364                // https://github.com/launchbadge/sqlx/issues/1389
365                conn.return_to_pool().await;
366            };
367
368            // Unregister any listeners before returning the connection to the pool.
369            crate::rt::spawn(fut);
370        }
371    }
372}
373
374impl<'c> Acquire<'c> for &'c mut PgListener {
375    type Database = Postgres;
376    type Connection = &'c mut PgConnection;
377
378    fn acquire(self) -> BoxFuture<'c, Result<Self::Connection, Error>> {
379        self.connection().boxed()
380    }
381
382    fn begin(self) -> BoxFuture<'c, Result<Transaction<'c, Self::Database>, Error>> {
383        self.connection().and_then(|c| c.begin()).boxed()
384    }
385}
386
387impl<'c> Executor<'c> for &'c mut PgListener {
388    type Database = Postgres;
389
390    fn fetch_many<'e, 'q, E>(
391        self,
392        query: E,
393    ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
394    where
395        'c: 'e,
396        E: Execute<'q, Self::Database>,
397        'q: 'e,
398        E: 'q,
399    {
400        futures_util::stream::once(async move {
401            // need some basic type annotation to help the compiler a bit
402            let res: Result<_, Error> = Ok(self.connection().await?.fetch_many(query));
403            res
404        })
405        .try_flatten()
406        .boxed()
407    }
408
409    fn fetch_optional<'e, 'q, E>(self, query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
410    where
411        'c: 'e,
412        E: Execute<'q, Self::Database>,
413        'q: 'e,
414        E: 'q,
415    {
416        async move { self.connection().await?.fetch_optional(query).await }.boxed()
417    }
418
419    fn prepare_with<'e, 'q: 'e>(
420        self,
421        query: &'q str,
422        parameters: &'e [PgTypeInfo],
423    ) -> BoxFuture<'e, Result<PgStatement<'q>, Error>>
424    where
425        'c: 'e,
426    {
427        async move {
428            self.connection()
429                .await?
430                .prepare_with(query, parameters)
431                .await
432        }
433        .boxed()
434    }
435
436    #[doc(hidden)]
437    fn describe<'e, 'q: 'e>(
438        self,
439        query: &'q str,
440    ) -> BoxFuture<'e, Result<Describe<Self::Database>, Error>>
441    where
442        'c: 'e,
443    {
444        async move { self.connection().await?.describe(query).await }.boxed()
445    }
446}
447
448impl PgNotification {
449    /// The process ID of the notifying backend process.
450    #[inline]
451    pub fn process_id(&self) -> u32 {
452        self.0.process_id
453    }
454
455    /// The channel that the notify has been raised on. This can be thought
456    /// of as the message topic.
457    #[inline]
458    pub fn channel(&self) -> &str {
459        from_utf8(&self.0.channel).unwrap()
460    }
461
462    /// The payload of the notification. An empty payload is received as an
463    /// empty string.
464    #[inline]
465    pub fn payload(&self) -> &str {
466        from_utf8(&self.0.payload).unwrap()
467    }
468}
469
470impl Debug for PgListener {
471    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
472        f.debug_struct("PgListener").finish()
473    }
474}
475
476impl Debug for PgNotification {
477    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
478        f.debug_struct("PgNotification")
479            .field("process_id", &self.process_id())
480            .field("channel", &self.channel())
481            .field("payload", &self.payload())
482            .finish()
483    }
484}
485
486fn ident(mut name: &str) -> String {
487    // If the input string contains a NUL byte, we should truncate the
488    // identifier.
489    if let Some(index) = name.find('\0') {
490        name = &name[..index];
491    }
492
493    // Any double quotes must be escaped
494    name.replace('"', "\"\"")
495}
496
497fn build_listen_all_query(channels: impl IntoIterator<Item = impl AsRef<str>>) -> String {
498    channels.into_iter().fold(String::new(), |mut acc, chan| {
499        acc.push_str(r#"LISTEN ""#);
500        acc.push_str(&ident(chan.as_ref()));
501        acc.push_str(r#"";"#);
502        acc
503    })
504}
505
506#[test]
507fn test_build_listen_all_query_with_single_channel() {
508    let output = build_listen_all_query(&["test"]);
509    assert_eq!(output.as_str(), r#"LISTEN "test";"#);
510}
511
512#[test]
513fn test_build_listen_all_query_with_multiple_channels() {
514    let output = build_listen_all_query(&["channel.0", "channel.1"]);
515    assert_eq!(output.as_str(), r#"LISTEN "channel.0";LISTEN "channel.1";"#);
516}