1use self::error_helper::ErrorHelper;
8use self::row::PgRow;
9use self::serialize::ToSqlHelper;
10use crate::stmt_cache::{PrepareCallback, StmtCache};
11use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection};
12use diesel::connection::statement_cache::{PrepareForCache, StatementCacheKey};
13use diesel::connection::Instrumentation;
14use diesel::connection::InstrumentationEvent;
15use diesel::connection::StrQueryHelper;
16use diesel::pg::{
17 Pg, PgMetadataCache, PgMetadataCacheKey, PgMetadataLookup, PgQueryBuilder, PgTypeMetadata,
18};
19use diesel::query_builder::bind_collector::RawBytesBindCollector;
20use diesel::query_builder::{AsQuery, QueryBuilder, QueryFragment, QueryId};
21use diesel::result::{DatabaseErrorKind, Error};
22use diesel::{ConnectionError, ConnectionResult, QueryResult};
23use futures_util::future::BoxFuture;
24use futures_util::future::Either;
25use futures_util::stream::{BoxStream, TryStreamExt};
26use futures_util::TryFutureExt;
27use futures_util::{Future, FutureExt, StreamExt};
28use std::collections::{HashMap, HashSet};
29use std::sync::Arc;
30use tokio::sync::broadcast;
31use tokio::sync::oneshot;
32use tokio::sync::Mutex;
33use tokio_postgres::types::ToSql;
34use tokio_postgres::types::Type;
35use tokio_postgres::Statement;
36
37pub use self::transaction_builder::TransactionBuilder;
38
39mod error_helper;
40mod row;
41mod serialize;
42mod transaction_builder;
43
44const FAKE_OID: u32 = 0;
45
46pub struct AsyncPgConnection {
124 conn: Arc<tokio_postgres::Client>,
125 stmt_cache: Arc<Mutex<StmtCache<diesel::pg::Pg, Statement>>>,
126 transaction_state: Arc<Mutex<AnsiTransactionManager>>,
127 metadata_cache: Arc<Mutex<PgMetadataCache>>,
128 connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
129 shutdown_channel: Option<oneshot::Sender<()>>,
130 instrumentation: Arc<std::sync::Mutex<Option<Box<dyn Instrumentation>>>>,
132}
133
134#[async_trait::async_trait]
135impl SimpleAsyncConnection for AsyncPgConnection {
136 async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
137 self.record_instrumentation(InstrumentationEvent::start_query(&StrQueryHelper::new(
138 query,
139 )));
140 let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
141 let batch_execute = self
142 .conn
143 .batch_execute(query)
144 .map_err(ErrorHelper)
145 .map_err(Into::into);
146 let r = drive_future(connection_future, batch_execute).await;
147 self.record_instrumentation(InstrumentationEvent::finish_query(
148 &StrQueryHelper::new(query),
149 r.as_ref().err(),
150 ));
151 r
152 }
153}
154
155#[async_trait::async_trait]
156impl AsyncConnection for AsyncPgConnection {
157 type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult<Self::Stream<'conn, 'query>>>;
158 type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult<usize>>;
159 type Stream<'conn, 'query> = BoxStream<'static, QueryResult<PgRow>>;
160 type Row<'conn, 'query> = PgRow;
161 type Backend = diesel::pg::Pg;
162 type TransactionManager = AnsiTransactionManager;
163
164 async fn establish(database_url: &str) -> ConnectionResult<Self> {
165 let mut instrumentation = diesel::connection::get_default_instrumentation();
166 instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
167 database_url,
168 ));
169 let instrumentation = Arc::new(std::sync::Mutex::new(instrumentation));
170 let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
171 .await
172 .map_err(ErrorHelper)?;
173
174 let (error_rx, shutdown_tx) = drive_connection(connection);
175
176 let r = Self::setup(
177 client,
178 Some(error_rx),
179 Some(shutdown_tx),
180 Arc::clone(&instrumentation),
181 )
182 .await;
183
184 instrumentation
185 .lock()
186 .unwrap_or_else(|e| e.into_inner())
187 .on_connection_event(InstrumentationEvent::finish_establish_connection(
188 database_url,
189 r.as_ref().err(),
190 ));
191 r
192 }
193
194 fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
195 where
196 T: AsQuery + 'query,
197 T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
198 {
199 let query = source.as_query();
200 let load_future = self.with_prepared_statement(query, load_prepared);
201
202 self.run_with_connection_future(load_future)
203 }
204
205 fn execute_returning_count<'conn, 'query, T>(
206 &'conn mut self,
207 source: T,
208 ) -> Self::ExecuteFuture<'conn, 'query>
209 where
210 T: QueryFragment<Self::Backend> + QueryId + 'query,
211 {
212 let execute = self.with_prepared_statement(source, execute_prepared);
213 self.run_with_connection_future(execute)
214 }
215
216 fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
217 if let Some(tm) = Arc::get_mut(&mut self.transaction_state) {
221 tm.get_mut()
222 } else {
223 panic!("Cannot access shared transaction state")
224 }
225 }
226
227 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
228 if let Some(instrumentation) = Arc::get_mut(&mut self.instrumentation) {
232 instrumentation.get_mut().unwrap_or_else(|p| p.into_inner())
233 } else {
234 panic!("Cannot access shared instrumentation")
235 }
236 }
237
238 fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
239 self.instrumentation = Arc::new(std::sync::Mutex::new(Some(Box::new(instrumentation))));
240 }
241}
242
243impl Drop for AsyncPgConnection {
244 fn drop(&mut self) {
245 if let Some(tx) = self.shutdown_channel.take() {
246 let _ = tx.send(());
247 }
248 }
249}
250
251async fn load_prepared(
252 conn: Arc<tokio_postgres::Client>,
253 stmt: Statement,
254 binds: Vec<ToSqlHelper>,
255) -> QueryResult<BoxStream<'static, QueryResult<PgRow>>> {
256 let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?;
257
258 Ok(res
259 .map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
260 .map_ok(PgRow::new)
261 .boxed())
262}
263
264async fn execute_prepared(
265 conn: Arc<tokio_postgres::Client>,
266 stmt: Statement,
267 binds: Vec<ToSqlHelper>,
268) -> QueryResult<usize> {
269 let binds = binds
270 .iter()
271 .map(|b| b as &(dyn ToSql + Sync))
272 .collect::<Vec<_>>();
273
274 let res = tokio_postgres::Client::execute(&conn, &stmt, &binds as &[_])
275 .await
276 .map_err(ErrorHelper)?;
277 res.try_into()
278 .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))
279}
280
281#[inline(always)]
282fn update_transaction_manager_status<T>(
283 query_result: QueryResult<T>,
284 transaction_manager: &mut AnsiTransactionManager,
285) -> QueryResult<T> {
286 if let Err(diesel::result::Error::DatabaseError(DatabaseErrorKind::SerializationFailure, _)) =
287 query_result
288 {
289 transaction_manager
290 .status
291 .set_requires_rollback_maybe_up_to_top_level(true)
292 }
293 query_result
294}
295
296#[async_trait::async_trait]
297impl PrepareCallback<Statement, PgTypeMetadata> for Arc<tokio_postgres::Client> {
298 async fn prepare(
299 self,
300 sql: &str,
301 metadata: &[PgTypeMetadata],
302 _is_for_cache: PrepareForCache,
303 ) -> QueryResult<(Statement, Self)> {
304 let bind_types = metadata
305 .iter()
306 .map(type_from_oid)
307 .collect::<QueryResult<Vec<_>>>()?;
308
309 let stmt = self
310 .prepare_typed(sql, &bind_types)
311 .await
312 .map_err(ErrorHelper);
313 Ok((stmt?, self))
314 }
315}
316
317fn type_from_oid(t: &PgTypeMetadata) -> QueryResult<Type> {
318 let oid = t
319 .oid()
320 .map_err(|e| diesel::result::Error::SerializationError(Box::new(e) as _))?;
321
322 if let Some(tpe) = Type::from_oid(oid) {
323 return Ok(tpe);
324 }
325
326 Ok(Type::new(
327 format!("diesel_custom_type_{oid}"),
328 oid,
329 tokio_postgres::types::Kind::Simple,
330 "public".into(),
331 ))
332}
333
334impl AsyncPgConnection {
335 pub fn build_transaction(&mut self) -> TransactionBuilder<Self> {
362 TransactionBuilder::new(self)
363 }
364
365 pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
367 Self::setup(
368 conn,
369 None,
370 None,
371 Arc::new(std::sync::Mutex::new(
372 diesel::connection::get_default_instrumentation(),
373 )),
374 )
375 .await
376 }
377
378 pub async fn try_from_client_and_connection<S>(
381 client: tokio_postgres::Client,
382 conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
383 ) -> ConnectionResult<Self>
384 where
385 S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
386 {
387 let (error_rx, shutdown_tx) = drive_connection(conn);
388
389 Self::setup(
390 client,
391 Some(error_rx),
392 Some(shutdown_tx),
393 Arc::new(std::sync::Mutex::new(
394 diesel::connection::get_default_instrumentation(),
395 )),
396 )
397 .await
398 }
399
400 async fn setup(
401 conn: tokio_postgres::Client,
402 connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
403 shutdown_channel: Option<oneshot::Sender<()>>,
404 instrumentation: Arc<std::sync::Mutex<Option<Box<dyn Instrumentation>>>>,
405 ) -> ConnectionResult<Self> {
406 let mut conn = Self {
407 conn: Arc::new(conn),
408 stmt_cache: Arc::new(Mutex::new(StmtCache::new())),
409 transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
410 metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
411 connection_future,
412 shutdown_channel,
413 instrumentation,
414 };
415 conn.set_config_options()
416 .await
417 .map_err(ConnectionError::CouldntSetupConfiguration)?;
418 Ok(conn)
419 }
420
421 pub fn cancel_token(&self) -> tokio_postgres::CancelToken {
423 self.conn.cancel_token()
424 }
425
426 async fn set_config_options(&mut self) -> QueryResult<()> {
427 use crate::run_query_dsl::RunQueryDsl;
428
429 futures_util::try_join!(
430 diesel::sql_query("SET TIME ZONE 'UTC'").execute(self),
431 diesel::sql_query("SET CLIENT_ENCODING TO 'UTF8'").execute(self),
432 )?;
433 Ok(())
434 }
435
436 fn run_with_connection_future<'a, R: 'a>(
437 &self,
438 future: impl Future<Output = QueryResult<R>> + Send + 'a,
439 ) -> BoxFuture<'a, QueryResult<R>> {
440 let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
441 drive_future(connection_future, future).boxed()
442 }
443
444 fn with_prepared_statement<'a, T, F, R>(
445 &mut self,
446 query: T,
447 callback: fn(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
448 ) -> BoxFuture<'a, QueryResult<R>>
449 where
450 T: QueryFragment<diesel::pg::Pg> + QueryId,
451 F: Future<Output = QueryResult<R>> + Send + 'a,
452 R: Send,
453 {
454 self.record_instrumentation(InstrumentationEvent::start_query(&diesel::debug_query(
455 &query,
456 )));
457 let mut query_builder = PgQueryBuilder::default();
465
466 let bind_data = construct_bind_data(&query);
467
468 self.with_prepared_statement_after_sql_built(
470 callback,
471 query.is_safe_to_cache_prepared(&Pg),
472 T::query_id(),
473 query.to_sql(&mut query_builder, &Pg),
474 query_builder,
475 bind_data,
476 )
477 }
478
479 fn with_prepared_statement_after_sql_built<'a, F, R>(
480 &mut self,
481 callback: fn(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
482 is_safe_to_cache_prepared: QueryResult<bool>,
483 query_id: Option<std::any::TypeId>,
484 to_sql_result: QueryResult<()>,
485 query_builder: PgQueryBuilder,
486 bind_data: BindData,
487 ) -> BoxFuture<'a, QueryResult<R>>
488 where
489 F: Future<Output = QueryResult<R>> + Send + 'a,
490 R: Send,
491 {
492 let raw_connection = self.conn.clone();
493 let stmt_cache = self.stmt_cache.clone();
494 let metadata_cache = self.metadata_cache.clone();
495 let tm = self.transaction_state.clone();
496 let instrumentation = self.instrumentation.clone();
497 let BindData {
498 collect_bind_result,
499 fake_oid_locations,
500 generated_oids,
501 mut bind_collector,
502 } = bind_data;
503
504 async move {
505 let sql = to_sql_result.map(|_| query_builder.finish())?;
506 let res = async {
507 let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
508 collect_bind_result?;
509 if let Some(ref unresolved_types) = generated_oids {
514 let metadata_cache = &mut *metadata_cache.lock().await;
515 let mut real_oids = HashMap::new();
516
517 for ((schema, lookup_type_name), (fake_oid, fake_array_oid)) in
518 unresolved_types
519 {
520 let cache_key = PgMetadataCacheKey::new(
524 schema.as_deref().map(Into::into),
525 lookup_type_name.into(),
526 );
527 let real_metadata = if let Some(type_metadata) =
528 metadata_cache.lookup_type(&cache_key)
529 {
530 type_metadata
531 } else {
532 let type_metadata =
533 lookup_type(schema.clone(), lookup_type_name.clone(), &raw_connection)
534 .await?;
535 metadata_cache.store_type(cache_key, type_metadata);
536
537 PgTypeMetadata::from_result(Ok(type_metadata))
538 };
539 let (real_oid, real_array_oid) = unwrap_oids(&real_metadata);
541 real_oids.extend([(*fake_oid, real_oid), (*fake_array_oid, real_array_oid)]);
542 }
543
544 for m in &mut bind_collector.metadata {
546 let (oid, array_oid) = unwrap_oids(m);
547 *m = PgTypeMetadata::new(
548 real_oids.get(&oid).copied().unwrap_or(oid),
549 real_oids.get(&array_oid).copied().unwrap_or(array_oid)
550 );
551 }
552 for (bind_index, byte_index) in fake_oid_locations {
554 replace_fake_oid(&mut bind_collector.binds, &real_oids, bind_index, byte_index)
555 .ok_or_else(|| {
556 Error::SerializationError(
557 format!("diesel_async failed to replace a type OID serialized in bind value {bind_index}").into(),
558 )
559 })?;
560 }
561 }
562 let key = match query_id {
563 Some(id) => StatementCacheKey::Type(id),
564 None => StatementCacheKey::Sql {
565 sql: sql.clone(),
566 bind_types: bind_collector.metadata.clone(),
567 },
568 };
569 let stmt = {
570 let mut stmt_cache = stmt_cache.lock().await;
571 stmt_cache
572 .cached_prepared_statement(
573 key,
574 sql.clone(),
575 is_safe_to_cache_prepared,
576 &bind_collector.metadata,
577 raw_connection.clone(),
578 &instrumentation
579 )
580 .await?
581 .0
582 .clone()
583 };
584
585 let binds = bind_collector
586 .metadata
587 .into_iter()
588 .zip(bind_collector.binds)
589 .map(|(meta, bind)| ToSqlHelper(meta, bind))
590 .collect::<Vec<_>>();
591 callback(raw_connection, stmt.clone(), binds).await
592 };
593 let res = res.await;
594 let mut tm = tm.lock().await;
595 let r = update_transaction_manager_status(res, &mut tm);
596 instrumentation
597 .lock()
598 .unwrap_or_else(|p| p.into_inner())
599 .on_connection_event(InstrumentationEvent::finish_query(
600 &StrQueryHelper::new(&sql),
601 r.as_ref().err(),
602 ));
603
604 r
605 }
606 .boxed()
607 }
608
609 fn record_instrumentation(&self, event: InstrumentationEvent<'_>) {
610 self.instrumentation
611 .lock()
612 .unwrap_or_else(|p| p.into_inner())
613 .on_connection_event(event);
614 }
615}
616
617struct BindData {
618 collect_bind_result: Result<(), Error>,
619 fake_oid_locations: Vec<(usize, usize)>,
620 generated_oids: GeneratedOidTypeMap,
621 bind_collector: RawBytesBindCollector<Pg>,
622}
623
624fn construct_bind_data(query: &dyn QueryFragment<diesel::pg::Pg>) -> BindData {
625 let mut bind_collector_0 = RawBytesBindCollector::<diesel::pg::Pg>::new();
635 let mut metadata_lookup_0 = PgAsyncMetadataLookup {
636 custom_oid: false,
637 generated_oids: None,
638 oid_generator: |_, _| (FAKE_OID, FAKE_OID),
639 };
640 let collect_bind_result_0 =
641 query.collect_binds(&mut bind_collector_0, &mut metadata_lookup_0, &Pg);
642 if metadata_lookup_0.custom_oid {
652 let mut max_oid = bind_collector_0
656 .metadata
657 .iter()
658 .flat_map(|t| {
659 [
660 t.oid().unwrap_or_default(),
661 t.array_oid().unwrap_or_default(),
662 ]
663 })
664 .max()
665 .unwrap_or_default();
666 let mut bind_collector_1 = RawBytesBindCollector::<diesel::pg::Pg>::new();
667 let mut metadata_lookup_1 = PgAsyncMetadataLookup {
668 custom_oid: false,
669 generated_oids: Some(HashMap::new()),
670 oid_generator: move |_, _| {
671 max_oid += 2;
672 (max_oid, max_oid + 1)
673 },
674 };
675 let collect_bind_result_1 =
676 query.collect_binds(&mut bind_collector_1, &mut metadata_lookup_1, &Pg);
677
678 assert_eq!(
679 bind_collector_0.binds.len(),
680 bind_collector_0.metadata.len()
681 );
682 let fake_oid_locations = std::iter::zip(
683 bind_collector_0
684 .binds
685 .iter()
686 .zip(&bind_collector_0.metadata),
687 &bind_collector_1.binds,
688 )
689 .enumerate()
690 .flat_map(|(bind_index, ((bytes_0, metadata_0), bytes_1))| {
691 let (bytes_0, bytes_1) = if matches!(metadata_0.oid(), Ok(FAKE_OID)) {
695 (
696 bytes_0.as_deref().unwrap_or_default(),
697 bytes_1.as_deref().unwrap_or_default(),
698 )
699 } else {
700 (&[] as &[_], &[] as &[_])
704 };
705 let lookup_map = metadata_lookup_1
706 .generated_oids
707 .as_ref()
708 .map(|map| {
709 map.values()
710 .flat_map(|(oid, array_oid)| [*oid, *array_oid])
711 .collect::<HashSet<_>>()
712 })
713 .unwrap_or_default();
714 std::iter::zip(
715 bytes_0.windows(std::mem::size_of_val(&FAKE_OID)),
716 bytes_1.windows(std::mem::size_of_val(&FAKE_OID)),
717 )
718 .enumerate()
719 .filter_map(move |(byte_index, (l, r))| {
720 let r_val = u32::from_be_bytes(r.try_into().expect("That's the right size"));
729 (l == FAKE_OID.to_be_bytes()
730 && r != FAKE_OID.to_be_bytes()
731 && lookup_map.contains(&r_val))
732 .then_some((bind_index, byte_index))
733 })
734 })
735 .collect::<Vec<_>>();
737 BindData {
738 collect_bind_result: collect_bind_result_0.and(collect_bind_result_1),
739 fake_oid_locations,
740 generated_oids: metadata_lookup_1.generated_oids,
741 bind_collector: bind_collector_1,
742 }
743 } else {
744 BindData {
745 collect_bind_result: collect_bind_result_0,
746 fake_oid_locations: Vec::new(),
747 generated_oids: None,
748 bind_collector: bind_collector_0,
749 }
750 }
751}
752
753type GeneratedOidTypeMap = Option<HashMap<(Option<String>, String), (u32, u32)>>;
754
755struct PgAsyncMetadataLookup<F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static> {
758 custom_oid: bool,
759 generated_oids: GeneratedOidTypeMap,
760 oid_generator: F,
761}
762
763impl<F> PgMetadataLookup for PgAsyncMetadataLookup<F>
764where
765 F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static,
766{
767 fn lookup_type(&mut self, type_name: &str, schema: Option<&str>) -> PgTypeMetadata {
768 self.custom_oid = true;
769
770 let oid = if let Some(map) = &mut self.generated_oids {
771 *map.entry((schema.map(ToOwned::to_owned), type_name.to_owned()))
772 .or_insert_with(|| (self.oid_generator)(type_name, schema))
773 } else {
774 (self.oid_generator)(type_name, schema)
775 };
776
777 PgTypeMetadata::from_result(Ok(oid))
778 }
779}
780
781async fn lookup_type(
782 schema: Option<String>,
783 type_name: String,
784 raw_connection: &tokio_postgres::Client,
785) -> QueryResult<(u32, u32)> {
786 let r = if let Some(schema) = schema.as_ref() {
787 raw_connection
788 .query_one(
789 "SELECT pg_type.oid, pg_type.typarray FROM pg_type \
790 INNER JOIN pg_namespace ON pg_type.typnamespace = pg_namespace.oid \
791 WHERE pg_type.typname = $1 AND pg_namespace.nspname = $2 \
792 LIMIT 1",
793 &[&type_name, schema],
794 )
795 .await
796 .map_err(ErrorHelper)?
797 } else {
798 raw_connection
799 .query_one(
800 "SELECT pg_type.oid, pg_type.typarray FROM pg_type \
801 WHERE pg_type.oid = quote_ident($1)::regtype::oid \
802 LIMIT 1",
803 &[&type_name],
804 )
805 .await
806 .map_err(ErrorHelper)?
807 };
808 Ok((r.get(0), r.get(1)))
809}
810
811fn unwrap_oids(metadata: &PgTypeMetadata) -> (u32, u32) {
812 let err_msg = "PgTypeMetadata is supposed to always be Ok here";
813 (
814 metadata.oid().expect(err_msg),
815 metadata.array_oid().expect(err_msg),
816 )
817}
818
819fn replace_fake_oid(
820 binds: &mut [Option<Vec<u8>>],
821 real_oids: &HashMap<u32, u32>,
822 bind_index: usize,
823 byte_index: usize,
824) -> Option<()> {
825 let serialized_oid = binds
826 .get_mut(bind_index)?
827 .as_mut()?
828 .get_mut(byte_index..)?
829 .first_chunk_mut::<4>()?;
830 *serialized_oid = real_oids
831 .get(&u32::from_be_bytes(*serialized_oid))?
832 .to_be_bytes();
833 Some(())
834}
835
836async fn drive_future<R>(
837 connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
838 client_future: impl Future<Output = Result<R, diesel::result::Error>>,
839) -> Result<R, diesel::result::Error> {
840 if let Some(mut connection_future) = connection_future {
841 let client_future = std::pin::pin!(client_future);
842 let connection_future = std::pin::pin!(connection_future.recv());
843 match futures_util::future::select(client_future, connection_future).await {
844 Either::Left((res, _)) => res,
845 Either::Right((Ok(e), _)) => Err(self::error_helper::from_tokio_postgres_error(e)),
848 Either::Right((Err(e), _)) => Err(diesel::result::Error::DatabaseError(
850 DatabaseErrorKind::UnableToSendCommand,
851 Box::new(e.to_string()),
852 )),
853 }
854 } else {
855 client_future.await
856 }
857}
858
859fn drive_connection<S>(
860 conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
861) -> (
862 broadcast::Receiver<Arc<tokio_postgres::Error>>,
863 oneshot::Sender<()>,
864)
865where
866 S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
867{
868 let (error_tx, error_rx) = tokio::sync::broadcast::channel(1);
869 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
870
871 tokio::spawn(async move {
872 match futures_util::future::select(shutdown_rx, conn).await {
873 Either::Left(_) | Either::Right((Ok(_), _)) => {}
874 Either::Right((Err(e), _)) => {
875 let _ = error_tx.send(Arc::new(e));
876 }
877 }
878 });
879
880 (error_rx, shutdown_tx)
881}
882
883#[cfg(any(
884 feature = "deadpool",
885 feature = "bb8",
886 feature = "mobc",
887 feature = "r2d2"
888))]
889impl crate::pooled_connection::PoolableConnection for AsyncPgConnection {
890 fn is_broken(&mut self) -> bool {
891 use crate::TransactionManager;
892
893 Self::TransactionManager::is_broken_transaction_manager(self) || self.conn.is_closed()
894 }
895}
896
897#[cfg(test)]
898mod tests {
899 use super::*;
900 use crate::run_query_dsl::RunQueryDsl;
901 use diesel::sql_types::Integer;
902 use diesel::IntoSql;
903
904 #[tokio::test]
905 async fn pipelining() {
906 let database_url =
907 std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
908 let mut conn = crate::AsyncPgConnection::establish(&database_url)
909 .await
910 .unwrap();
911
912 let q1 = diesel::select(1_i32.into_sql::<Integer>());
913 let q2 = diesel::select(2_i32.into_sql::<Integer>());
914
915 let f1 = q1.get_result::<i32>(&mut conn);
916 let f2 = q2.get_result::<i32>(&mut conn);
917
918 let (r1, r2) = futures_util::try_join!(f1, f2).unwrap();
919
920 assert_eq!(r1, 1);
921 assert_eq!(r2, 2);
922 }
923}