diesel_async/
transaction_manager.rs

1use diesel::connection::InstrumentationEvent;
2use diesel::connection::TransactionManagerStatus;
3use diesel::connection::{
4    InTransactionStatus, TransactionDepthChange, ValidTransactionManagerStatus,
5};
6use diesel::result::Error;
7use diesel::QueryResult;
8use scoped_futures::ScopedBoxFuture;
9use std::borrow::Cow;
10use std::num::NonZeroU32;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13
14use crate::AsyncConnection;
15// TODO: refactor this to share more code with diesel
16
17/// Manages the internal transaction state for a connection.
18///
19/// You will not need to interact with this trait, unless you are writing an
20/// implementation of [`AsyncConnection`].
21#[async_trait::async_trait]
22pub trait TransactionManager<Conn: AsyncConnection>: Send {
23    /// Data stored as part of the connection implementation
24    /// to track the current transaction state of a connection
25    type TransactionStateData;
26
27    /// Begin a new transaction or savepoint
28    ///
29    /// If the transaction depth is greater than 0,
30    /// this should create a savepoint instead.
31    /// This function is expected to increment the transaction depth by 1.
32    async fn begin_transaction(conn: &mut Conn) -> QueryResult<()>;
33
34    /// Rollback the inner-most transaction or savepoint
35    ///
36    /// If the transaction depth is greater than 1,
37    /// this should rollback to the most recent savepoint.
38    /// This function is expected to decrement the transaction depth by 1.
39    async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()>;
40
41    /// Commit the inner-most transaction or savepoint
42    ///
43    /// If the transaction depth is greater than 1,
44    /// this should release the most recent savepoint.
45    /// This function is expected to decrement the transaction depth by 1.
46    async fn commit_transaction(conn: &mut Conn) -> QueryResult<()>;
47
48    /// Fetch the current transaction status as mutable
49    ///
50    /// Used to ensure that `begin_test_transaction` is not called when already
51    /// inside of a transaction, and that operations are not run in a `InError`
52    /// transaction manager.
53    #[doc(hidden)]
54    fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus;
55
56    /// Executes the given function inside of a database transaction
57    ///
58    /// Each implementation of this function needs to fulfill the documented
59    /// behaviour of [`AsyncConnection::transaction`]
60    async fn transaction<'a, F, R, E>(conn: &mut Conn, callback: F) -> Result<R, E>
61    where
62        F: for<'r> FnOnce(&'r mut Conn) -> ScopedBoxFuture<'a, 'r, Result<R, E>> + Send + 'a,
63        E: From<Error> + Send,
64        R: Send,
65    {
66        Self::begin_transaction(conn).await?;
67        match callback(&mut *conn).await {
68            Ok(value) => {
69                Self::commit_transaction(conn).await?;
70                Ok(value)
71            }
72            Err(user_error) => match Self::rollback_transaction(conn).await {
73                Ok(()) => Err(user_error),
74                Err(Error::BrokenTransactionManager) => {
75                    // In this case we are probably more interested by the
76                    // original error, which likely caused this
77                    Err(user_error)
78                }
79                Err(rollback_error) => Err(rollback_error.into()),
80            },
81        }
82    }
83
84    /// This methods checks if the connection manager is considered to be broken
85    /// by connection pool implementations
86    ///
87    /// A connection manager is considered to be broken by default if it either
88    /// contains an open transaction (because you don't want to have connections
89    /// with open transactions in your pool) or when the transaction manager is
90    /// in an error state.
91    #[doc(hidden)]
92    fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
93        check_broken_transaction_state(conn)
94    }
95}
96
97fn check_broken_transaction_state<Conn>(conn: &mut Conn) -> bool
98where
99    Conn: AsyncConnection,
100{
101    match Conn::TransactionManager::transaction_manager_status_mut(conn).transaction_state() {
102        // all transactions are closed
103        // so we don't consider this connection broken
104        Ok(ValidTransactionManagerStatus {
105            in_transaction: None,
106            ..
107        }) => false,
108        // The transaction manager is in an error state
109        // Therefore we consider this connection broken
110        Err(_) => true,
111        // The transaction manager contains a open transaction
112        // we do consider this connection broken
113        // if that transaction was not opened by `begin_test_transaction`
114        Ok(ValidTransactionManagerStatus {
115            in_transaction: Some(s),
116            ..
117        }) => !s.test_transaction,
118    }
119}
120
121/// An implementation of `TransactionManager` which can be used for backends
122/// which use ANSI standard syntax for savepoints such as SQLite and PostgreSQL.
123#[derive(Default, Debug)]
124pub struct AnsiTransactionManager {
125    pub(crate) status: TransactionManagerStatus,
126    // this boolean flag tracks whether we are currently in the process
127    // of executing any transaction releated SQL (BEGIN, COMMIT, ROLLBACK)
128    // if we ever encounter a situation where this flag is set
129    // while the connection is returned to a pool
130    // that means the connection is broken as someone dropped the
131    // transaction future while these commands where executed
132    // and we cannot know the connection state anymore
133    //
134    // We ensure this by wrapping all calls to `.await`
135    // into `AnsiTransactionManager::critical_transaction_block`
136    // below
137    //
138    // See https://github.com/weiznich/diesel_async/issues/198 for
139    // details
140    pub(crate) is_broken: Arc<AtomicBool>,
141}
142
143impl AnsiTransactionManager {
144    fn get_transaction_state<Conn>(
145        conn: &mut Conn,
146    ) -> QueryResult<&mut ValidTransactionManagerStatus>
147    where
148        Conn: AsyncConnection<TransactionManager = Self>,
149    {
150        conn.transaction_state().status.transaction_state()
151    }
152
153    /// Begin a transaction with custom SQL
154    ///
155    /// This is used by connections to implement more complex transaction APIs
156    /// to set things such as isolation levels.
157    /// Returns an error if already inside of a transaction.
158    pub async fn begin_transaction_sql<Conn>(conn: &mut Conn, sql: &str) -> QueryResult<()>
159    where
160        Conn: AsyncConnection<TransactionManager = Self>,
161    {
162        let is_broken = conn.transaction_state().is_broken.clone();
163        let state = Self::get_transaction_state(conn)?;
164        match state.transaction_depth() {
165            None => {
166                Self::critical_transaction_block(&is_broken, conn.batch_execute(sql)).await?;
167                Self::get_transaction_state(conn)?
168                    .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
169                Ok(())
170            }
171            Some(_depth) => Err(Error::AlreadyInTransaction),
172        }
173    }
174
175    // This function should be used to await any connection
176    // related future in our transaction manager implementation
177    //
178    // It takes care of tracking entering and exiting executing the future
179    // which in turn is used to determine if it's safe to still use
180    // the connection in the event of a canceled transaction execution
181    async fn critical_transaction_block<F>(is_broken: &AtomicBool, f: F) -> F::Output
182    where
183        F: std::future::Future,
184    {
185        let was_broken = is_broken.swap(true, Ordering::Relaxed);
186        debug_assert!(
187            !was_broken,
188            "Tried to execute a transaction SQL on transaction manager that was previously cancled"
189        );
190        let res = f.await;
191        is_broken.store(false, Ordering::Relaxed);
192        res
193    }
194}
195
196#[async_trait::async_trait]
197impl<Conn> TransactionManager<Conn> for AnsiTransactionManager
198where
199    Conn: AsyncConnection<TransactionManager = Self>,
200{
201    type TransactionStateData = Self;
202
203    async fn begin_transaction(conn: &mut Conn) -> QueryResult<()> {
204        let transaction_state = Self::get_transaction_state(conn)?;
205        let start_transaction_sql = match transaction_state.transaction_depth() {
206            None => Cow::from("BEGIN"),
207            Some(transaction_depth) => {
208                Cow::from(format!("SAVEPOINT diesel_savepoint_{transaction_depth}"))
209            }
210        };
211        let depth = transaction_state
212            .transaction_depth()
213            .and_then(|d| d.checked_add(1))
214            .unwrap_or(NonZeroU32::new(1).expect("It's not 0"));
215        conn.instrumentation()
216            .on_connection_event(InstrumentationEvent::begin_transaction(depth));
217        Self::critical_transaction_block(
218            &conn.transaction_state().is_broken.clone(),
219            conn.batch_execute(&start_transaction_sql),
220        )
221        .await?;
222        Self::get_transaction_state(conn)?
223            .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
224
225        Ok(())
226    }
227
228    async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()> {
229        let transaction_state = Self::get_transaction_state(conn)?;
230
231        let (
232            (rollback_sql, rolling_back_top_level),
233            requires_rollback_maybe_up_to_top_level_before_execute,
234        ) = match transaction_state.in_transaction {
235            Some(ref in_transaction) => (
236                match in_transaction.transaction_depth.get() {
237                    1 => (Cow::Borrowed("ROLLBACK"), true),
238                    depth_gt1 => (
239                        Cow::Owned(format!(
240                            "ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
241                            depth_gt1 - 1
242                        )),
243                        false,
244                    ),
245                },
246                in_transaction.requires_rollback_maybe_up_to_top_level,
247            ),
248            None => return Err(Error::NotInTransaction),
249        };
250
251        let depth = transaction_state
252            .transaction_depth()
253            .expect("We know that we are in a transaction here");
254        conn.instrumentation()
255            .on_connection_event(InstrumentationEvent::rollback_transaction(depth));
256
257        let is_broken = conn.transaction_state().is_broken.clone();
258
259        match Self::critical_transaction_block(&is_broken, conn.batch_execute(&rollback_sql)).await
260        {
261            Ok(()) => {
262                match Self::get_transaction_state(conn)?
263                    .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
264                {
265                    Ok(()) => {}
266                    Err(Error::NotInTransaction) if rolling_back_top_level => {
267                        // Transaction exit may have already been detected by connection
268                        // implementation. It's fine.
269                    }
270                    Err(e) => return Err(e),
271                }
272                Ok(())
273            }
274            Err(rollback_error) => {
275                let tm_status = Self::transaction_manager_status_mut(conn);
276                match tm_status {
277                    TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
278                        in_transaction:
279                            Some(InTransactionStatus {
280                                transaction_depth,
281                                requires_rollback_maybe_up_to_top_level,
282                                ..
283                            }),
284                        ..
285                    }) if transaction_depth.get() > 1 => {
286                        // A savepoint failed to rollback - we may still attempt to repair
287                        // the connection by rolling back higher levels.
288
289                        // To make it easier on the user (that they don't have to really
290                        // look at actual transaction depth and can just rely on the number
291                        // of times they have called begin/commit/rollback) we still
292                        // decrement here:
293                        *transaction_depth = NonZeroU32::new(transaction_depth.get() - 1)
294                            .expect("Depth was checked to be > 1");
295                        *requires_rollback_maybe_up_to_top_level = true;
296                        if requires_rollback_maybe_up_to_top_level_before_execute {
297                            // In that case, we tolerate that savepoint releases fail
298                            // -> we should ignore errors
299                            return Ok(());
300                        }
301                    }
302                    TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
303                        in_transaction: None,
304                        ..
305                    }) => {
306                        // we would have returned `NotInTransaction` if that was already the state
307                        // before we made our call
308                        // => Transaction manager status has been fixed by the underlying connection
309                        // so we don't need to set_in_error
310                    }
311                    _ => tm_status.set_in_error(),
312                }
313                Err(rollback_error)
314            }
315        }
316    }
317
318    /// If the transaction fails to commit due to a `SerializationFailure` or a
319    /// `ReadOnlyTransaction` a rollback will be attempted. If the rollback succeeds,
320    /// the original error will be returned, otherwise the error generated by the rollback
321    /// will be returned. In the second case the connection will be considered broken
322    /// as it contains a uncommitted unabortable open transaction.
323    async fn commit_transaction(conn: &mut Conn) -> QueryResult<()> {
324        let transaction_state = Self::get_transaction_state(conn)?;
325        let transaction_depth = transaction_state.transaction_depth();
326        let (commit_sql, committing_top_level) = match transaction_depth {
327            None => return Err(Error::NotInTransaction),
328            Some(transaction_depth) if transaction_depth.get() == 1 => {
329                (Cow::Borrowed("COMMIT"), true)
330            }
331            Some(transaction_depth) => (
332                Cow::Owned(format!(
333                    "RELEASE SAVEPOINT diesel_savepoint_{}",
334                    transaction_depth.get() - 1
335                )),
336                false,
337            ),
338        };
339        let depth = transaction_state
340            .transaction_depth()
341            .expect("We know that we are in a transaction here");
342        conn.instrumentation()
343            .on_connection_event(InstrumentationEvent::commit_transaction(depth));
344
345        let is_broken = conn.transaction_state().is_broken.clone();
346
347        match Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await {
348            Ok(()) => {
349                match Self::get_transaction_state(conn)?
350                    .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
351                {
352                    Ok(()) => {}
353                    Err(Error::NotInTransaction) if committing_top_level => {
354                        // Transaction exit may have already been detected by connection.
355                        // It's fine
356                    }
357                    Err(e) => return Err(e),
358                }
359                Ok(())
360            }
361            Err(commit_error) => {
362                if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
363                    in_transaction:
364                        Some(InTransactionStatus {
365                            requires_rollback_maybe_up_to_top_level: true,
366                            ..
367                        }),
368                    ..
369                }) = conn.transaction_state().status
370                {
371                    match Self::critical_transaction_block(
372                        &is_broken,
373                        Self::rollback_transaction(conn),
374                    )
375                    .await
376                    {
377                        Ok(()) => {}
378                        Err(rollback_error) => {
379                            conn.transaction_state().status.set_in_error();
380                            return Err(Error::RollbackErrorOnCommit {
381                                rollback_error: Box::new(rollback_error),
382                                commit_error: Box::new(commit_error),
383                            });
384                        }
385                    }
386                }
387                Err(commit_error)
388            }
389        }
390    }
391
392    fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus {
393        &mut conn.transaction_state().status
394    }
395
396    fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
397        conn.transaction_state().is_broken.load(Ordering::Relaxed)
398            || check_broken_transaction_state(conn)
399    }
400}