diesel_async/sync_connection_wrapper/
mod.rs1use crate::{AsyncConnection, SimpleAsyncConnection, TransactionManager};
11use diesel::backend::{Backend, DieselReserveSpecialization};
12use diesel::connection::Instrumentation;
13use diesel::connection::{
14 Connection, LoadConnection, TransactionManagerStatus, WithMetadataLookup,
15};
16use diesel::query_builder::{
17 AsQuery, CollectedQuery, MoveableBindCollector, QueryBuilder, QueryFragment, QueryId,
18};
19use diesel::row::IntoOwnedRow;
20use diesel::{ConnectionResult, QueryResult};
21use futures_util::future::BoxFuture;
22use futures_util::stream::BoxStream;
23use futures_util::{FutureExt, StreamExt, TryFutureExt};
24use std::marker::PhantomData;
25use std::sync::{Arc, Mutex};
26use tokio::task::JoinError;
27
28#[cfg(feature = "sqlite")]
29mod sqlite;
30
31fn from_tokio_join_error(join_error: JoinError) -> diesel::result::Error {
32 diesel::result::Error::DatabaseError(
33 diesel::result::DatabaseErrorKind::UnableToSendCommand,
34 Box::new(join_error.to_string()),
35 )
36}
37
38pub struct SyncConnectionWrapper<C> {
77 inner: Arc<Mutex<C>>,
78}
79
80#[async_trait::async_trait]
81impl<C> SimpleAsyncConnection for SyncConnectionWrapper<C>
82where
83 C: diesel::connection::Connection + 'static,
84{
85 async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
86 let query = query.to_string();
87 self.spawn_blocking(move |inner| inner.batch_execute(query.as_str()))
88 .await
89 }
90}
91
92#[async_trait::async_trait]
93impl<C, MD, O> AsyncConnection for SyncConnectionWrapper<C>
94where
95 <C as Connection>::Backend: std::default::Default + DieselReserveSpecialization,
97 <C::Backend as Backend>::QueryBuilder: std::default::Default,
98 C: Connection + LoadConnection + WithMetadataLookup + 'static,
100 <C as Connection>::TransactionManager: Send,
101 MD: Send + 'static,
103 for<'a> <C::Backend as Backend>::BindCollector<'a>:
104 MoveableBindCollector<C::Backend, BindData = MD> + std::default::Default,
105 O: 'static + Send + for<'conn> diesel::row::Row<'conn, C::Backend>,
107 for<'conn, 'query> <C as LoadConnection>::Row<'conn, 'query>:
108 IntoOwnedRow<'conn, <C as Connection>::Backend, OwnedRow = O>,
109{
110 type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult<Self::Stream<'conn, 'query>>>;
111 type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult<usize>>;
112 type Stream<'conn, 'query> = BoxStream<'static, QueryResult<Self::Row<'conn, 'query>>>;
113 type Row<'conn, 'query> = O;
114 type Backend = <C as Connection>::Backend;
115 type TransactionManager = SyncTransactionManagerWrapper<<C as Connection>::TransactionManager>;
116
117 async fn establish(database_url: &str) -> ConnectionResult<Self> {
118 let database_url = database_url.to_string();
119 tokio::task::spawn_blocking(move || C::establish(&database_url))
120 .await
121 .unwrap_or_else(|e| Err(diesel::ConnectionError::BadConnection(e.to_string())))
122 .map(|c| SyncConnectionWrapper::new(c))
123 }
124
125 fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
126 where
127 T: AsQuery + 'query,
128 T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
129 {
130 self.execute_with_prepared_query(source.as_query(), |conn, query| {
131 use diesel::row::IntoOwnedRow;
132 let mut cache = <<<C as LoadConnection>::Row<'_, '_> as IntoOwnedRow<
133 <C as Connection>::Backend,
134 >>::Cache as Default>::default();
135 let cursor = conn.load(&query)?;
136
137 let size_hint = cursor.size_hint();
138 let mut out = Vec::with_capacity(size_hint.1.unwrap_or(size_hint.0));
139 for row in cursor {
142 out.push(Ok(IntoOwnedRow::into_owned(row?, &mut cache)));
143 }
144
145 Ok(out)
146 })
147 .map_ok(|rows| futures_util::stream::iter(rows).boxed())
148 .boxed()
149 }
150
151 fn execute_returning_count<'query, T>(&mut self, source: T) -> Self::ExecuteFuture<'_, 'query>
152 where
153 T: QueryFragment<Self::Backend> + QueryId,
154 {
155 self.execute_with_prepared_query(source, |conn, query| conn.execute_returning_count(&query))
156 }
157
158 fn transaction_state(
159 &mut self,
160 ) -> &mut <Self::TransactionManager as TransactionManager<Self>>::TransactionStateData {
161 self.exclusive_connection().transaction_state()
162 }
163
164 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
165 if let Some(inner) = Arc::get_mut(&mut self.inner) {
169 inner
170 .get_mut()
171 .unwrap_or_else(|p| p.into_inner())
172 .instrumentation()
173 } else {
174 panic!("Cannot access shared instrumentation")
175 }
176 }
177
178 fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
179 if let Some(inner) = Arc::get_mut(&mut self.inner) {
183 inner
184 .get_mut()
185 .unwrap_or_else(|p| p.into_inner())
186 .set_instrumentation(instrumentation)
187 } else {
188 panic!("Cannot access shared instrumentation")
189 }
190 }
191}
192
193pub struct SyncTransactionManagerWrapper<T>(PhantomData<T>);
195
196#[async_trait::async_trait]
197impl<T, C> TransactionManager<SyncConnectionWrapper<C>> for SyncTransactionManagerWrapper<T>
198where
199 SyncConnectionWrapper<C>: AsyncConnection,
200 C: Connection + 'static,
201 T: diesel::connection::TransactionManager<C> + Send,
202{
203 type TransactionStateData = T::TransactionStateData;
204
205 async fn begin_transaction(conn: &mut SyncConnectionWrapper<C>) -> QueryResult<()> {
206 conn.spawn_blocking(move |inner| T::begin_transaction(inner))
207 .await
208 }
209
210 async fn commit_transaction(conn: &mut SyncConnectionWrapper<C>) -> QueryResult<()> {
211 conn.spawn_blocking(move |inner| T::commit_transaction(inner))
212 .await
213 }
214
215 async fn rollback_transaction(conn: &mut SyncConnectionWrapper<C>) -> QueryResult<()> {
216 conn.spawn_blocking(move |inner| T::rollback_transaction(inner))
217 .await
218 }
219
220 fn transaction_manager_status_mut(
221 conn: &mut SyncConnectionWrapper<C>,
222 ) -> &mut TransactionManagerStatus {
223 T::transaction_manager_status_mut(conn.exclusive_connection())
224 }
225}
226
227impl<C> SyncConnectionWrapper<C> {
228 pub fn new(connection: C) -> Self
230 where
231 C: Connection,
232 {
233 SyncConnectionWrapper {
234 inner: Arc::new(Mutex::new(connection)),
235 }
236 }
237
238 pub fn spawn_blocking<'a, R>(
265 &mut self,
266 task: impl FnOnce(&mut C) -> QueryResult<R> + Send + 'static,
267 ) -> BoxFuture<'a, QueryResult<R>>
268 where
269 C: Connection + 'static,
270 R: Send + 'static,
271 {
272 let inner = self.inner.clone();
273 tokio::task::spawn_blocking(move || {
274 let mut inner = inner.lock().unwrap_or_else(|poison| {
275 inner.clear_poison();
277 poison.into_inner()
278 });
279 task(&mut inner)
280 })
281 .unwrap_or_else(|err| QueryResult::Err(from_tokio_join_error(err)))
282 .boxed()
283 }
284
285 fn execute_with_prepared_query<'a, MD, Q, R>(
286 &mut self,
287 query: Q,
288 callback: impl FnOnce(&mut C, &CollectedQuery<MD>) -> QueryResult<R> + Send + 'static,
289 ) -> BoxFuture<'a, QueryResult<R>>
290 where
291 <C as Connection>::Backend: std::default::Default + DieselReserveSpecialization,
293 <C::Backend as Backend>::QueryBuilder: std::default::Default,
294 C: Connection + LoadConnection + WithMetadataLookup + 'static,
296 <C as Connection>::TransactionManager: Send,
297 MD: Send + 'static,
299 for<'b> <C::Backend as Backend>::BindCollector<'b>:
300 MoveableBindCollector<C::Backend, BindData = MD> + std::default::Default,
301 Q: QueryFragment<C::Backend> + QueryId,
303 R: Send + 'static,
304 {
305 let backend = C::Backend::default();
306
307 let (collect_bind_result, collector_data) = {
308 let exclusive = self.inner.clone();
309 let mut inner = exclusive.lock().unwrap_or_else(|poison| {
310 exclusive.clear_poison();
312 poison.into_inner()
313 });
314 let mut bind_collector =
315 <<C::Backend as Backend>::BindCollector<'_> as Default>::default();
316 let metadata_lookup = inner.metadata_lookup();
317 let result = query.collect_binds(&mut bind_collector, metadata_lookup, &backend);
318 let collector_data = bind_collector.moveable();
319
320 (result, collector_data)
321 };
322
323 let mut query_builder = <<C::Backend as Backend>::QueryBuilder as Default>::default();
324 let sql = query
325 .to_sql(&mut query_builder, &backend)
326 .map(|_| query_builder.finish());
327 let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&backend);
328
329 self.spawn_blocking(|inner| {
330 collect_bind_result?;
331 let query = CollectedQuery::new(sql?, is_safe_to_cache_prepared?, collector_data);
332 callback(inner, &query)
333 })
334 }
335
336 pub(self) fn exclusive_connection(&mut self) -> &mut C
341 where
342 C: Connection,
343 {
344 if let Some(conn_mutex) = Arc::get_mut(&mut self.inner) {
348 conn_mutex
349 .get_mut()
350 .expect("Mutex is poisoned, a thread must have panicked holding it.")
351 } else {
352 panic!("Cannot access shared transaction state")
353 }
354 }
355}
356
357#[cfg(any(
358 feature = "deadpool",
359 feature = "bb8",
360 feature = "mobc",
361 feature = "r2d2"
362))]
363impl<C> crate::pooled_connection::PoolableConnection for SyncConnectionWrapper<C>
364where
365 Self: AsyncConnection,
366{
367 fn is_broken(&mut self) -> bool {
368 Self::TransactionManager::is_broken_transaction_manager(self)
369 }
370}