diesel_async/mysql/
mod.rs

1use crate::stmt_cache::{PrepareCallback, StmtCache};
2use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection};
3use diesel::connection::statement_cache::{MaybeCached, StatementCacheKey};
4use diesel::connection::Instrumentation;
5use diesel::connection::InstrumentationEvent;
6use diesel::connection::StrQueryHelper;
7use diesel::mysql::{Mysql, MysqlQueryBuilder, MysqlType};
8use diesel::query_builder::QueryBuilder;
9use diesel::query_builder::{bind_collector::RawBytesBindCollector, QueryFragment, QueryId};
10use diesel::result::{ConnectionError, ConnectionResult};
11use diesel::QueryResult;
12use futures_util::future::BoxFuture;
13use futures_util::stream::{self, BoxStream};
14use futures_util::{Future, FutureExt, StreamExt, TryStreamExt};
15use mysql_async::prelude::Queryable;
16use mysql_async::{Opts, OptsBuilder, Statement};
17
18mod error_helper;
19mod row;
20mod serialize;
21
22use self::error_helper::ErrorHelper;
23use self::row::MysqlRow;
24use self::serialize::ToSqlHelper;
25
26/// A connection to a MySQL database. Connection URLs should be in the form
27/// `mysql://[user[:password]@]host/database_name`
28pub struct AsyncMysqlConnection {
29    conn: mysql_async::Conn,
30    stmt_cache: StmtCache<Mysql, Statement>,
31    transaction_manager: AnsiTransactionManager,
32    instrumentation: std::sync::Mutex<Option<Box<dyn Instrumentation>>>,
33}
34
35#[async_trait::async_trait]
36impl SimpleAsyncConnection for AsyncMysqlConnection {
37    async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
38        self.instrumentation()
39            .on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
40                query,
41            )));
42        let result = self
43            .conn
44            .query_drop(query)
45            .await
46            .map_err(ErrorHelper)
47            .map_err(Into::into);
48        self.instrumentation()
49            .on_connection_event(InstrumentationEvent::finish_query(
50                &StrQueryHelper::new(query),
51                result.as_ref().err(),
52            ));
53        result
54    }
55}
56
57const CONNECTION_SETUP_QUERIES: &[&str] = &[
58    "SET time_zone = '+00:00';",
59    "SET character_set_client = 'utf8mb4'",
60    "SET character_set_connection = 'utf8mb4'",
61    "SET character_set_results = 'utf8mb4'",
62];
63
64#[async_trait::async_trait]
65impl AsyncConnection for AsyncMysqlConnection {
66    type ExecuteFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<usize>>;
67    type LoadFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<Self::Stream<'conn, 'query>>>;
68    type Stream<'conn, 'query> = BoxStream<'conn, QueryResult<Self::Row<'conn, 'query>>>;
69    type Row<'conn, 'query> = MysqlRow;
70    type Backend = Mysql;
71
72    type TransactionManager = AnsiTransactionManager;
73
74    async fn establish(database_url: &str) -> diesel::ConnectionResult<Self> {
75        let mut instrumentation = diesel::connection::get_default_instrumentation();
76        instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
77            database_url,
78        ));
79        let r = Self::establish_connection_inner(database_url).await;
80        instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection(
81            database_url,
82            r.as_ref().err(),
83        ));
84        let mut conn = r?;
85        conn.instrumentation = std::sync::Mutex::new(instrumentation);
86        Ok(conn)
87    }
88
89    fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
90    where
91        T: diesel::query_builder::AsQuery,
92        T::Query: diesel::query_builder::QueryFragment<Self::Backend>
93            + diesel::query_builder::QueryId
94            + 'query,
95    {
96        self.with_prepared_statement(source.as_query(), |conn, stmt, binds| async move {
97            let stmt_for_exec = match stmt {
98                MaybeCached::Cached(ref s) => (*s).clone(),
99                MaybeCached::CannotCache(ref s) => s.clone(),
100                _ => unreachable!(
101                    "Diesel has only two variants here at the time of writing.\n\
102                     If you ever see this error message please open in issue in the diesel-async issue tracker"
103                ),
104            };
105
106            let (tx, rx) = futures_channel::mpsc::channel(0);
107
108            let yielder = async move {
109                let r = Self::poll_result_stream(conn, stmt_for_exec, binds, tx).await;
110                // We need to close any non-cached statement explicitly here as otherwise
111                // we might error out on too many open statements. See https://github.com/weiznich/diesel_async/issues/26
112                // for details
113                //
114                // This might be problematic for cases where the stream is dropped before the end is reached
115                //
116                // Such behaviour might happen if users:
117                // * Just drop the future/stream after polling at least once (timeouts!!)
118                // * Users only fetch a fixed number of elements from the stream
119                //
120                // For now there is not really a good solution to this problem as this would require something like async drop
121                // (and even with async drop that would be really hard to solve due to the involved lifetimes)
122                if let MaybeCached::CannotCache(stmt) = stmt {
123                    conn.close(stmt).await.map_err(ErrorHelper)?;
124                }
125                r
126            };
127
128            let fake_stream = stream::once(yielder).filter_map(|e: QueryResult<()>| async move {
129                if let Err(e) = e {
130                    Some(Err(e))
131                } else {
132                    None
133                }
134            });
135
136            let stream = stream::select(fake_stream, rx).boxed();
137
138            Ok(stream)
139        })
140        .boxed()
141    }
142
143    fn execute_returning_count<'conn, 'query, T>(
144        &'conn mut self,
145        source: T,
146    ) -> Self::ExecuteFuture<'conn, 'query>
147    where
148        T: diesel::query_builder::QueryFragment<Self::Backend>
149            + diesel::query_builder::QueryId
150            + 'query,
151    {
152        self.with_prepared_statement(source, |conn, stmt, binds| async move {
153            let params = mysql_async::Params::try_from(binds)?;
154            conn.exec_drop(&*stmt, params).await.map_err(ErrorHelper)?;
155            // We need to close any non-cached statement explicitly here as otherwise
156            // we might error out on too many open statements. See https://github.com/weiznich/diesel_async/issues/26
157            // for details
158            //
159            // This might be problematic for cases where the stream is dropped before the end is reached
160            //
161            // Such behaviour might happen if users:
162            // * Just drop the future after polling at least once (timeouts!!)
163            //
164            // For now there is not really a good solution to this problem as this would require something like async drop
165            // (and even with async drop that would be really hard to solve due to the involved lifetimes)
166            if let MaybeCached::CannotCache(stmt) = stmt {
167                conn.close(stmt).await.map_err(ErrorHelper)?;
168            }
169            conn.affected_rows()
170                .try_into()
171                .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))
172        })
173    }
174
175    fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
176        &mut self.transaction_manager
177    }
178
179    fn instrumentation(&mut self) -> &mut dyn Instrumentation {
180        self.instrumentation
181            .get_mut()
182            .unwrap_or_else(|p| p.into_inner())
183    }
184
185    fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
186        *self
187            .instrumentation
188            .get_mut()
189            .unwrap_or_else(|p| p.into_inner()) = Some(Box::new(instrumentation));
190    }
191}
192
193#[inline(always)]
194fn update_transaction_manager_status<T>(
195    query_result: QueryResult<T>,
196    transaction_manager: &mut AnsiTransactionManager,
197) -> QueryResult<T> {
198    if let Err(diesel::result::Error::DatabaseError(
199        diesel::result::DatabaseErrorKind::SerializationFailure,
200        _,
201    )) = query_result
202    {
203        transaction_manager
204            .status
205            .set_requires_rollback_maybe_up_to_top_level(true)
206    }
207    query_result
208}
209
210#[async_trait::async_trait]
211impl PrepareCallback<Statement, MysqlType> for &'_ mut mysql_async::Conn {
212    async fn prepare(
213        self,
214        sql: &str,
215        _metadata: &[MysqlType],
216        _is_for_cache: diesel::connection::statement_cache::PrepareForCache,
217    ) -> QueryResult<(Statement, Self)> {
218        let s = self.prep(sql).await.map_err(ErrorHelper)?;
219        Ok((s, self))
220    }
221}
222
223impl AsyncMysqlConnection {
224    /// Wrap an existing [`mysql_async::Conn`] into a async diesel mysql connection
225    ///
226    /// This function constructs a new `AsyncMysqlConnection` based on an existing
227    /// [`mysql_async::Conn]`.
228    pub async fn try_from(conn: mysql_async::Conn) -> ConnectionResult<Self> {
229        use crate::run_query_dsl::RunQueryDsl;
230        let mut conn = AsyncMysqlConnection {
231            conn,
232            stmt_cache: StmtCache::new(),
233            transaction_manager: AnsiTransactionManager::default(),
234            instrumentation: std::sync::Mutex::new(
235                diesel::connection::get_default_instrumentation(),
236            ),
237        };
238
239        for stmt in CONNECTION_SETUP_QUERIES {
240            diesel::sql_query(*stmt)
241                .execute(&mut conn)
242                .await
243                .map_err(ConnectionError::CouldntSetupConfiguration)?;
244        }
245
246        Ok(conn)
247    }
248
249    fn with_prepared_statement<'conn, T, F, R>(
250        &'conn mut self,
251        query: T,
252        callback: impl (FnOnce(&'conn mut mysql_async::Conn, MaybeCached<'conn, Statement>, ToSqlHelper) -> F)
253            + Send
254            + 'conn,
255    ) -> BoxFuture<'conn, QueryResult<R>>
256    where
257        R: Send + 'conn,
258        T: QueryFragment<Mysql> + QueryId,
259        F: Future<Output = QueryResult<R>> + Send,
260    {
261        self.instrumentation()
262            .on_connection_event(InstrumentationEvent::start_query(&diesel::debug_query(
263                &query,
264            )));
265        let mut bind_collector = RawBytesBindCollector::<Mysql>::new();
266        let bind_collector = query
267            .collect_binds(&mut bind_collector, &mut (), &Mysql)
268            .map(|()| bind_collector);
269
270        let AsyncMysqlConnection {
271            ref mut conn,
272            ref mut stmt_cache,
273            ref mut transaction_manager,
274            ref mut instrumentation,
275            ..
276        } = self;
277
278        let is_safe_to_cache_prepared = query.is_safe_to_cache_prepared(&Mysql);
279        let mut qb = MysqlQueryBuilder::new();
280        let sql = query.to_sql(&mut qb, &Mysql).map(|()| qb.finish());
281        let query_id = T::query_id();
282
283        async move {
284            let RawBytesBindCollector {
285                metadata, binds, ..
286            } = bind_collector?;
287            let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
288            let sql = sql?;
289            let inner = async {
290                let cache_key = if let Some(query_id) = query_id {
291                    StatementCacheKey::Type(query_id)
292                } else {
293                    StatementCacheKey::Sql {
294                        sql: sql.clone(),
295                        bind_types: metadata.clone(),
296                    }
297                };
298
299                let (stmt, conn) = stmt_cache
300                    .cached_prepared_statement(
301                        cache_key,
302                        sql.clone(),
303                        is_safe_to_cache_prepared,
304                        &metadata,
305                        conn,
306                        instrumentation,
307                    )
308                    .await?;
309                callback(conn, stmt, ToSqlHelper { metadata, binds }).await
310            };
311            let r = update_transaction_manager_status(inner.await, transaction_manager);
312            instrumentation
313                .get_mut()
314                .unwrap_or_else(|p| p.into_inner())
315                .on_connection_event(InstrumentationEvent::finish_query(
316                    &StrQueryHelper::new(&sql),
317                    r.as_ref().err(),
318                ));
319            r
320        }
321        .boxed()
322    }
323
324    async fn poll_result_stream(
325        conn: &mut mysql_async::Conn,
326        stmt_for_exec: mysql_async::Statement,
327        binds: ToSqlHelper,
328        mut tx: futures_channel::mpsc::Sender<QueryResult<MysqlRow>>,
329    ) -> QueryResult<()> {
330        use futures_util::sink::SinkExt;
331        let params = mysql_async::Params::try_from(binds)?;
332
333        let res = conn
334            .exec_iter(stmt_for_exec, params)
335            .await
336            .map_err(ErrorHelper)?;
337
338        let mut stream = res
339            .stream_and_drop::<MysqlRow>()
340            .await
341            .map_err(ErrorHelper)?
342            .ok_or_else(|| {
343                diesel::result::Error::DeserializationError(Box::new(
344                    diesel::result::UnexpectedEndOfRow,
345                ))
346            })?
347            .map_err(|e| diesel::result::Error::from(ErrorHelper(e)));
348
349        while let Some(row) = stream.next().await {
350            let row = row?;
351            tx.send(Ok(row))
352                .await
353                .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))?;
354        }
355
356        Ok(())
357    }
358
359    async fn establish_connection_inner(
360        database_url: &str,
361    ) -> Result<AsyncMysqlConnection, ConnectionError> {
362        let opts = Opts::from_url(database_url)
363            .map_err(|e| diesel::result::ConnectionError::InvalidConnectionUrl(e.to_string()))?;
364        let builder = OptsBuilder::from_opts(opts)
365            .init(CONNECTION_SETUP_QUERIES.to_vec())
366            .stmt_cache_size(0) // We have our own cache
367            .client_found_rows(true); // This allows a consistent behavior between MariaDB/MySQL and PostgreSQL (and is already set in `diesel`)
368
369        let conn = mysql_async::Conn::new(builder).await.map_err(ErrorHelper)?;
370
371        Ok(AsyncMysqlConnection {
372            conn,
373            stmt_cache: StmtCache::new(),
374            transaction_manager: AnsiTransactionManager::default(),
375            instrumentation: std::sync::Mutex::new(None),
376        })
377    }
378}
379
380#[cfg(any(
381    feature = "deadpool",
382    feature = "bb8",
383    feature = "mobc",
384    feature = "r2d2"
385))]
386impl crate::pooled_connection::PoolableConnection for AsyncMysqlConnection {}
387
388#[cfg(test)]
389mod tests {
390    use crate::RunQueryDsl;
391    mod diesel_async {
392        pub use crate::*;
393    }
394    include!("../doctest_setup.rs");
395
396    #[tokio::test]
397    async fn check_statements_are_dropped() {
398        use self::schema::users;
399
400        let mut conn = establish_connection().await;
401        // we cannot set a lower limit here without admin privileges
402        // which makes this test really slow
403        let stmt_count = 16382 + 10;
404
405        for i in 0..stmt_count {
406            diesel::insert_into(users::table)
407                .values(Some(users::name.eq(format!("User{i}"))))
408                .execute(&mut conn)
409                .await
410                .unwrap();
411        }
412
413        #[derive(QueryableByName)]
414        #[diesel(table_name = users)]
415        #[allow(dead_code)]
416        struct User {
417            id: i32,
418            name: String,
419        }
420
421        for i in 0..stmt_count {
422            diesel::sql_query("SELECT id, name FROM users WHERE name = ?")
423                .bind::<diesel::sql_types::Text, _>(format!("User{i}"))
424                .load::<User>(&mut conn)
425                .await
426                .unwrap();
427        }
428    }
429}