diesel_async/pooled_connection/
mod.rs

1//! This module contains support using diesel-async with
2//! various async rust connection pooling solutions
3//!
4//! See the concrete pool implementations for examples:
5//! * [deadpool](self::deadpool)
6//! * [bb8](self::bb8)
7//! * [mobc](self::mobc)
8use crate::{AsyncConnection, SimpleAsyncConnection};
9use crate::{TransactionManager, UpdateAndFetchResults};
10use diesel::associations::HasTable;
11use diesel::connection::Instrumentation;
12use diesel::QueryResult;
13use futures_util::{future, FutureExt};
14use std::borrow::Cow;
15use std::fmt;
16use std::ops::DerefMut;
17
18#[cfg(feature = "bb8")]
19pub mod bb8;
20#[cfg(feature = "deadpool")]
21pub mod deadpool;
22#[cfg(feature = "mobc")]
23pub mod mobc;
24
25/// The error used when managing connections with `deadpool`.
26#[derive(Debug)]
27pub enum PoolError {
28    /// An error occurred establishing the connection
29    ConnectionError(diesel::result::ConnectionError),
30
31    /// An error occurred pinging the database
32    QueryError(diesel::result::Error),
33}
34
35impl fmt::Display for PoolError {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        match *self {
38            PoolError::ConnectionError(ref e) => e.fmt(f),
39            PoolError::QueryError(ref e) => e.fmt(f),
40        }
41    }
42}
43
44impl std::error::Error for PoolError {}
45
46/// Type of the custom setup closure passed to [`ManagerConfig::custom_setup`]
47pub type SetupCallback<C> =
48    Box<dyn Fn(&str) -> future::BoxFuture<diesel::ConnectionResult<C>> + Send + Sync>;
49
50/// Type of the recycle check callback for the [`RecyclingMethod::CustomFunction`] variant
51pub type RecycleCheckCallback<C> =
52    dyn Fn(&mut C) -> future::BoxFuture<QueryResult<()>> + Send + Sync;
53
54/// Possible methods of how a connection is recycled.
55#[derive(Default)]
56pub enum RecyclingMethod<C> {
57    /// Only check for open transactions when recycling existing connections
58    /// Unless you have special needs this is a safe choice.
59    ///
60    /// If the database connection is closed you will recieve an error on the first place
61    /// you actually try to use the connection
62    Fast,
63    /// In addition to checking for open transactions a test query is executed
64    ///
65    /// This is slower, but guarantees that the database connection is ready to be used.
66    #[default]
67    Verified,
68    /// Like `Verified` but with a custom query
69    CustomQuery(Cow<'static, str>),
70    /// Like `Verified` but with a custom callback that allows to perform more checks
71    ///
72    /// The connection is only recycled if the callback returns `Ok(())`
73    CustomFunction(Box<RecycleCheckCallback<C>>),
74}
75
76impl<C: fmt::Debug> fmt::Debug for RecyclingMethod<C> {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        match self {
79            Self::Fast => write!(f, "Fast"),
80            Self::Verified => write!(f, "Verified"),
81            Self::CustomQuery(arg0) => f.debug_tuple("CustomQuery").field(arg0).finish(),
82            Self::CustomFunction(_) => f.debug_tuple("CustomFunction").finish(),
83        }
84    }
85}
86
87/// Configuration object for a Manager.
88///
89/// This makes it possible to specify which [`RecyclingMethod`]
90/// should be used when retrieving existing objects from the `Pool`
91/// and it allows to provide a custom setup function.
92#[non_exhaustive]
93pub struct ManagerConfig<C> {
94    /// Method of how a connection is recycled. See [RecyclingMethod].
95    pub recycling_method: RecyclingMethod<C>,
96    /// Construct a new connection manger
97    /// with a custom setup procedure
98    ///
99    /// This can be used to for example establish a SSL secured
100    /// postgres connection
101    pub custom_setup: SetupCallback<C>,
102}
103
104impl<C> Default for ManagerConfig<C>
105where
106    C: AsyncConnection + 'static,
107{
108    fn default() -> Self {
109        Self {
110            recycling_method: Default::default(),
111            custom_setup: Box::new(|url| C::establish(url).boxed()),
112        }
113    }
114}
115
116/// An connection manager for use with diesel-async.
117///
118/// See the concrete pool implementations for examples:
119/// * [deadpool](self::deadpool)
120/// * [bb8](self::bb8)
121/// * [mobc](self::mobc)
122#[allow(dead_code)]
123pub struct AsyncDieselConnectionManager<C> {
124    connection_url: String,
125    manager_config: ManagerConfig<C>,
126}
127
128impl<C> fmt::Debug for AsyncDieselConnectionManager<C> {
129    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
130        write!(
131            f,
132            "AsyncDieselConnectionManager<{}>",
133            std::any::type_name::<C>()
134        )
135    }
136}
137
138impl<C> AsyncDieselConnectionManager<C>
139where
140    C: AsyncConnection + 'static,
141{
142    /// Returns a new connection manager,
143    /// which establishes connections to the given database URL.
144    #[must_use]
145    pub fn new(connection_url: impl Into<String>) -> Self
146    where
147        C: AsyncConnection + 'static,
148    {
149        Self::new_with_config(connection_url, Default::default())
150    }
151
152    /// Returns a new connection manager,
153    /// which establishes connections with the given database URL
154    /// and that uses the specified configuration
155    #[must_use]
156    pub fn new_with_config(
157        connection_url: impl Into<String>,
158        manager_config: ManagerConfig<C>,
159    ) -> Self {
160        Self {
161            connection_url: connection_url.into(),
162            manager_config,
163        }
164    }
165}
166
167#[async_trait::async_trait]
168impl<C> SimpleAsyncConnection for C
169where
170    C: DerefMut + Send,
171    C::Target: SimpleAsyncConnection + Send,
172{
173    async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
174        let conn = self.deref_mut();
175        conn.batch_execute(query).await
176    }
177}
178
179#[async_trait::async_trait]
180impl<C> AsyncConnection for C
181where
182    C: DerefMut + Send,
183    C::Target: AsyncConnection,
184{
185    type ExecuteFuture<'conn, 'query> =
186        <C::Target as AsyncConnection>::ExecuteFuture<'conn, 'query>;
187    type LoadFuture<'conn, 'query> = <C::Target as AsyncConnection>::LoadFuture<'conn, 'query>;
188    type Stream<'conn, 'query> = <C::Target as AsyncConnection>::Stream<'conn, 'query>;
189    type Row<'conn, 'query> = <C::Target as AsyncConnection>::Row<'conn, 'query>;
190
191    type Backend = <C::Target as AsyncConnection>::Backend;
192
193    type TransactionManager =
194        PoolTransactionManager<<C::Target as AsyncConnection>::TransactionManager>;
195
196    async fn establish(_database_url: &str) -> diesel::ConnectionResult<Self> {
197        Err(diesel::result::ConnectionError::BadConnection(
198            String::from("Cannot directly establish a pooled connection"),
199        ))
200    }
201
202    fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
203    where
204        T: diesel::query_builder::AsQuery + 'query,
205        T::Query: diesel::query_builder::QueryFragment<Self::Backend>
206            + diesel::query_builder::QueryId
207            + 'query,
208    {
209        let conn = self.deref_mut();
210        conn.load(source)
211    }
212
213    fn execute_returning_count<'conn, 'query, T>(
214        &'conn mut self,
215        source: T,
216    ) -> Self::ExecuteFuture<'conn, 'query>
217    where
218        T: diesel::query_builder::QueryFragment<Self::Backend>
219            + diesel::query_builder::QueryId
220            + 'query,
221    {
222        let conn = self.deref_mut();
223        conn.execute_returning_count(source)
224    }
225
226    fn transaction_state(
227        &mut self,
228    ) -> &mut <Self::TransactionManager as crate::transaction_manager::TransactionManager<Self>>::TransactionStateData{
229        let conn = self.deref_mut();
230        conn.transaction_state()
231    }
232
233    async fn begin_test_transaction(&mut self) -> diesel::QueryResult<()> {
234        self.deref_mut().begin_test_transaction().await
235    }
236
237    fn instrumentation(&mut self) -> &mut dyn Instrumentation {
238        self.deref_mut().instrumentation()
239    }
240
241    fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
242        self.deref_mut().set_instrumentation(instrumentation);
243    }
244}
245
246#[doc(hidden)]
247#[allow(missing_debug_implementations)]
248pub struct PoolTransactionManager<TM>(std::marker::PhantomData<TM>);
249
250#[async_trait::async_trait]
251impl<C, TM> TransactionManager<C> for PoolTransactionManager<TM>
252where
253    C: DerefMut + Send,
254    C::Target: AsyncConnection<TransactionManager = TM>,
255    TM: TransactionManager<C::Target>,
256{
257    type TransactionStateData = TM::TransactionStateData;
258
259    async fn begin_transaction(conn: &mut C) -> diesel::QueryResult<()> {
260        TM::begin_transaction(&mut **conn).await
261    }
262
263    async fn rollback_transaction(conn: &mut C) -> diesel::QueryResult<()> {
264        TM::rollback_transaction(&mut **conn).await
265    }
266
267    async fn commit_transaction(conn: &mut C) -> diesel::QueryResult<()> {
268        TM::commit_transaction(&mut **conn).await
269    }
270
271    fn transaction_manager_status_mut(
272        conn: &mut C,
273    ) -> &mut diesel::connection::TransactionManagerStatus {
274        TM::transaction_manager_status_mut(&mut **conn)
275    }
276
277    fn is_broken_transaction_manager(conn: &mut C) -> bool {
278        TM::is_broken_transaction_manager(&mut **conn)
279    }
280}
281
282#[async_trait::async_trait]
283impl<Changes, Output, Conn> UpdateAndFetchResults<Changes, Output> for Conn
284where
285    Conn: DerefMut + Send,
286    Changes: diesel::prelude::Identifiable + HasTable + Send,
287    Conn::Target: UpdateAndFetchResults<Changes, Output>,
288{
289    async fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult<Output>
290    where
291        Changes: 'async_trait,
292    {
293        self.deref_mut().update_and_fetch(changeset).await
294    }
295}
296
297#[derive(diesel::query_builder::QueryId)]
298struct CheckConnectionQuery;
299
300impl<DB> diesel::query_builder::QueryFragment<DB> for CheckConnectionQuery
301where
302    DB: diesel::backend::Backend,
303{
304    fn walk_ast<'b>(
305        &'b self,
306        mut pass: diesel::query_builder::AstPass<'_, 'b, DB>,
307    ) -> diesel::QueryResult<()> {
308        pass.push_sql("SELECT 1");
309        Ok(())
310    }
311}
312
313impl diesel::query_builder::Query for CheckConnectionQuery {
314    type SqlType = diesel::sql_types::Integer;
315}
316
317impl<C> diesel::query_dsl::RunQueryDsl<C> for CheckConnectionQuery {}
318
319#[doc(hidden)]
320#[async_trait::async_trait]
321pub trait PoolableConnection: AsyncConnection {
322    /// Check if a connection is still valid
323    ///
324    /// The default implementation will perform a check based on the provided
325    /// recycling method variant
326    async fn ping(&mut self, config: &RecyclingMethod<Self>) -> diesel::QueryResult<()>
327    where
328        for<'a> Self: 'a,
329        diesel::dsl::select<diesel::dsl::AsExprOf<i32, diesel::sql_types::Integer>>:
330            crate::methods::ExecuteDsl<Self>,
331        diesel::query_builder::SqlQuery: crate::methods::ExecuteDsl<Self>,
332    {
333        use crate::run_query_dsl::RunQueryDsl;
334        use diesel::IntoSql;
335
336        match config {
337            RecyclingMethod::Fast => Ok(()),
338            RecyclingMethod::Verified => {
339                diesel::select(1_i32.into_sql::<diesel::sql_types::Integer>())
340                    .execute(self)
341                    .await
342                    .map(|_| ())
343            }
344            RecyclingMethod::CustomQuery(query) => diesel::sql_query(query.as_ref())
345                .execute(self)
346                .await
347                .map(|_| ()),
348            RecyclingMethod::CustomFunction(c) => c(self).await,
349        }
350    }
351
352    /// Checks if the connection is broken and should not be reused
353    ///
354    /// This method should return only contain a fast non-blocking check
355    /// if the connection is considered to be broken or not. See
356    /// [ManageConnection::has_broken] for details.
357    ///
358    /// The default implementation uses
359    /// [TransactionManager::is_broken_transaction_manager].
360    fn is_broken(&mut self) -> bool {
361        Self::TransactionManager::is_broken_transaction_manager(self)
362    }
363}