diesel_async/
async_connection_wrapper.rs

1//! This module contains an wrapper type
2//! that provides a [`diesel::Connection`]
3//! implementation for types that implement
4//! [`crate::AsyncConnection`]. Using this type
5//! might be useful for the following usecases:
6//!
7//! * Executing migrations on application startup
8//! * Using a pure rust diesel connection implementation
9//!   as replacement for the existing connection
10//!   implementations provided by diesel
11
12use futures_util::Future;
13use futures_util::Stream;
14use futures_util::StreamExt;
15use std::pin::Pin;
16
17/// This is a helper trait that allows to customize the
18/// async runtime used to execute futures as part of the
19/// [`AsyncConnectionWrapper`] type. By default a
20/// tokio runtime is used.
21pub trait BlockOn {
22    /// This function should allow to execute a
23    /// given future to get the result
24    fn block_on<F>(&self, f: F) -> F::Output
25    where
26        F: Future;
27
28    /// This function should be used to construct
29    /// a new runtime instance
30    fn get_runtime() -> Self;
31}
32
33/// A helper type that wraps an [`AsyncConnection`][crate::AsyncConnection] to
34/// provide a sync [`diesel::Connection`] implementation.
35///
36/// Internally this wrapper type will use `block_on` to wait for
37/// the execution of futures from the inner connection. This implies you
38/// cannot use functions of this type in a scope with an already existing
39/// tokio runtime. If you are in a situation where you want to use this
40/// connection wrapper in the scope of an existing tokio runtime (for example
41/// for running migrations via `diesel_migration`) you need to wrap
42/// the relevant code block into a `tokio::task::spawn_blocking` task.
43///
44/// # Examples
45///
46/// ```rust,no_run
47/// # include!("doctest_setup.rs");
48/// use schema::users;
49/// use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
50/// #
51/// # fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
52/// use diesel::prelude::{RunQueryDsl, Connection};
53/// # let database_url = database_url();
54/// let mut conn = AsyncConnectionWrapper::<DbConnection>::establish(&database_url)?;
55///
56/// let all_users = users::table.load::<(i32, String)>(&mut conn)?;
57/// # assert_eq!(all_users.len(), 0);
58/// # Ok(())
59/// # }
60/// ```
61///
62/// If you are in the scope of an existing tokio runtime you need to use
63/// `tokio::task::spawn_blocking` to encapsulate the blocking tasks
64/// ```rust,no_run
65/// # include!("doctest_setup.rs");
66/// use schema::users;
67/// use diesel_async::async_connection_wrapper::AsyncConnectionWrapper;
68///
69/// async fn some_async_fn() {
70/// # let database_url = database_url();
71///      // need to use `spawn_blocking` to execute
72///      // a blocking task in the scope of an existing runtime
73///      let res = tokio::task::spawn_blocking(move || {
74///          use diesel::prelude::{RunQueryDsl, Connection};
75///          let mut conn = AsyncConnectionWrapper::<DbConnection>::establish(&database_url)?;
76///
77///          let all_users = users::table.load::<(i32, String)>(&mut conn)?;
78/// #         assert_eq!(all_users.len(), 0);
79///          Ok::<_, Box<dyn std::error::Error + Send + Sync>>(())
80///      }).await;
81///
82/// # res.unwrap().unwrap();
83/// }
84///
85/// # #[tokio::main]
86/// # async fn main() {
87/// #    some_async_fn().await;
88/// # }
89/// ```
90#[cfg(feature = "tokio")]
91pub type AsyncConnectionWrapper<C, B = self::implementation::Tokio> =
92    self::implementation::AsyncConnectionWrapper<C, B>;
93
94/// A helper type that wraps an [`crate::AsyncConnectionWrapper`] to
95/// provide a sync [`diesel::Connection`] implementation.
96///
97/// Internally this wrapper type will use `block_on` to wait for
98/// the execution of futures from the inner connection.
99#[cfg(not(feature = "tokio"))]
100pub use self::implementation::AsyncConnectionWrapper;
101
102mod implementation {
103    use diesel::connection::{Instrumentation, SimpleConnection};
104    use std::ops::{Deref, DerefMut};
105
106    use super::*;
107
108    pub struct AsyncConnectionWrapper<C, B> {
109        inner: C,
110        runtime: B,
111    }
112
113    impl<C, B> From<C> for AsyncConnectionWrapper<C, B>
114    where
115        C: crate::AsyncConnection,
116        B: BlockOn + Send,
117    {
118        fn from(inner: C) -> Self {
119            Self {
120                inner,
121                runtime: B::get_runtime(),
122            }
123        }
124    }
125
126    impl<C, B> Deref for AsyncConnectionWrapper<C, B> {
127        type Target = C;
128
129        fn deref(&self) -> &Self::Target {
130            &self.inner
131        }
132    }
133
134    impl<C, B> DerefMut for AsyncConnectionWrapper<C, B> {
135        fn deref_mut(&mut self) -> &mut Self::Target {
136            &mut self.inner
137        }
138    }
139
140    impl<C, B> diesel::connection::SimpleConnection for AsyncConnectionWrapper<C, B>
141    where
142        C: crate::SimpleAsyncConnection,
143        B: BlockOn,
144    {
145        fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
146            let f = self.inner.batch_execute(query);
147            self.runtime.block_on(f)
148        }
149    }
150
151    impl<C, B> diesel::connection::ConnectionSealed for AsyncConnectionWrapper<C, B> {}
152
153    impl<C, B> diesel::connection::Connection for AsyncConnectionWrapper<C, B>
154    where
155        C: crate::AsyncConnection,
156        B: BlockOn + Send,
157    {
158        type Backend = C::Backend;
159
160        type TransactionManager = AsyncConnectionWrapperTransactionManagerWrapper;
161
162        fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
163            let runtime = B::get_runtime();
164            let f = C::establish(database_url);
165            let inner = runtime.block_on(f)?;
166            Ok(Self { inner, runtime })
167        }
168
169        fn execute_returning_count<T>(&mut self, source: &T) -> diesel::QueryResult<usize>
170        where
171            T: diesel::query_builder::QueryFragment<Self::Backend> + diesel::query_builder::QueryId,
172        {
173            let f = self.inner.execute_returning_count(source);
174            self.runtime.block_on(f)
175        }
176
177        fn transaction_state(
178            &mut self,
179        ) -> &mut <Self::TransactionManager as diesel::connection::TransactionManager<Self>>::TransactionStateData{
180            self.inner.transaction_state()
181        }
182
183        fn instrumentation(&mut self) -> &mut dyn Instrumentation {
184            self.inner.instrumentation()
185        }
186
187        fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
188            self.inner.set_instrumentation(instrumentation);
189        }
190    }
191
192    impl<C, B> diesel::connection::LoadConnection for AsyncConnectionWrapper<C, B>
193    where
194        C: crate::AsyncConnection,
195        B: BlockOn + Send,
196    {
197        type Cursor<'conn, 'query> = AsyncCursorWrapper<'conn, C::Stream<'conn, 'query>, B>
198    where
199        Self: 'conn;
200
201        type Row<'conn, 'query> = C::Row<'conn, 'query>
202    where
203        Self: 'conn;
204
205        fn load<'conn, 'query, T>(
206            &'conn mut self,
207            source: T,
208        ) -> diesel::QueryResult<Self::Cursor<'conn, 'query>>
209        where
210            T: diesel::query_builder::Query
211                + diesel::query_builder::QueryFragment<Self::Backend>
212                + diesel::query_builder::QueryId
213                + 'query,
214            Self::Backend: diesel::expression::QueryMetadata<T::SqlType>,
215        {
216            let f = self.inner.load(source);
217            let stream = self.runtime.block_on(f)?;
218
219            Ok(AsyncCursorWrapper {
220                stream: Box::pin(stream),
221                runtime: &self.runtime,
222            })
223        }
224    }
225
226    pub struct AsyncCursorWrapper<'a, S, B> {
227        stream: Pin<Box<S>>,
228        runtime: &'a B,
229    }
230
231    impl<'a, S, B> Iterator for AsyncCursorWrapper<'a, S, B>
232    where
233        S: Stream,
234        B: BlockOn,
235    {
236        type Item = S::Item;
237
238        fn next(&mut self) -> Option<Self::Item> {
239            let f = self.stream.next();
240            self.runtime.block_on(f)
241        }
242    }
243
244    pub struct AsyncConnectionWrapperTransactionManagerWrapper;
245
246    impl<C, B> diesel::connection::TransactionManager<AsyncConnectionWrapper<C, B>>
247        for AsyncConnectionWrapperTransactionManagerWrapper
248    where
249        C: crate::AsyncConnection,
250        B: BlockOn + Send,
251    {
252        type TransactionStateData =
253            <C::TransactionManager as crate::TransactionManager<C>>::TransactionStateData;
254
255        fn begin_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
256            let f = <C::TransactionManager as crate::TransactionManager<_>>::begin_transaction(
257                &mut conn.inner,
258            );
259            conn.runtime.block_on(f)
260        }
261
262        fn rollback_transaction(
263            conn: &mut AsyncConnectionWrapper<C, B>,
264        ) -> diesel::QueryResult<()> {
265            let f = <C::TransactionManager as crate::TransactionManager<_>>::rollback_transaction(
266                &mut conn.inner,
267            );
268            conn.runtime.block_on(f)
269        }
270
271        fn commit_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
272            let f = <C::TransactionManager as crate::TransactionManager<_>>::commit_transaction(
273                &mut conn.inner,
274            );
275            conn.runtime.block_on(f)
276        }
277
278        fn transaction_manager_status_mut(
279            conn: &mut AsyncConnectionWrapper<C, B>,
280        ) -> &mut diesel::connection::TransactionManagerStatus {
281            <C::TransactionManager as crate::TransactionManager<_>>::transaction_manager_status_mut(
282                &mut conn.inner,
283            )
284        }
285
286        fn is_broken_transaction_manager(conn: &mut AsyncConnectionWrapper<C, B>) -> bool {
287            <C::TransactionManager as crate::TransactionManager<_>>::is_broken_transaction_manager(
288                &mut conn.inner,
289            )
290        }
291    }
292
293    #[cfg(feature = "r2d2")]
294    impl<C, B> diesel::r2d2::R2D2Connection for AsyncConnectionWrapper<C, B>
295    where
296        B: BlockOn,
297        Self: diesel::Connection,
298        C: crate::AsyncConnection<Backend = <Self as diesel::Connection>::Backend>
299            + crate::pooled_connection::PoolableConnection
300            + 'static,
301        diesel::dsl::select<diesel::dsl::AsExprOf<i32, diesel::sql_types::Integer>>:
302            crate::methods::ExecuteDsl<C>,
303        diesel::query_builder::SqlQuery: crate::methods::ExecuteDsl<C>,
304    {
305        fn ping(&mut self) -> diesel::QueryResult<()> {
306            let fut = crate::pooled_connection::PoolableConnection::ping(
307                &mut self.inner,
308                &crate::pooled_connection::RecyclingMethod::Verified,
309            );
310            self.runtime.block_on(fut)
311        }
312
313        fn is_broken(&mut self) -> bool {
314            crate::pooled_connection::PoolableConnection::is_broken(&mut self.inner)
315        }
316    }
317
318    impl<C, B> diesel::migration::MigrationConnection for AsyncConnectionWrapper<C, B>
319    where
320        B: BlockOn,
321        Self: diesel::Connection,
322    {
323        fn setup(&mut self) -> diesel::QueryResult<usize> {
324            self.batch_execute(diesel::migration::CREATE_MIGRATIONS_TABLE)
325                .map(|()| 0)
326        }
327    }
328
329    #[cfg(feature = "tokio")]
330    pub struct Tokio {
331        handle: Option<tokio::runtime::Handle>,
332        runtime: Option<tokio::runtime::Runtime>,
333    }
334
335    #[cfg(feature = "tokio")]
336    impl BlockOn for Tokio {
337        fn block_on<F>(&self, f: F) -> F::Output
338        where
339            F: Future,
340        {
341            if let Some(handle) = &self.handle {
342                handle.block_on(f)
343            } else if let Some(runtime) = &self.runtime {
344                runtime.block_on(f)
345            } else {
346                unreachable!()
347            }
348        }
349
350        fn get_runtime() -> Self {
351            if let Ok(handle) = tokio::runtime::Handle::try_current() {
352                Self {
353                    handle: Some(handle),
354                    runtime: None,
355                }
356            } else {
357                let runtime = tokio::runtime::Builder::new_current_thread()
358                    .enable_io()
359                    .build()
360                    .unwrap();
361                Self {
362                    handle: None,
363                    runtime: Some(runtime),
364                }
365            }
366        }
367    }
368}