1use futures_util::Future;
13use futures_util::Stream;
14use futures_util::StreamExt;
15use std::pin::Pin;
16
17pub trait BlockOn {
22 fn block_on<F>(&self, f: F) -> F::Output
25 where
26 F: Future;
27
28 fn get_runtime() -> Self;
31}
32
33#[cfg(feature = "tokio")]
91pub type AsyncConnectionWrapper<C, B = self::implementation::Tokio> =
92 self::implementation::AsyncConnectionWrapper<C, B>;
93
94#[cfg(not(feature = "tokio"))]
100pub use self::implementation::AsyncConnectionWrapper;
101
102mod implementation {
103 use diesel::connection::{Instrumentation, SimpleConnection};
104 use std::ops::{Deref, DerefMut};
105
106 use super::*;
107
108 pub struct AsyncConnectionWrapper<C, B> {
109 inner: C,
110 runtime: B,
111 }
112
113 impl<C, B> From<C> for AsyncConnectionWrapper<C, B>
114 where
115 C: crate::AsyncConnection,
116 B: BlockOn + Send,
117 {
118 fn from(inner: C) -> Self {
119 Self {
120 inner,
121 runtime: B::get_runtime(),
122 }
123 }
124 }
125
126 impl<C, B> Deref for AsyncConnectionWrapper<C, B> {
127 type Target = C;
128
129 fn deref(&self) -> &Self::Target {
130 &self.inner
131 }
132 }
133
134 impl<C, B> DerefMut for AsyncConnectionWrapper<C, B> {
135 fn deref_mut(&mut self) -> &mut Self::Target {
136 &mut self.inner
137 }
138 }
139
140 impl<C, B> diesel::connection::SimpleConnection for AsyncConnectionWrapper<C, B>
141 where
142 C: crate::SimpleAsyncConnection,
143 B: BlockOn,
144 {
145 fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
146 let f = self.inner.batch_execute(query);
147 self.runtime.block_on(f)
148 }
149 }
150
151 impl<C, B> diesel::connection::ConnectionSealed for AsyncConnectionWrapper<C, B> {}
152
153 impl<C, B> diesel::connection::Connection for AsyncConnectionWrapper<C, B>
154 where
155 C: crate::AsyncConnection,
156 B: BlockOn + Send,
157 {
158 type Backend = C::Backend;
159
160 type TransactionManager = AsyncConnectionWrapperTransactionManagerWrapper;
161
162 fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
163 let runtime = B::get_runtime();
164 let f = C::establish(database_url);
165 let inner = runtime.block_on(f)?;
166 Ok(Self { inner, runtime })
167 }
168
169 fn execute_returning_count<T>(&mut self, source: &T) -> diesel::QueryResult<usize>
170 where
171 T: diesel::query_builder::QueryFragment<Self::Backend> + diesel::query_builder::QueryId,
172 {
173 let f = self.inner.execute_returning_count(source);
174 self.runtime.block_on(f)
175 }
176
177 fn transaction_state(
178 &mut self,
179 ) -> &mut <Self::TransactionManager as diesel::connection::TransactionManager<Self>>::TransactionStateData{
180 self.inner.transaction_state()
181 }
182
183 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
184 self.inner.instrumentation()
185 }
186
187 fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
188 self.inner.set_instrumentation(instrumentation);
189 }
190 }
191
192 impl<C, B> diesel::connection::LoadConnection for AsyncConnectionWrapper<C, B>
193 where
194 C: crate::AsyncConnection,
195 B: BlockOn + Send,
196 {
197 type Cursor<'conn, 'query> = AsyncCursorWrapper<'conn, C::Stream<'conn, 'query>, B>
198 where
199 Self: 'conn;
200
201 type Row<'conn, 'query> = C::Row<'conn, 'query>
202 where
203 Self: 'conn;
204
205 fn load<'conn, 'query, T>(
206 &'conn mut self,
207 source: T,
208 ) -> diesel::QueryResult<Self::Cursor<'conn, 'query>>
209 where
210 T: diesel::query_builder::Query
211 + diesel::query_builder::QueryFragment<Self::Backend>
212 + diesel::query_builder::QueryId
213 + 'query,
214 Self::Backend: diesel::expression::QueryMetadata<T::SqlType>,
215 {
216 let f = self.inner.load(source);
217 let stream = self.runtime.block_on(f)?;
218
219 Ok(AsyncCursorWrapper {
220 stream: Box::pin(stream),
221 runtime: &self.runtime,
222 })
223 }
224 }
225
226 pub struct AsyncCursorWrapper<'a, S, B> {
227 stream: Pin<Box<S>>,
228 runtime: &'a B,
229 }
230
231 impl<'a, S, B> Iterator for AsyncCursorWrapper<'a, S, B>
232 where
233 S: Stream,
234 B: BlockOn,
235 {
236 type Item = S::Item;
237
238 fn next(&mut self) -> Option<Self::Item> {
239 let f = self.stream.next();
240 self.runtime.block_on(f)
241 }
242 }
243
244 pub struct AsyncConnectionWrapperTransactionManagerWrapper;
245
246 impl<C, B> diesel::connection::TransactionManager<AsyncConnectionWrapper<C, B>>
247 for AsyncConnectionWrapperTransactionManagerWrapper
248 where
249 C: crate::AsyncConnection,
250 B: BlockOn + Send,
251 {
252 type TransactionStateData =
253 <C::TransactionManager as crate::TransactionManager<C>>::TransactionStateData;
254
255 fn begin_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
256 let f = <C::TransactionManager as crate::TransactionManager<_>>::begin_transaction(
257 &mut conn.inner,
258 );
259 conn.runtime.block_on(f)
260 }
261
262 fn rollback_transaction(
263 conn: &mut AsyncConnectionWrapper<C, B>,
264 ) -> diesel::QueryResult<()> {
265 let f = <C::TransactionManager as crate::TransactionManager<_>>::rollback_transaction(
266 &mut conn.inner,
267 );
268 conn.runtime.block_on(f)
269 }
270
271 fn commit_transaction(conn: &mut AsyncConnectionWrapper<C, B>) -> diesel::QueryResult<()> {
272 let f = <C::TransactionManager as crate::TransactionManager<_>>::commit_transaction(
273 &mut conn.inner,
274 );
275 conn.runtime.block_on(f)
276 }
277
278 fn transaction_manager_status_mut(
279 conn: &mut AsyncConnectionWrapper<C, B>,
280 ) -> &mut diesel::connection::TransactionManagerStatus {
281 <C::TransactionManager as crate::TransactionManager<_>>::transaction_manager_status_mut(
282 &mut conn.inner,
283 )
284 }
285
286 fn is_broken_transaction_manager(conn: &mut AsyncConnectionWrapper<C, B>) -> bool {
287 <C::TransactionManager as crate::TransactionManager<_>>::is_broken_transaction_manager(
288 &mut conn.inner,
289 )
290 }
291 }
292
293 #[cfg(feature = "r2d2")]
294 impl<C, B> diesel::r2d2::R2D2Connection for AsyncConnectionWrapper<C, B>
295 where
296 B: BlockOn,
297 Self: diesel::Connection,
298 C: crate::AsyncConnection<Backend = <Self as diesel::Connection>::Backend>
299 + crate::pooled_connection::PoolableConnection
300 + 'static,
301 diesel::dsl::select<diesel::dsl::AsExprOf<i32, diesel::sql_types::Integer>>:
302 crate::methods::ExecuteDsl<C>,
303 diesel::query_builder::SqlQuery: crate::methods::ExecuteDsl<C>,
304 {
305 fn ping(&mut self) -> diesel::QueryResult<()> {
306 let fut = crate::pooled_connection::PoolableConnection::ping(
307 &mut self.inner,
308 &crate::pooled_connection::RecyclingMethod::Verified,
309 );
310 self.runtime.block_on(fut)
311 }
312
313 fn is_broken(&mut self) -> bool {
314 crate::pooled_connection::PoolableConnection::is_broken(&mut self.inner)
315 }
316 }
317
318 impl<C, B> diesel::migration::MigrationConnection for AsyncConnectionWrapper<C, B>
319 where
320 B: BlockOn,
321 Self: diesel::Connection,
322 {
323 fn setup(&mut self) -> diesel::QueryResult<usize> {
324 self.batch_execute(diesel::migration::CREATE_MIGRATIONS_TABLE)
325 .map(|()| 0)
326 }
327 }
328
329 #[cfg(feature = "tokio")]
330 pub struct Tokio {
331 handle: Option<tokio::runtime::Handle>,
332 runtime: Option<tokio::runtime::Runtime>,
333 }
334
335 #[cfg(feature = "tokio")]
336 impl BlockOn for Tokio {
337 fn block_on<F>(&self, f: F) -> F::Output
338 where
339 F: Future,
340 {
341 if let Some(handle) = &self.handle {
342 handle.block_on(f)
343 } else if let Some(runtime) = &self.runtime {
344 runtime.block_on(f)
345 } else {
346 unreachable!()
347 }
348 }
349
350 fn get_runtime() -> Self {
351 if let Ok(handle) = tokio::runtime::Handle::try_current() {
352 Self {
353 handle: Some(handle),
354 runtime: None,
355 }
356 } else {
357 let runtime = tokio::runtime::Builder::new_current_thread()
358 .enable_io()
359 .build()
360 .unwrap();
361 Self {
362 handle: None,
363 runtime: Some(runtime),
364 }
365 }
366 }
367 }
368}