1use crate::stmt_cache::{PrepareCallback, StmtCache};
2use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection};
3use diesel::connection::statement_cache::{MaybeCached, StatementCacheKey};
4use diesel::connection::Instrumentation;
5use diesel::connection::InstrumentationEvent;
6use diesel::connection::StrQueryHelper;
7use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType};
8use diesel::query_builder::QueryBuilder;
9use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId};
10use diesel::result::{ConnectionError, ConnectionResult};
11use diesel::QueryResult;
12use futures_util::future::BoxFuture;
13use futures_util::stream::{self, BoxStream};
14use futures_util::{Future, FutureExt, StreamExt, TryStreamExt};
15use mysql_async::prelude::Queryable;
16use mysql_async::{Opts, OptsBuilder, Statement};
17
18mod error_helper;
19mod row;
20mod serialize;
21
22use self::error_helper::ErrorHelper;
23use self::row::MysqlRow;
24use self::serialize::ToSqlHelper;
25
26pub struct AsyncMysqlConnection {
29 conn: mysql_async::Conn,
30 stmt_cache: StmtCache<Mysql, Statement>,
31 transaction_manager: AnsiTransactionManager,
32 instrumentation: std::sync::Mutex<Option<Box<dyn Instrumentation>>>,
33}
34
35#[async_trait::async_trait]
36impl SimpleAsyncConnection for AsyncMysqlConnection {
37 async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
38 self.instrumentation()
39 .on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
40 query,
41 )));
42 let result = self
43 .conn
44 .query_drop(query)
45 .await
46 .map_err(ErrorHelper)
47 .map_err(Into::into);
48 self.instrumentation()
49 .on_connection_event(InstrumentationEvent::finish_query(
50 &StrQueryHelper::new(query),
51 result.as_ref().err(),
52 ));
53 result
54 }
55}
56
57const CONNECTION_SETUP_QUERIES: &[&str] = &[
58 "SET time_zone = '+00:00';",
59 "SET character_set_client = 'utf8mb4'",
60 "SET character_set_connection = 'utf8mb4'",
61 "SET character_set_results = 'utf8mb4'",
62];
63
64#[async_trait::async_trait]
65impl AsyncConnection for AsyncMysqlConnection {
66 type ExecuteFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<usize>>;
67 type LoadFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<Self::Stream<'conn, 'query>>>;
68 type Stream<'conn, 'query> = BoxStream<'conn, QueryResult<Self::Row<'conn, 'query>>>;
69 type Row<'conn, 'query> = MysqlRow;
70 type Backend = Mysql;
71
72 type TransactionManager = AnsiTransactionManager;
73
74 async fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
75 let mut instrumentation = diesel::connection::get_default_instrumentation();
76 instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
77 database_url,
78 ));
79 let r = Self::establish_connection_inner(database_url).await;
80 instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection(
81 database_url,
82 r.as_ref().err(),
83 ));
84 let mut conn = r?;
85 conn.instrumentation = std::sync::Mutex::new(instrumentation);
86 Ok(conn)
87 }
88
89 fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
90 where
91 T: diesel::query_builder::AsQuery,
92 T::Query: diesel::query_builder::QueryFragment<Self::Backend>
93 + diesel::query_builder::QueryId
94 + 'query,
95 {
96 self.with_prepared_statement(source.as_query(), |conn, stmt, binds| async move {
97 let stmt_for_exec = match stmt {
98 MaybeCached::Cached(ref s) => (*s).clone(),
99 MaybeCached::CannotCache(ref s) => s.clone(),
100 _ => unreachable!(
101 "Diesel has only two variants here at the time of writing.\n\
102 If you ever see this error message please open in issue in the diesel-async issue tracker"
103 ),
104 };
105
106 let (tx, rx) = futures_channel::mpsc::channel(0);
107
108 let yielder = async move {
109 let r = Self::poll_result_stream(conn, stmt_for_exec, binds, tx).await;
110 if let MaybeCached::CannotCache(stmt) = stmt {
123 conn.close(stmt).await.map_err(ErrorHelper)?;
124 }
125 r
126 };
127
128 let fake_stream = stream::once(yielder).filter_map(|e: QueryResult<()>| async move {
129 if let Err(e) = e {
130 Some(Err(e))
131 } else {
132 None
133 }
134 });
135
136 let stream = stream::select(fake_stream, rx).boxed();
137
138 Ok(stream)
139 })
140 .boxed()
141 }
142
143 fn execute_returning_count<'conn, 'query, T>(
144 &'conn mut self,
145 source: T,
146 ) -> Self::ExecuteFuture<'conn, 'query>
147 where
148 T: diesel::query_builder::QueryFragment<Self::Backend>
149 + diesel::query_builder::QueryId
150 + 'query,
151 {
152 self.with_prepared_statement(source, |conn, stmt, binds| async move {
153 let params = mysql_async::Params::try_from(binds)?;
154 conn.exec_drop(&*stmt, params).await.map_err(ErrorHelper)?;
155 if let MaybeCached::CannotCache(stmt) = stmt {
167 conn.close(stmt).await.map_err(ErrorHelper)?;
168 }
169 conn.affected_rows()
170 .try_into()
171 .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))
172 })
173 }
174
175 fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
176 &mut self.transaction_manager
177 }
178
179 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
180 self.instrumentation
181 .get_mut()
182 .unwrap_or_else(|p| p.into_inner())
183 }
184
185 fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
186 *self
187 .instrumentation
188 .get_mut()
189 .unwrap_or_else(|p| p.into_inner()) = Some(Box::new(instrumentation));
190 }
191}
192
193#[inline(always)]
194fn update_transaction_manager_status<T>(
195 query_result: QueryResult<T>,
196 transaction_manager: &mut AnsiTransactionManager,
197) -> QueryResult<T> {
198 if let Err(diesel::result::Error::DatabaseError(
199 diesel::result::DatabaseErrorKind::SerializationFailure,
200 _,
201 )) = query_result
202 {
203 transaction_manager
204 .status
205 .set_requires_rollback_maybe_up_to_top_level(true)
206 }
207 query_result
208}
209
210#[async_trait::async_trait]
211impl PrepareCallback<Statement, MysqlType> for &'_ mut mysql_async::Conn {
212 async fn prepare(
213 self,
214 sql: &str,
215 _metadata: &[MysqlType],
216 _is_for_cache: diesel::connection::statement_cache::PrepareForCache,
217 ) -> QueryResult<(Statement, Self)> {
218 let s = self.prep(sql).await.map_err(ErrorHelper)?;
219 Ok((s, self))
220 }
221}
222
223impl AsyncMysqlConnection {
224 pub async fn try_from(conn: mysql_async::Conn) -> ConnectionResult<Self> {
229 use crate::run_query_dsl::RunQueryDsl;
230 let mut conn = AsyncMysqlConnection {
231 conn,
232 stmt_cache: StmtCache::new(),
233 transaction_manager: AnsiTransactionManager::default(),
234 instrumentation: std::sync::Mutex::new(
235 diesel::connection::get_default_instrumentation(),
236 ),
237 };
238
239 for stmt in CONNECTION_SETUP_QUERIES {
240 diesel::sql_query(*stmt)
241 .execute(&mut conn)
242 .await
243 .map_err(ConnectionError::CouldntSetupConfiguration)?;
244 }
245
246 Ok(conn)
247 }
248
249 fn with_prepared_statement<'conn, T, F, R>(
250 &'conn mut self,
251 query: T,
252 callback: impl (FnOnce(&'conn mut mysql_async::Conn, MaybeCached<'conn, Statement>, ToSqlHelper) -> F)
253 + Send
254 + 'conn,
255 ) -> BoxFuture<'conn, QueryResult<R>>
256 where
257 R: Send + 'conn,
258 T: QueryFragment<Mysql> + QueryId,
259 F: Future<Output = QueryResult<R>> + Send,
260 {
261 self.instrumentation()
262 .on_connection_event(InstrumentationEvent::start_query(&diesel::debug_query(
263 &query,
264 )));
265 let mut bind_collector = RawBytesBindCollector::<Mysql>::new();
266 let bind_collector = query
267 .collect_binds(&mut bind_collector, &mut (), &Mysql)
268 .map(|()| bind_collector);
269
270 let AsyncMysqlConnection {
271 ref mut conn,
272 ref mut stmt_cache,
273 ref mut transaction_manager,
274 ref mut instrumentation,
275 ..
276 } = self;
277
278 let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&Mysql);
279 let mut qb = MysqlQueryBuilder::new();
280 let sql = query.to_sql(&mut qb, &Mysql).map(|()| qb.finish());
281 let query_id = T::query_id();
282
283 async move {
284 let RawBytesBindCollector {
285 metadata, binds, ..
286 } = bind_collector?;
287 let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
288 let sql = sql?;
289 let inner = async {
290 let cache_key = if let Some(query_id) = query_id {
291 StatementCacheKey::Type(query_id)
292 } else {
293 StatementCacheKey::Sql {
294 sql: sql.clone(),
295 bind_types: metadata.clone(),
296 }
297 };
298
299 let (stmt, conn) = stmt_cache
300 .cached_prepared_statement(
301 cache_key,
302 sql.clone(),
303 is_safe_to_cache_prepared,
304 &metadata,
305 conn,
306 instrumentation,
307 )
308 .await?;
309 callback(conn, stmt, ToSqlHelper { metadata, binds }).await
310 };
311 let r = update_transaction_manager_status(inner.await, transaction_manager);
312 instrumentation
313 .get_mut()
314 .unwrap_or_else(|p| p.into_inner())
315 .on_connection_event(InstrumentationEvent::finish_query(
316 &StrQueryHelper::new(&sql),
317 r.as_ref().err(),
318 ));
319 r
320 }
321 .boxed()
322 }
323
324 async fn poll_result_stream(
325 conn: &mut mysql_async::Conn,
326 stmt_for_exec: mysql_async::Statement,
327 binds: ToSqlHelper,
328 mut tx: futures_channel::mpsc::Sender<QueryResult<MysqlRow>>,
329 ) -> QueryResult<()> {
330 use futures_util::sink::SinkExt;
331 let params = mysql_async::Params::try_from(binds)?;
332
333 let res = conn
334 .exec_iter(stmt_for_exec, params)
335 .await
336 .map_err(ErrorHelper)?;
337
338 let mut stream = res
339 .stream_and_drop::<MysqlRow>()
340 .await
341 .map_err(ErrorHelper)?
342 .ok_or_else(|| {
343 diesel::result::Error::DeserializationError(Box::new(
344 diesel::result::UnexpectedEndOfRow,
345 ))
346 })?
347 .map_err(|e| diesel::result::Error::from(ErrorHelper(e)));
348
349 while let Some(row) = stream.next().await {
350 let row = row?;
351 tx.send(Ok(row))
352 .await
353 .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))?;
354 }
355
356 Ok(())
357 }
358
359 async fn establish_connection_inner(
360 database_url: &str,
361 ) -> Result<AsyncMysqlConnection, ConnectionError> {
362 let opts = Opts::from_url(database_url)
363 .map_err(|e| diesel::result::ConnectionError::InvalidConnectionUrl(e.to_string()))?;
364 let builder = OptsBuilder::from_opts(opts)
365 .init(CONNECTION_SETUP_QUERIES.to_vec())
366 .stmt_cache_size(0) .client_found_rows(true); let conn = mysql_async::Conn::new(builder).await.map_err(ErrorHelper)?;
370
371 Ok(AsyncMysqlConnection {
372 conn,
373 stmt_cache: StmtCache::new(),
374 transaction_manager: AnsiTransactionManager::default(),
375 instrumentation: std::sync::Mutex::new(None),
376 })
377 }
378}
379
380#[cfg(any(
381 feature = "deadpool",
382 feature = "bb8",
383 feature = "mobc",
384 feature = "r2d2"
385))]
386impl crate::pooled_connection::PoolableConnection for AsyncMysqlConnection {}
387
388#[cfg(test)]
389mod tests {
390 use crate::RunQueryDsl;
391 mod diesel_async {
392 pub use crate::*;
393 }
394 include!("../doctest_setup.rs");
395
396 #[tokio::test]
397 async fn check_statements_are_dropped() {
398 use self::schema::users;
399
400 let mut conn = establish_connection().await;
401 let stmt_count = 16382 + 10;
404
405 for i in 0..stmt_count {
406 diesel::insert_into(users::table)
407 .values(Some(users::name.eq(format!("User{i}"))))
408 .execute(&mut conn)
409 .await
410 .unwrap();
411 }
412
413 #[derive(QueryableByName)]
414 #[diesel(table_name = users)]
415 #[allow(dead_code)]
416 struct User {
417 id: i32,
418 name: String,
419 }
420
421 for i in 0..stmt_count {
422 diesel::sql_query("SELECT id, name FROM users WHERE name = ?")
423 .bind::<diesel::sql_types::Text, _>(format!("User{i}"))
424 .load::<User>(&mut conn)
425 .await
426 .unwrap();
427 }
428 }
429}