diesel_async/pg/
mod.rs

1//! Provides types and functions related to working with PostgreSQL
2//!
3//! Much of this module is re-exported from database agnostic locations.
4//! However, if you are writing code specifically to extend Diesel on
5//! PostgreSQL, you may need to work with this module directly.
6
7use self::error_helper::ErrorHelper;
8use self::row::PgRow;
9use self::serialize::ToSqlHelper;
10use crate::stmt_cache::{PrepareCallback, StmtCache};
11use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection};
12use diesel::connection::statement_cache::{PrepareForCache, StatementCacheKey};
13use diesel::connection::Instrumentation;
14use diesel::connection::InstrumentationEvent;
15use diesel::connection::StrQueryHelper;
16use diesel::pg::{
17    Pg, PgMetadataCache, PgMetadataCacheKey, PgMetadataLookup, PgQueryBuilder, PgTypeMetadata,
18};
19use diesel::query_builder::bind_collector::RawBytesBindCollector;
20use diesel::query_builder::{AsQuery, QueryBuilder, QueryFragment, QueryId};
21use diesel::result::{DatabaseErrorKind, Error};
22use diesel::{ConnectionError, ConnectionResult, QueryResult};
23use futures_util::future::BoxFuture;
24use futures_util::future::Either;
25use futures_util::stream::{BoxStream, TryStreamExt};
26use futures_util::TryFutureExt;
27use futures_util::{Future, FutureExt, StreamExt};
28use std::collections::{HashMap, HashSet};
29use std::sync::Arc;
30use tokio::sync::broadcast;
31use tokio::sync::oneshot;
32use tokio::sync::Mutex;
33use tokio_postgres::types::ToSql;
34use tokio_postgres::types::Type;
35use tokio_postgres::Statement;
36
37pub use self::transaction_builder::TransactionBuilder;
38
39mod error_helper;
40mod row;
41mod serialize;
42mod transaction_builder;
43
44const FAKE_OID: u32 = 0;
45
46/// A connection to a PostgreSQL database.
47///
48/// Connection URLs should be in the form
49/// `postgres://[user[:password]@]host/database_name`
50///
51/// Checkout the documentation of the [tokio_postgres]
52/// crate for details about the format
53///
54/// [tokio_postgres]: https://docs.rs/tokio-postgres/0.7.6/tokio_postgres/config/struct.Config.html#url
55///
56/// ## Pipelining
57///
58/// This connection supports *pipelined* requests. Pipelining can improve performance in use cases in which multiple,
59/// independent queries need to be executed. In a traditional workflow, each query is sent to the server after the
60/// previous query completes. In contrast, pipelining allows the client to send all of the queries to the server up
61/// front, minimizing time spent by one side waiting for the other to finish sending data:
62///
63/// ```not_rust
64///             Sequential                              Pipelined
65/// | Client         | Server          |    | Client         | Server          |
66/// |----------------|-----------------|    |----------------|-----------------|
67/// | send query 1   |                 |    | send query 1   |                 |
68/// |                | process query 1 |    | send query 2   | process query 1 |
69/// | receive rows 1 |                 |    | send query 3   | process query 2 |
70/// | send query 2   |                 |    | receive rows 1 | process query 3 |
71/// |                | process query 2 |    | receive rows 2 |                 |
72/// | receive rows 2 |                 |    | receive rows 3 |                 |
73/// | send query 3   |                 |
74/// |                | process query 3 |
75/// | receive rows 3 |                 |
76/// ```
77///
78/// In both cases, the PostgreSQL server is executing the queries **sequentially** - pipelining just allows both sides of
79/// the connection to work concurrently when possible.
80///
81/// Pipelining happens automatically when futures are polled concurrently (for example, by using the futures `join`
82/// combinator):
83///
84/// ```rust
85/// # include!("../doctest_setup.rs");
86/// use diesel_async::RunQueryDsl;
87///
88/// #
89/// # #[tokio::main(flavor = "current_thread")]
90/// # async fn main() {
91/// #     run_test().await.unwrap();
92/// # }
93/// #
94/// # async fn run_test() -> QueryResult<()> {
95/// #     use diesel::sql_types::{Text, Integer};
96/// #     let conn = &mut establish_connection().await;
97///       let q1 = diesel::select(1_i32.into_sql::<Integer>());
98///       let q2 = diesel::select(2_i32.into_sql::<Integer>());
99///
100///       // construct multiple futures for different queries
101///       let f1 = q1.get_result::<i32>(conn);
102///       let f2 = q2.get_result::<i32>(conn);
103///
104///       // wait on both results
105///       let res = futures_util::try_join!(f1, f2)?;
106///
107///       assert_eq!(res.0, 1);
108///       assert_eq!(res.1, 2);
109///       # Ok(())
110/// # }
111/// ```
112///
113/// ## TLS
114///
115/// Connections created by [`AsyncPgConnection::establish`] do not support TLS.
116///
117/// TLS support for tokio_postgres connections is implemented by external crates, e.g. [tokio_postgres_rustls].
118///
119/// [`AsyncPgConnection::try_from_client_and_connection`] can be used to construct a connection from an existing
120/// [`tokio_postgres::Connection`] with TLS enabled.
121///
122/// [tokio_postgres_rustls]: https://docs.rs/tokio-postgres-rustls/0.12.0/tokio_postgres_rustls/
123pub struct AsyncPgConnection {
124    conn: Arc<tokio_postgres::Client>,
125    stmt_cache: Arc<Mutex<StmtCache<diesel::pg::Pg, Statement>>>,
126    transaction_state: Arc<Mutex<AnsiTransactionManager>>,
127    metadata_cache: Arc<Mutex<PgMetadataCache>>,
128    connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
129    shutdown_channel: Option<oneshot::Sender<()>>,
130    // a sync mutex is fine here as we only hold it for a really short time
131    instrumentation: Arc<std::sync::Mutex<Option<Box<dyn Instrumentation>>>>,
132}
133
134#[async_trait::async_trait]
135impl SimpleAsyncConnection for AsyncPgConnection {
136    async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
137        self.record_instrumentation(InstrumentationEvent::start_query(&StrQueryHelper::new(
138            query,
139        )));
140        let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
141        let batch_execute = self
142            .conn
143            .batch_execute(query)
144            .map_err(ErrorHelper)
145            .map_err(Into::into);
146        let r = drive_future(connection_future, batch_execute).await;
147        self.record_instrumentation(InstrumentationEvent::finish_query(
148            &StrQueryHelper::new(query),
149            r.as_ref().err(),
150        ));
151        r
152    }
153}
154
155#[async_trait::async_trait]
156impl AsyncConnection for AsyncPgConnection {
157    type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult<Self::Stream<'conn, 'query>>>;
158    type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult<usize>>;
159    type Stream<'conn, 'query> = BoxStream<'static, QueryResult<PgRow>>;
160    type Row<'conn, 'query> = PgRow;
161    type Backend = diesel::pg::Pg;
162    type TransactionManager = AnsiTransactionManager;
163
164    async fn establish(database_url: &str) -> ConnectionResult<Self> {
165        let mut instrumentation = diesel::connection::get_default_instrumentation();
166        instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
167            database_url,
168        ));
169        let instrumentation = Arc::new(std::sync::Mutex::new(instrumentation));
170        let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
171            .await
172            .map_err(ErrorHelper)?;
173
174        let (error_rx, shutdown_tx) = drive_connection(connection);
175
176        let r = Self::setup(
177            client,
178            Some(error_rx),
179            Some(shutdown_tx),
180            Arc::clone(&instrumentation),
181        )
182        .await;
183
184        instrumentation
185            .lock()
186            .unwrap_or_else(|e| e.into_inner())
187            .on_connection_event(InstrumentationEvent::finish_establish_connection(
188                database_url,
189                r.as_ref().err(),
190            ));
191        r
192    }
193
194    fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
195    where
196        T: AsQuery + 'query,
197        T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
198    {
199        let query = source.as_query();
200        let load_future = self.with_prepared_statement(query, load_prepared);
201
202        self.run_with_connection_future(load_future)
203    }
204
205    fn execute_returning_count<'conn, 'query, T>(
206        &'conn mut self,
207        source: T,
208    ) -> Self::ExecuteFuture<'conn, 'query>
209    where
210        T: QueryFragment<Self::Backend> + QueryId + 'query,
211    {
212        let execute = self.with_prepared_statement(source, execute_prepared);
213        self.run_with_connection_future(execute)
214    }
215
216    fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
217        // there should be no other pending future when this is called
218        // that means there is only one instance of this arc and
219        // we can simply access the inner data
220        if let Some(tm) = Arc::get_mut(&mut self.transaction_state) {
221            tm.get_mut()
222        } else {
223            panic!("Cannot access shared transaction state")
224        }
225    }
226
227    fn instrumentation(&mut self) -> &mut dyn Instrumentation {
228        // there should be no other pending future when this is called
229        // that means there is only one instance of this arc and
230        // we can simply access the inner data
231        if let Some(instrumentation) = Arc::get_mut(&mut self.instrumentation) {
232            instrumentation.get_mut().unwrap_or_else(|p| p.into_inner())
233        } else {
234            panic!("Cannot access shared instrumentation")
235        }
236    }
237
238    fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
239        self.instrumentation = Arc::new(std::sync::Mutex::new(Some(Box::new(instrumentation))));
240    }
241}
242
243impl Drop for AsyncPgConnection {
244    fn drop(&mut self) {
245        if let Some(tx) = self.shutdown_channel.take() {
246            let _ = tx.send(());
247        }
248    }
249}
250
251async fn load_prepared(
252    conn: Arc<tokio_postgres::Client>,
253    stmt: Statement,
254    binds: Vec<ToSqlHelper>,
255) -> QueryResult<BoxStream<'static, QueryResult<PgRow>>> {
256    let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?;
257
258    Ok(res
259        .map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
260        .map_ok(PgRow::new)
261        .boxed())
262}
263
264async fn execute_prepared(
265    conn: Arc<tokio_postgres::Client>,
266    stmt: Statement,
267    binds: Vec<ToSqlHelper>,
268) -> QueryResult<usize> {
269    let binds = binds
270        .iter()
271        .map(|b| b as &(dyn ToSql + Sync))
272        .collect::<Vec<_>>();
273
274    let res = tokio_postgres::Client::execute(&conn, &stmt, &binds as &[_])
275        .await
276        .map_err(ErrorHelper)?;
277    res.try_into()
278        .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))
279}
280
281#[inline(always)]
282fn update_transaction_manager_status<T>(
283    query_result: QueryResult<T>,
284    transaction_manager: &mut AnsiTransactionManager,
285) -> QueryResult<T> {
286    if let Err(diesel::result::Error::DatabaseError(DatabaseErrorKind::SerializationFailure, _)) =
287        query_result
288    {
289        transaction_manager
290            .status
291            .set_requires_rollback_maybe_up_to_top_level(true)
292    }
293    query_result
294}
295
296#[async_trait::async_trait]
297impl PrepareCallback<Statement, PgTypeMetadata> for Arc<tokio_postgres::Client> {
298    async fn prepare(
299        self,
300        sql: &str,
301        metadata: &[PgTypeMetadata],
302        _is_for_cache: PrepareForCache,
303    ) -> QueryResult<(Statement, Self)> {
304        let bind_types = metadata
305            .iter()
306            .map(type_from_oid)
307            .collect::<QueryResult<Vec<_>>>()?;
308
309        let stmt = self
310            .prepare_typed(sql, &bind_types)
311            .await
312            .map_err(ErrorHelper);
313        Ok((stmt?, self))
314    }
315}
316
317fn type_from_oid(t: &PgTypeMetadata) -> QueryResult<Type> {
318    let oid = t
319        .oid()
320        .map_err(|e| diesel::result::Error::SerializationError(Box::new(e) as _))?;
321
322    if let Some(tpe) = Type::from_oid(oid) {
323        return Ok(tpe);
324    }
325
326    Ok(Type::new(
327        format!("diesel_custom_type_{oid}"),
328        oid,
329        tokio_postgres::types::Kind::Simple,
330        "public".into(),
331    ))
332}
333
334impl AsyncPgConnection {
335    /// Build a transaction, specifying additional details such as isolation level
336    ///
337    /// See [`TransactionBuilder`] for more examples.
338    ///
339    /// [`TransactionBuilder`]: crate::pg::TransactionBuilder
340    ///
341    /// ```rust
342    /// # include!("../doctest_setup.rs");
343    /// # use scoped_futures::ScopedFutureExt;
344    /// #
345    /// # #[tokio::main(flavor = "current_thread")]
346    /// # async fn main() {
347    /// #     run_test().await.unwrap();
348    /// # }
349    /// #
350    /// # async fn run_test() -> QueryResult<()> {
351    /// #     use schema::users::dsl::*;
352    /// #     let conn = &mut connection_no_transaction().await;
353    /// conn.build_transaction()
354    ///     .read_only()
355    ///     .serializable()
356    ///     .deferrable()
357    ///     .run(|conn| async move { Ok(()) }.scope_boxed())
358    ///     .await
359    /// # }
360    /// ```
361    pub fn build_transaction(&mut self) -> TransactionBuilder<Self> {
362        TransactionBuilder::new(self)
363    }
364
365    /// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
366    pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
367        Self::setup(
368            conn,
369            None,
370            None,
371            Arc::new(std::sync::Mutex::new(
372                diesel::connection::get_default_instrumentation(),
373            )),
374        )
375        .await
376    }
377
378    /// Constructs a new `AsyncPgConnection` from an existing [`tokio_postgres::Client`] and
379    /// [`tokio_postgres::Connection`]
380    pub async fn try_from_client_and_connection<S>(
381        client: tokio_postgres::Client,
382        conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
383    ) -> ConnectionResult<Self>
384    where
385        S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
386    {
387        let (error_rx, shutdown_tx) = drive_connection(conn);
388
389        Self::setup(
390            client,
391            Some(error_rx),
392            Some(shutdown_tx),
393            Arc::new(std::sync::Mutex::new(
394                diesel::connection::get_default_instrumentation(),
395            )),
396        )
397        .await
398    }
399
400    async fn setup(
401        conn: tokio_postgres::Client,
402        connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
403        shutdown_channel: Option<oneshot::Sender<()>>,
404        instrumentation: Arc<std::sync::Mutex<Option<Box<dyn Instrumentation>>>>,
405    ) -> ConnectionResult<Self> {
406        let mut conn = Self {
407            conn: Arc::new(conn),
408            stmt_cache: Arc::new(Mutex::new(StmtCache::new())),
409            transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
410            metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
411            connection_future,
412            shutdown_channel,
413            instrumentation,
414        };
415        conn.set_config_options()
416            .await
417            .map_err(ConnectionError::CouldntSetupConfiguration)?;
418        Ok(conn)
419    }
420
421    /// Constructs a cancellation token that can later be used to request cancellation of a query running on the connection associated with this client.
422    pub fn cancel_token(&self) -> tokio_postgres::CancelToken {
423        self.conn.cancel_token()
424    }
425
426    async fn set_config_options(&mut self) -> QueryResult<()> {
427        use crate::run_query_dsl::RunQueryDsl;
428
429        futures_util::try_join!(
430            diesel::sql_query("SET TIME ZONE 'UTC'").execute(self),
431            diesel::sql_query("SET CLIENT_ENCODING TO 'UTF8'").execute(self),
432        )?;
433        Ok(())
434    }
435
436    fn run_with_connection_future<'a, R: 'a>(
437        &self,
438        future: impl Future<Output = QueryResult<R>> + Send + 'a,
439    ) -> BoxFuture<'a, QueryResult<R>> {
440        let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
441        drive_future(connection_future, future).boxed()
442    }
443
444    fn with_prepared_statement<'a, T, F, R>(
445        &mut self,
446        query: T,
447        callback: fn(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
448    ) -> BoxFuture<'a, QueryResult<R>>
449    where
450        T: QueryFragment<diesel::pg::Pg> + QueryId,
451        F: Future<Output = QueryResult<R>> + Send + 'a,
452        R: Send,
453    {
454        self.record_instrumentation(InstrumentationEvent::start_query(&diesel::debug_query(
455            &query,
456        )));
457        // we explicilty descruct the query here before going into the async block
458        //
459        // That's required to remove the send bound from `T` as we have translated
460        // the query type to just a string (for the SQL) and a bunch of bytes (for the binds)
461        // which both are `Send`.
462        // We also collect the query id (essentially an integer) and the safe_to_cache flag here
463        // so there is no need to even access the query in the async block below
464        let mut query_builder = PgQueryBuilder::default();
465
466        let bind_data = construct_bind_data(&query);
467
468        // The code that doesn't need the `T` generic parameter is in a separate function to reduce LLVM IR lines
469        self.with_prepared_statement_after_sql_built(
470            callback,
471            query.is_safe_to_cache_prepared(&Pg),
472            T::query_id(),
473            query.to_sql(&mut query_builder, &Pg),
474            query_builder,
475            bind_data,
476        )
477    }
478
479    fn with_prepared_statement_after_sql_built<'a, F, R>(
480        &mut self,
481        callback: fn(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
482        is_safe_to_cache_prepared: QueryResult<bool>,
483        query_id: Option<std::any::TypeId>,
484        to_sql_result: QueryResult<()>,
485        query_builder: PgQueryBuilder,
486        bind_data: BindData,
487    ) -> BoxFuture<'a, QueryResult<R>>
488    where
489        F: Future<Output = QueryResult<R>> + Send + 'a,
490        R: Send,
491    {
492        let raw_connection = self.conn.clone();
493        let stmt_cache = self.stmt_cache.clone();
494        let metadata_cache = self.metadata_cache.clone();
495        let tm = self.transaction_state.clone();
496        let instrumentation = self.instrumentation.clone();
497        let BindData {
498            collect_bind_result,
499            fake_oid_locations,
500            generated_oids,
501            mut bind_collector,
502        } = bind_data;
503
504        async move {
505            let sql = to_sql_result.map(|_| query_builder.finish())?;
506            let res = async {
507            let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
508            collect_bind_result?;
509            // Check whether we need to resolve some types at all
510            //
511            // If the user doesn't use custom types there is no need
512            // to borther with that at all
513            if let Some(ref unresolved_types) = generated_oids {
514                let metadata_cache = &mut *metadata_cache.lock().await;
515                let mut real_oids = HashMap::new();
516
517                for ((schema, lookup_type_name), (fake_oid, fake_array_oid)) in
518                    unresolved_types
519                {
520                    // for each unresolved item
521                    // we check whether it's arleady in the cache
522                    // or perform a lookup and insert it into the cache
523                    let cache_key = PgMetadataCacheKey::new(
524                        schema.as_deref().map(Into::into),
525                        lookup_type_name.into(),
526                    );
527                    let real_metadata = if let Some(type_metadata) =
528                        metadata_cache.lookup_type(&cache_key)
529                    {
530                        type_metadata
531                    } else {
532                        let type_metadata =
533                            lookup_type(schema.clone(), lookup_type_name.clone(), &raw_connection)
534                                .await?;
535                        metadata_cache.store_type(cache_key, type_metadata);
536
537                        PgTypeMetadata::from_result(Ok(type_metadata))
538                    };
539                    // let (fake_oid, fake_array_oid) = metadata_lookup.fake_oids(index);
540                    let (real_oid, real_array_oid) = unwrap_oids(&real_metadata);
541                    real_oids.extend([(*fake_oid, real_oid), (*fake_array_oid, real_array_oid)]);
542                }
543
544                // Replace fake OIDs with real OIDs in `bind_collector.metadata`
545                for m in &mut bind_collector.metadata {
546                    let (oid, array_oid) = unwrap_oids(m);
547                    *m = PgTypeMetadata::new(
548                        real_oids.get(&oid).copied().unwrap_or(oid),
549                        real_oids.get(&array_oid).copied().unwrap_or(array_oid)
550                    );
551                }
552                // Replace fake OIDs with real OIDs in `bind_collector.binds`
553                for (bind_index, byte_index) in fake_oid_locations {
554                    replace_fake_oid(&mut bind_collector.binds, &real_oids, bind_index, byte_index)
555                        .ok_or_else(|| {
556                            Error::SerializationError(
557                                format!("diesel_async failed to replace a type OID serialized in bind value {bind_index}").into(),
558                            )
559                        })?;
560                }
561            }
562            let key = match query_id {
563                Some(id) => StatementCacheKey::Type(id),
564                None => StatementCacheKey::Sql {
565                    sql: sql.clone(),
566                    bind_types: bind_collector.metadata.clone(),
567                },
568            };
569            let stmt = {
570                let mut stmt_cache = stmt_cache.lock().await;
571                stmt_cache
572                    .cached_prepared_statement(
573                        key,
574                        sql.clone(),
575                        is_safe_to_cache_prepared,
576                        &bind_collector.metadata,
577                        raw_connection.clone(),
578                        &instrumentation
579                    )
580                    .await?
581                    .0
582                    .clone()
583            };
584
585            let binds = bind_collector
586                .metadata
587                .into_iter()
588                .zip(bind_collector.binds)
589                .map(|(meta, bind)| ToSqlHelper(meta, bind))
590                .collect::<Vec<_>>();
591                callback(raw_connection, stmt.clone(), binds).await
592            };
593            let res = res.await;
594            let mut tm = tm.lock().await;
595            let r = update_transaction_manager_status(res, &mut tm);
596            instrumentation
597                .lock()
598                .unwrap_or_else(|p| p.into_inner())
599                .on_connection_event(InstrumentationEvent::finish_query(
600                    &StrQueryHelper::new(&sql),
601                    r.as_ref().err(),
602                ));
603
604            r
605        }
606        .boxed()
607    }
608
609    fn record_instrumentation(&self, event: InstrumentationEvent<'_>) {
610        self.instrumentation
611            .lock()
612            .unwrap_or_else(|p| p.into_inner())
613            .on_connection_event(event);
614    }
615}
616
617struct BindData {
618    collect_bind_result: Result<(), Error>,
619    fake_oid_locations: Vec<(usize, usize)>,
620    generated_oids: GeneratedOidTypeMap,
621    bind_collector: RawBytesBindCollector<Pg>,
622}
623
624fn construct_bind_data(query: &dyn QueryFragment<diesel::pg::Pg>) -> BindData {
625    // we don't resolve custom types here yet, we do that later
626    // in the async block below as we might need to perform lookup
627    // queries for that.
628    //
629    // We apply this workaround to prevent requiring all the diesel
630    // serialization code to beeing async
631    //
632    // We give out constant fake oids here to optimize for the "happy" path
633    // without custom type lookup
634    let mut bind_collector_0 = RawBytesBindCollector::<diesel::pg::Pg>::new();
635    let mut metadata_lookup_0 = PgAsyncMetadataLookup {
636        custom_oid: false,
637        generated_oids: None,
638        oid_generator: |_, _| (FAKE_OID, FAKE_OID),
639    };
640    let collect_bind_result_0 =
641        query.collect_binds(&mut bind_collector_0, &mut metadata_lookup_0, &Pg);
642    // we have encountered a custom type oid, so we need to perform more work here.
643    // These oids can occure in two locations:
644    //
645    // * In the collected metadata -> relativly easy to resolve, just need to replace them below
646    // * As part of the seralized bind blob -> hard to replace
647    //
648    // To address the second case, we perform a second run of the bind collector
649    // with a different set of fake oids. Then we compare the output of the two runs
650    // and use that information to infer where to replace bytes in the serialized output
651    if metadata_lookup_0.custom_oid {
652        // we try to get the maxium oid we encountered here
653        // to be sure that we don't accidently give out a fake oid below that collides with
654        // something
655        let mut max_oid = bind_collector_0
656            .metadata
657            .iter()
658            .flat_map(|t| {
659                [
660                    t.oid().unwrap_or_default(),
661                    t.array_oid().unwrap_or_default(),
662                ]
663            })
664            .max()
665            .unwrap_or_default();
666        let mut bind_collector_1 = RawBytesBindCollector::<diesel::pg::Pg>::new();
667        let mut metadata_lookup_1 = PgAsyncMetadataLookup {
668            custom_oid: false,
669            generated_oids: Some(HashMap::new()),
670            oid_generator: move |_, _| {
671                max_oid += 2;
672                (max_oid, max_oid + 1)
673            },
674        };
675        let collect_bind_result_1 =
676            query.collect_binds(&mut bind_collector_1, &mut metadata_lookup_1, &Pg);
677
678        assert_eq!(
679            bind_collector_0.binds.len(),
680            bind_collector_0.metadata.len()
681        );
682        let fake_oid_locations = std::iter::zip(
683            bind_collector_0
684                .binds
685                .iter()
686                .zip(&bind_collector_0.metadata),
687            &bind_collector_1.binds,
688        )
689        .enumerate()
690        .flat_map(|(bind_index, ((bytes_0, metadata_0), bytes_1))| {
691            // custom oids might appear in the serialized bind arguments for arrays or composite (record) types
692            // in both cases the relevant buffer is a custom type on it's own
693            // so we only need to check the cases that contain a fake OID on their own
694            let (bytes_0, bytes_1) = if matches!(metadata_0.oid(), Ok(FAKE_OID)) {
695                (
696                    bytes_0.as_deref().unwrap_or_default(),
697                    bytes_1.as_deref().unwrap_or_default(),
698                )
699            } else {
700                // for all other cases, just return an empty
701                // list to make the iteration below a no-op
702                // and prevent the need of boxing
703                (&[] as &[_], &[] as &[_])
704            };
705            let lookup_map = metadata_lookup_1
706                .generated_oids
707                .as_ref()
708                .map(|map| {
709                    map.values()
710                        .flat_map(|(oid, array_oid)| [*oid, *array_oid])
711                        .collect::<HashSet<_>>()
712                })
713                .unwrap_or_default();
714            std::iter::zip(
715                bytes_0.windows(std::mem::size_of_val(&FAKE_OID)),
716                bytes_1.windows(std::mem::size_of_val(&FAKE_OID)),
717            )
718            .enumerate()
719            .filter_map(move |(byte_index, (l, r))| {
720                // here we infer if some byte sequence is a fake oid
721                // We use the following conditions for that:
722                //
723                // * The first byte sequence matches the constant FAKE_OID
724                // * The second sequence does not match the constant FAKE_OID
725                // * The second sequence is contained in the set of generated oid,
726                //   otherwise we get false positives around the boundary
727                //   of a to be replaced byte sequence
728                let r_val = u32::from_be_bytes(r.try_into().expect("That's the right size"));
729                (l == FAKE_OID.to_be_bytes()
730                    && r != FAKE_OID.to_be_bytes()
731                    && lookup_map.contains(&r_val))
732                .then_some((bind_index, byte_index))
733            })
734        })
735        // Avoid storing the bind collectors in the returned Future
736        .collect::<Vec<_>>();
737        BindData {
738            collect_bind_result: collect_bind_result_0.and(collect_bind_result_1),
739            fake_oid_locations,
740            generated_oids: metadata_lookup_1.generated_oids,
741            bind_collector: bind_collector_1,
742        }
743    } else {
744        BindData {
745            collect_bind_result: collect_bind_result_0,
746            fake_oid_locations: Vec::new(),
747            generated_oids: None,
748            bind_collector: bind_collector_0,
749        }
750    }
751}
752
753type GeneratedOidTypeMap = Option<HashMap<(Option<String>, String), (u32, u32)>>;
754
755/// Collects types that need to be looked up, and causes fake OIDs to be written into the bind collector
756/// so they can be replaced with asynchronously fetched OIDs after the original query is dropped
757struct PgAsyncMetadataLookup<F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static> {
758    custom_oid: bool,
759    generated_oids: GeneratedOidTypeMap,
760    oid_generator: F,
761}
762
763impl<F> PgMetadataLookup for PgAsyncMetadataLookup<F>
764where
765    F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static,
766{
767    fn lookup_type(&mut self, type_name: &str, schema: Option<&str>) -> PgTypeMetadata {
768        self.custom_oid = true;
769
770        let oid = if let Some(map) = &mut self.generated_oids {
771            *map.entry((schema.map(ToOwned::to_owned), type_name.to_owned()))
772                .or_insert_with(|| (self.oid_generator)(type_name, schema))
773        } else {
774            (self.oid_generator)(type_name, schema)
775        };
776
777        PgTypeMetadata::from_result(Ok(oid))
778    }
779}
780
781async fn lookup_type(
782    schema: Option<String>,
783    type_name: String,
784    raw_connection: &tokio_postgres::Client,
785) -> QueryResult<(u32, u32)> {
786    let r = if let Some(schema) = schema.as_ref() {
787        raw_connection
788            .query_one(
789                "SELECT pg_type.oid, pg_type.typarray FROM pg_type \
790             INNER JOIN pg_namespace ON pg_type.typnamespace = pg_namespace.oid \
791             WHERE pg_type.typname = $1 AND pg_namespace.nspname = $2 \
792             LIMIT 1",
793                &[&type_name, schema],
794            )
795            .await
796            .map_err(ErrorHelper)?
797    } else {
798        raw_connection
799            .query_one(
800                "SELECT pg_type.oid, pg_type.typarray FROM pg_type \
801             WHERE pg_type.oid = quote_ident($1)::regtype::oid \
802             LIMIT 1",
803                &[&type_name],
804            )
805            .await
806            .map_err(ErrorHelper)?
807    };
808    Ok((r.get(0), r.get(1)))
809}
810
811fn unwrap_oids(metadata: &PgTypeMetadata) -> (u32, u32) {
812    let err_msg = "PgTypeMetadata is supposed to always be Ok here";
813    (
814        metadata.oid().expect(err_msg),
815        metadata.array_oid().expect(err_msg),
816    )
817}
818
819fn replace_fake_oid(
820    binds: &mut [Option<Vec<u8>>],
821    real_oids: &HashMap<u32, u32>,
822    bind_index: usize,
823    byte_index: usize,
824) -> Option<()> {
825    let serialized_oid = binds
826        .get_mut(bind_index)?
827        .as_mut()?
828        .get_mut(byte_index..)?
829        .first_chunk_mut::<4>()?;
830    *serialized_oid = real_oids
831        .get(&u32::from_be_bytes(*serialized_oid))?
832        .to_be_bytes();
833    Some(())
834}
835
836async fn drive_future<R>(
837    connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
838    client_future: impl Future<Output = Result<R, diesel::result::Error>>,
839) -> Result<R, diesel::result::Error> {
840    if let Some(mut connection_future) = connection_future {
841        let client_future = std::pin::pin!(client_future);
842        let connection_future = std::pin::pin!(connection_future.recv());
843        match futures_util::future::select(client_future, connection_future).await {
844            Either::Left((res, _)) => res,
845            // we got an error from the background task
846            // return it to the user
847            Either::Right((Ok(e), _)) => Err(self::error_helper::from_tokio_postgres_error(e)),
848            // seems like the background thread died for whatever reason
849            Either::Right((Err(e), _)) => Err(diesel::result::Error::DatabaseError(
850                DatabaseErrorKind::UnableToSendCommand,
851                Box::new(e.to_string()),
852            )),
853        }
854    } else {
855        client_future.await
856    }
857}
858
859fn drive_connection<S>(
860    conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
861) -> (
862    broadcast::Receiver<Arc<tokio_postgres::Error>>,
863    oneshot::Sender<()>,
864)
865where
866    S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
867{
868    let (error_tx, error_rx) = tokio::sync::broadcast::channel(1);
869    let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
870
871    tokio::spawn(async move {
872        match futures_util::future::select(shutdown_rx, conn).await {
873            Either::Left(_) | Either::Right((Ok(_), _)) => {}
874            Either::Right((Err(e), _)) => {
875                let _ = error_tx.send(Arc::new(e));
876            }
877        }
878    });
879
880    (error_rx, shutdown_tx)
881}
882
883#[cfg(any(
884    feature = "deadpool",
885    feature = "bb8",
886    feature = "mobc",
887    feature = "r2d2"
888))]
889impl crate::pooled_connection::PoolableConnection for AsyncPgConnection {
890    fn is_broken(&mut self) -> bool {
891        use crate::TransactionManager;
892
893        Self::TransactionManager::is_broken_transaction_manager(self) || self.conn.is_closed()
894    }
895}
896
897#[cfg(test)]
898mod tests {
899    use super::*;
900    use crate::run_query_dsl::RunQueryDsl;
901    use diesel::sql_types::Integer;
902    use diesel::IntoSql;
903
904    #[tokio::test]
905    async fn pipelining() {
906        let database_url =
907            std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
908        let mut conn = crate::AsyncPgConnection::establish(&database_url)
909            .await
910            .unwrap();
911
912        let q1 = diesel::select(1_i32.into_sql::<Integer>());
913        let q2 = diesel::select(2_i32.into_sql::<Integer>());
914
915        let f1 = q1.get_result::<i32>(&mut conn);
916        let f2 = q2.get_result::<i32>(&mut conn);
917
918        let (r1, r2) = futures_util::try_join!(f1, f2).unwrap();
919
920        assert_eq!(r1, 1);
921        assert_eq!(r2, 2);
922    }
923}