diesel_async/
transaction_manager.rs1use diesel::connection::InstrumentationEvent;
2use diesel::connection::TransactionManagerStatus;
3use diesel::connection::{
4 InTransactionStatus, TransactionDepthChange, ValidTransactionManagerStatus,
5};
6use diesel::result::Error;
7use diesel::QueryResult;
8use scoped_futures::ScopedBoxFuture;
9use std::borrow::Cow;
10use std::num::NonZeroU32;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13
14use crate::AsyncConnection;
15#[async_trait::async_trait]
22pub trait TransactionManager<Conn: AsyncConnection>: Send {
23 type TransactionStateData;
26
27 async fn begin_transaction(conn: &mut Conn) -> QueryResult<()>;
33
34 async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()>;
40
41 async fn commit_transaction(conn: &mut Conn) -> QueryResult<()>;
47
48 #[doc(hidden)]
54 fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus;
55
56 async fn transaction<'a, F, R, E>(conn: &mut Conn, callback: F) -> Result<R, E>
61 where
62 F: for<'r> FnOnce(&'r mut Conn) -> ScopedBoxFuture<'a, 'r, Result<R, E>> + Send + 'a,
63 E: From<Error> + Send,
64 R: Send,
65 {
66 Self::begin_transaction(conn).await?;
67 match callback(&mut *conn).await {
68 Ok(value) => {
69 Self::commit_transaction(conn).await?;
70 Ok(value)
71 }
72 Err(user_error) => match Self::rollback_transaction(conn).await {
73 Ok(()) => Err(user_error),
74 Err(Error::BrokenTransactionManager) => {
75 Err(user_error)
78 }
79 Err(rollback_error) => Err(rollback_error.into()),
80 },
81 }
82 }
83
84 #[doc(hidden)]
92 fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
93 check_broken_transaction_state(conn)
94 }
95}
96
97fn check_broken_transaction_state<Conn>(conn: &mut Conn) -> bool
98where
99 Conn: AsyncConnection,
100{
101 match Conn::TransactionManager::transaction_manager_status_mut(conn).transaction_state() {
102 Ok(ValidTransactionManagerStatus {
105 in_transaction: None,
106 ..
107 }) => false,
108 Err(_) => true,
111 Ok(ValidTransactionManagerStatus {
115 in_transaction: Some(s),
116 ..
117 }) => !s.test_transaction,
118 }
119}
120
121#[derive(Default, Debug)]
124pub struct AnsiTransactionManager {
125 pub(crate) status: TransactionManagerStatus,
126 pub(crate) is_broken: Arc<AtomicBool>,
141}
142
143impl AnsiTransactionManager {
144 fn get_transaction_state<Conn>(
145 conn: &mut Conn,
146 ) -> QueryResult<&mut ValidTransactionManagerStatus>
147 where
148 Conn: AsyncConnection<TransactionManager = Self>,
149 {
150 conn.transaction_state().status.transaction_state()
151 }
152
153 pub async fn begin_transaction_sql<Conn>(conn: &mut Conn, sql: &str) -> QueryResult<()>
159 where
160 Conn: AsyncConnection<TransactionManager = Self>,
161 {
162 let is_broken = conn.transaction_state().is_broken.clone();
163 let state = Self::get_transaction_state(conn)?;
164 match state.transaction_depth() {
165 None => {
166 Self::critical_transaction_block(&is_broken, conn.batch_execute(sql)).await?;
167 Self::get_transaction_state(conn)?
168 .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
169 Ok(())
170 }
171 Some(_depth) => Err(Error::AlreadyInTransaction),
172 }
173 }
174
175 async fn critical_transaction_block<F>(is_broken: &AtomicBool, f: F) -> F::Output
182 where
183 F: std::future::Future,
184 {
185 let was_broken = is_broken.swap(true, Ordering::Relaxed);
186 debug_assert!(
187 !was_broken,
188 "Tried to execute a transaction SQL on transaction manager that was previously cancled"
189 );
190 let res = f.await;
191 is_broken.store(false, Ordering::Relaxed);
192 res
193 }
194}
195
196#[async_trait::async_trait]
197impl<Conn> TransactionManager<Conn> for AnsiTransactionManager
198where
199 Conn: AsyncConnection<TransactionManager = Self>,
200{
201 type TransactionStateData = Self;
202
203 async fn begin_transaction(conn: &mut Conn) -> QueryResult<()> {
204 let transaction_state = Self::get_transaction_state(conn)?;
205 let start_transaction_sql = match transaction_state.transaction_depth() {
206 None => Cow::from("BEGIN"),
207 Some(transaction_depth) => {
208 Cow::from(format!("SAVEPOINT diesel_savepoint_{transaction_depth}"))
209 }
210 };
211 let depth = transaction_state
212 .transaction_depth()
213 .and_then(|d| d.checked_add(1))
214 .unwrap_or(NonZeroU32::new(1).expect("It's not 0"));
215 conn.instrumentation()
216 .on_connection_event(InstrumentationEvent::begin_transaction(depth));
217 Self::critical_transaction_block(
218 &conn.transaction_state().is_broken.clone(),
219 conn.batch_execute(&start_transaction_sql),
220 )
221 .await?;
222 Self::get_transaction_state(conn)?
223 .change_transaction_depth(TransactionDepthChange::IncreaseDepth)?;
224
225 Ok(())
226 }
227
228 async fn rollback_transaction(conn: &mut Conn) -> QueryResult<()> {
229 let transaction_state = Self::get_transaction_state(conn)?;
230
231 let (
232 (rollback_sql, rolling_back_top_level),
233 requires_rollback_maybe_up_to_top_level_before_execute,
234 ) = match transaction_state.in_transaction {
235 Some(ref in_transaction) => (
236 match in_transaction.transaction_depth.get() {
237 1 => (Cow::Borrowed("ROLLBACK"), true),
238 depth_gt1 => (
239 Cow::Owned(format!(
240 "ROLLBACK TO SAVEPOINT diesel_savepoint_{}",
241 depth_gt1 - 1
242 )),
243 false,
244 ),
245 },
246 in_transaction.requires_rollback_maybe_up_to_top_level,
247 ),
248 None => return Err(Error::NotInTransaction),
249 };
250
251 let depth = transaction_state
252 .transaction_depth()
253 .expect("We know that we are in a transaction here");
254 conn.instrumentation()
255 .on_connection_event(InstrumentationEvent::rollback_transaction(depth));
256
257 let is_broken = conn.transaction_state().is_broken.clone();
258
259 match Self::critical_transaction_block(&is_broken, conn.batch_execute(&rollback_sql)).await
260 {
261 Ok(()) => {
262 match Self::get_transaction_state(conn)?
263 .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
264 {
265 Ok(()) => {}
266 Err(Error::NotInTransaction) if rolling_back_top_level => {
267 }
270 Err(e) => return Err(e),
271 }
272 Ok(())
273 }
274 Err(rollback_error) => {
275 let tm_status = Self::transaction_manager_status_mut(conn);
276 match tm_status {
277 TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
278 in_transaction:
279 Some(InTransactionStatus {
280 transaction_depth,
281 requires_rollback_maybe_up_to_top_level,
282 ..
283 }),
284 ..
285 }) if transaction_depth.get() > 1 => {
286 *transaction_depth = NonZeroU32::new(transaction_depth.get() - 1)
294 .expect("Depth was checked to be > 1");
295 *requires_rollback_maybe_up_to_top_level = true;
296 if requires_rollback_maybe_up_to_top_level_before_execute {
297 return Ok(());
300 }
301 }
302 TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
303 in_transaction: None,
304 ..
305 }) => {
306 }
311 _ => tm_status.set_in_error(),
312 }
313 Err(rollback_error)
314 }
315 }
316 }
317
318 async fn commit_transaction(conn: &mut Conn) -> QueryResult<()> {
324 let transaction_state = Self::get_transaction_state(conn)?;
325 let transaction_depth = transaction_state.transaction_depth();
326 let (commit_sql, committing_top_level) = match transaction_depth {
327 None => return Err(Error::NotInTransaction),
328 Some(transaction_depth) if transaction_depth.get() == 1 => {
329 (Cow::Borrowed("COMMIT"), true)
330 }
331 Some(transaction_depth) => (
332 Cow::Owned(format!(
333 "RELEASE SAVEPOINT diesel_savepoint_{}",
334 transaction_depth.get() - 1
335 )),
336 false,
337 ),
338 };
339 let depth = transaction_state
340 .transaction_depth()
341 .expect("We know that we are in a transaction here");
342 conn.instrumentation()
343 .on_connection_event(InstrumentationEvent::commit_transaction(depth));
344
345 let is_broken = conn.transaction_state().is_broken.clone();
346
347 match Self::critical_transaction_block(&is_broken, conn.batch_execute(&commit_sql)).await {
348 Ok(()) => {
349 match Self::get_transaction_state(conn)?
350 .change_transaction_depth(TransactionDepthChange::DecreaseDepth)
351 {
352 Ok(()) => {}
353 Err(Error::NotInTransaction) if committing_top_level => {
354 }
357 Err(e) => return Err(e),
358 }
359 Ok(())
360 }
361 Err(commit_error) => {
362 if let TransactionManagerStatus::Valid(ValidTransactionManagerStatus {
363 in_transaction:
364 Some(InTransactionStatus {
365 requires_rollback_maybe_up_to_top_level: true,
366 ..
367 }),
368 ..
369 }) = conn.transaction_state().status
370 {
371 match Self::critical_transaction_block(
372 &is_broken,
373 Self::rollback_transaction(conn),
374 )
375 .await
376 {
377 Ok(()) => {}
378 Err(rollback_error) => {
379 conn.transaction_state().status.set_in_error();
380 return Err(Error::RollbackErrorOnCommit {
381 rollback_error: Box::new(rollback_error),
382 commit_error: Box::new(commit_error),
383 });
384 }
385 }
386 }
387 Err(commit_error)
388 }
389 }
390 }
391
392 fn transaction_manager_status_mut(conn: &mut Conn) -> &mut TransactionManagerStatus {
393 &mut conn.transaction_state().status
394 }
395
396 fn is_broken_transaction_manager(conn: &mut Conn) -> bool {
397 conn.transaction_state().is_broken.load(Ordering::Relaxed)
398 || check_broken_transaction_state(conn)
399 }
400}