diesel_async/sync_connection_wrapper/
mod.rs

1//! This module contains a wrapper type
2//! that provides a [`crate::AsyncConnection`]
3//! implementation for types that implement
4//! [`diesel::Connection`]. Using this type
5//! might be useful for the following usecases:
6//!
7//! * using a sync Connection implementation in async context
8//! * using the same code base for async crates needing multiple backends
9
10use crate::{AsyncConnection, SimpleAsyncConnection, TransactionManager};
11use diesel::backend::{Backend, DieselReserveSpecialization};
12use diesel::connection::Instrumentation;
13use diesel::connection::{
14    Connection, LoadConnection, TransactionManagerStatus, WithMetadataLookup,
15};
16use diesel::query_builder::{
17    AsQuery, CollectedQuery, MoveableBindCollector, QueryBuilder, QueryFragment, QueryId,
18};
19use diesel::row::IntoOwnedRow;
20use diesel::{ConnectionResult, QueryResult};
21use futures_util::future::BoxFuture;
22use futures_util::stream::BoxStream;
23use futures_util::{FutureExt, StreamExt, TryFutureExt};
24use std::marker::PhantomData;
25use std::sync::{Arc, Mutex};
26use tokio::task::JoinError;
27
28#[cfg(feature = "sqlite")]
29mod sqlite;
30
31fn from_tokio_join_error(join_error: JoinError) -> diesel::result::Error {
32    diesel::result::Error::DatabaseError(
33        diesel::result::DatabaseErrorKind::UnableToSendCommand,
34        Box::new(join_error.to_string()),
35    )
36}
37
38/// A wrapper of a [`diesel::connection::Connection`] usable in async context.
39///
40/// It implements AsyncConnection if [`diesel::connection::Connection`] fullfils requirements:
41/// * it's a [`diesel::connection::LoadConnection`]
42/// * its [`diesel::connection::Connection::Backend`] has a [`diesel::query_builder::BindCollector`] implementing [`diesel::query_builder::MoveableBindCollector`]
43/// * its [`diesel::connection::LoadConnection::Row`] implements [`diesel::row::IntoOwnedRow`]
44///
45/// Internally this wrapper type will use `spawn_blocking` on tokio
46/// to execute the request on the inner connection. This implies a
47/// dependency on tokio and that the runtime is running.
48///
49/// Note that only SQLite is supported at the moment.
50///
51/// # Examples
52///
53/// ```rust
54/// # include!("../doctest_setup.rs");
55/// use diesel_async::RunQueryDsl;
56/// use schema::users;
57///
58/// async fn some_async_fn() {
59/// # let database_url = database_url();
60///          use diesel_async::AsyncConnection;
61///          use diesel::sqlite::SqliteConnection;
62///          let mut conn =
63///          SyncConnectionWrapper::<SqliteConnection>::establish(&database_url).await.unwrap();
64/// # create_tables(&mut conn).await;
65///
66///          let all_users = users::table.load::<(i32, String)>(&mut conn).await.unwrap();
67/// #         assert_eq!(all_users.len(), 2);
68/// }
69///
70/// # #[cfg(feature = "sqlite")]
71/// # #[tokio::main]
72/// # async fn main() {
73/// #    some_async_fn().await;
74/// # }
75/// ```
76pub struct SyncConnectionWrapper<C> {
77    inner: Arc<Mutex<C>>,
78}
79
80#[async_trait::async_trait]
81impl<C> SimpleAsyncConnection for SyncConnectionWrapper<C>
82where
83    C: diesel::connection::Connection + 'static,
84{
85    async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
86        let query = query.to_string();
87        self.spawn_blocking(move |inner| inner.batch_execute(query.as_str()))
88            .await
89    }
90}
91
92#[async_trait::async_trait]
93impl<C, MD, O> AsyncConnection for SyncConnectionWrapper<C>
94where
95    // Backend bounds
96    <C as Connection>::Backend: std::default::Default + DieselReserveSpecialization,
97    <C::Backend as Backend>::QueryBuilder: std::default::Default,
98    // Connection bounds
99    C: Connection + LoadConnection + WithMetadataLookup + 'static,
100    <C as Connection>::TransactionManager: Send,
101    // BindCollector bounds
102    MD: Send + 'static,
103    for<'a> <C::Backend as Backend>::BindCollector<'a>:
104        MoveableBindCollector<C::Backend, BindData = MD> + std::default::Default,
105    // Row bounds
106    O: 'static + Send + for<'conn> diesel::row::Row<'conn, C::Backend>,
107    for<'conn, 'query> <C as LoadConnection>::Row<'conn, 'query>:
108        IntoOwnedRow<'conn, <C as Connection>::Backend, OwnedRow = O>,
109{
110    type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult<Self::Stream<'conn, 'query>>>;
111    type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult<usize>>;
112    type Stream<'conn, 'query> = BoxStream<'static, QueryResult<Self::Row<'conn, 'query>>>;
113    type Row<'conn, 'query> = O;
114    type Backend = <C as Connection>::Backend;
115    type TransactionManager = SyncTransactionManagerWrapper<<C as Connection>::TransactionManager>;
116
117    async fn establish(database_url: &str) -> ConnectionResult<Self> {
118        let database_url = database_url.to_string();
119        tokio::task::spawn_blocking(move || C::establish(&database_url))
120            .await
121            .unwrap_or_else(|e| Err(diesel::ConnectionError::BadConnection(e.to_string())))
122            .map(|c| SyncConnectionWrapper::new(c))
123    }
124
125    fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
126    where
127        T: AsQuery + 'query,
128        T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
129    {
130        self.execute_with_prepared_query(source.as_query(), |conn, query| {
131            use diesel::row::IntoOwnedRow;
132            let mut cache = <<<C as LoadConnection>::Row<'_, '_> as IntoOwnedRow<
133                <C as Connection>::Backend,
134            >>::Cache as Default>::default();
135            let cursor = conn.load(&query)?;
136
137            let size_hint = cursor.size_hint();
138            let mut out = Vec::with_capacity(size_hint.1.unwrap_or(size_hint.0));
139            // we use an explicit loop here to easily propagate possible errors
140            // as early as possible
141            for row in cursor {
142                out.push(Ok(IntoOwnedRow::into_owned(row?, &mut cache)));
143            }
144
145            Ok(out)
146        })
147        .map_ok(|rows| futures_util::stream::iter(rows).boxed())
148        .boxed()
149    }
150
151    fn execute_returning_count<'query, T>(&mut self, source: T) -> Self::ExecuteFuture<'_, 'query>
152    where
153        T: QueryFragment<Self::Backend> + QueryId,
154    {
155        self.execute_with_prepared_query(source, |conn, query| conn.execute_returning_count(&query))
156    }
157
158    fn transaction_state(
159        &mut self,
160    ) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData {
161        self.exclusive_connection().transaction_state()
162    }
163
164    fn instrumentation(&mut self) -> &mut dyn Instrumentation {
165        // there should be no other pending future when this is called
166        // that means there is only one instance of this arc and
167        // we can simply access the inner data
168        if let Some(inner) = Arc::get_mut(&mut self.inner) {
169            inner
170                .get_mut()
171                .unwrap_or_else(|p| p.into_inner())
172                .instrumentation()
173        } else {
174            panic!("Cannot access shared instrumentation")
175        }
176    }
177
178    fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
179        // there should be no other pending future when this is called
180        // that means there is only one instance of this arc and
181        // we can simply access the inner data
182        if let Some(inner) = Arc::get_mut(&mut self.inner) {
183            inner
184                .get_mut()
185                .unwrap_or_else(|p| p.into_inner())
186                .set_instrumentation(instrumentation)
187        } else {
188            panic!("Cannot access shared instrumentation")
189        }
190    }
191}
192
193/// A wrapper of a diesel transaction manager usable in async context.
194pub struct SyncTransactionManagerWrapper<T>(PhantomData<T>);
195
196#[async_trait::async_trait]
197impl<T, C> TransactionManager<SyncConnectionWrapper<C>> for SyncTransactionManagerWrapper<T>
198where
199    SyncConnectionWrapper<C>: AsyncConnection,
200    C: Connection + 'static,
201    T: diesel::connection::TransactionManager<C> + Send,
202{
203    type TransactionStateData = T::TransactionStateData;
204
205    async fn begin_transaction(conn: &mut SyncConnectionWrapper<C>) -> QueryResult<()> {
206        conn.spawn_blocking(move |inner| T::begin_transaction(inner))
207            .await
208    }
209
210    async fn commit_transaction(conn: &mut SyncConnectionWrapper<C>) -> QueryResult<()> {
211        conn.spawn_blocking(move |inner| T::commit_transaction(inner))
212            .await
213    }
214
215    async fn rollback_transaction(conn: &mut SyncConnectionWrapper<C>) -> QueryResult<()> {
216        conn.spawn_blocking(move |inner| T::rollback_transaction(inner))
217            .await
218    }
219
220    fn transaction_manager_status_mut(
221        conn: &mut SyncConnectionWrapper<C>,
222    ) -> &mut TransactionManagerStatus {
223        T::transaction_manager_status_mut(conn.exclusive_connection())
224    }
225}
226
227impl<C> SyncConnectionWrapper<C> {
228    /// Builds a wrapper with this underlying sync connection
229    pub fn new(connection: C) -> Self
230    where
231        C: Connection,
232    {
233        SyncConnectionWrapper {
234            inner: Arc::new(Mutex::new(connection)),
235        }
236    }
237
238    /// Run a operation directly with the inner connection
239    ///
240    /// This function is usful to register custom functions
241    /// and collection for Sqlite for example
242    ///
243    /// # Example
244    ///
245    /// ```rust
246    /// # include!("../doctest_setup.rs");
247    /// # #[tokio::main]
248    /// # async fn main() {
249    /// #     run_test().await.unwrap();
250    /// # }
251    /// #
252    /// # async fn run_test() -> QueryResult<()> {
253    /// #     let mut conn = establish_connection().await;
254    /// conn.spawn_blocking(|conn| {
255    ///    // sqlite.rs sqlite NOCASE only works for ASCII characters,
256    ///    // this collation allows handling UTF-8 (barring locale differences)
257    ///    conn.register_collation("RUSTNOCASE", |rhs, lhs| {
258    ///     rhs.to_lowercase().cmp(&lhs.to_lowercase())
259    ///   })
260    /// }).await
261    ///
262    /// # }
263    /// ```
264    pub fn spawn_blocking<'a, R>(
265        &mut self,
266        task: impl FnOnce(&mut C) -> QueryResult<R> + Send + 'static,
267    ) -> BoxFuture<'a, QueryResult<R>>
268    where
269        C: Connection + 'static,
270        R: Send + 'static,
271    {
272        let inner = self.inner.clone();
273        tokio::task::spawn_blocking(move || {
274            let mut inner = inner.lock().unwrap_or_else(|poison| {
275                // try to be resilient by providing the guard
276                inner.clear_poison();
277                poison.into_inner()
278            });
279            task(&mut inner)
280        })
281        .unwrap_or_else(|err| QueryResult::Err(from_tokio_join_error(err)))
282        .boxed()
283    }
284
285    fn execute_with_prepared_query<'a, MD, Q, R>(
286        &mut self,
287        query: Q,
288        callback: impl FnOnce(&mut C, &CollectedQuery<MD>) -> QueryResult<R> + Send + 'static,
289    ) -> BoxFuture<'a, QueryResult<R>>
290    where
291        // Backend bounds
292        <C as Connection>::Backend: std::default::Default + DieselReserveSpecialization,
293        <C::Backend as Backend>::QueryBuilder: std::default::Default,
294        // Connection bounds
295        C: Connection + LoadConnection + WithMetadataLookup + 'static,
296        <C as Connection>::TransactionManager: Send,
297        // BindCollector bounds
298        MD: Send + 'static,
299        for<'b> <C::Backend as Backend>::BindCollector<'b>:
300            MoveableBindCollector<C::Backend, BindData = MD> + std::default::Default,
301        // Arguments/Return bounds
302        Q: QueryFragment<C::Backend> + QueryId,
303        R: Send + 'static,
304    {
305        let backend = C::Backend::default();
306
307        let (collect_bind_result, collector_data) = {
308            let exclusive = self.inner.clone();
309            let mut inner = exclusive.lock().unwrap_or_else(|poison| {
310                // try to be resilient by providing the guard
311                exclusive.clear_poison();
312                poison.into_inner()
313            });
314            let mut bind_collector =
315                <<C::Backend as Backend>::BindCollector<'_> as Default>::default();
316            let metadata_lookup = inner.metadata_lookup();
317            let result = query.collect_binds(&mut bind_collector, metadata_lookup, &backend);
318            let collector_data = bind_collector.moveable();
319
320            (result, collector_data)
321        };
322
323        let mut query_builder = <<C::Backend as Backend>::QueryBuilder as Default>::default();
324        let sql = query
325            .to_sql(&mut query_builder, &backend)
326            .map(|_| query_builder.finish());
327        let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&backend);
328
329        self.spawn_blocking(|inner| {
330            collect_bind_result?;
331            let query = CollectedQuery::new(sql?, is_safe_to_cache_prepared?, collector_data);
332            callback(inner, &query)
333        })
334    }
335
336    /// Gets an exclusive access to the underlying diesel Connection
337    ///
338    /// It panics in case of shared access.
339    /// This is typically used only used during transaction.
340    pub(self) fn exclusive_connection(&mut self) -> &mut C
341    where
342        C: Connection,
343    {
344        // there should be no other pending future when this is called
345        // that means there is only one instance of this Arc and
346        // we can simply access the inner data
347        if let Some(conn_mutex) = Arc::get_mut(&mut self.inner) {
348            conn_mutex
349                .get_mut()
350                .expect("Mutex is poisoned, a thread must have panicked holding it.")
351        } else {
352            panic!("Cannot access shared transaction state")
353        }
354    }
355}
356
357#[cfg(any(
358    feature = "deadpool",
359    feature = "bb8",
360    feature = "mobc",
361    feature = "r2d2"
362))]
363impl<C> crate::pooled_connection::PoolableConnection for SyncConnectionWrapper<C>
364where
365    Self: AsyncConnection,
366{
367    fn is_broken(&mut self) -> bool {
368        Self::TransactionManager::is_broken_transaction_manager(self)
369    }
370}