diesel_async/pg/
transaction_builder.rs

1use crate::{AnsiTransactionManager, AsyncConnection, TransactionManager};
2use diesel::backend::Backend;
3use diesel::pg::Pg;
4use diesel::query_builder::{AstPass, QueryBuilder, QueryFragment};
5use diesel::QueryResult;
6use scoped_futures::ScopedBoxFuture;
7
8/// Used to build a transaction, specifying additional details.
9///
10/// This struct is returned by [`AsyncPgConnection::build_transaction`].
11/// See the documentation for methods on this struct for usage examples.
12/// See [the PostgreSQL documentation for `SET TRANSACTION`][pg-docs]
13/// for details on the behavior of each option.
14///
15/// [`AsyncPgConnection::build_transaction`]: super::AsyncPgConnection::build_transaction()
16/// [pg-docs]: https://www.postgresql.org/docs/current/static/sql-set-transaction.html
17#[must_use = "Transaction builder does nothing unless you call `run` on it"]
18#[cfg(feature = "postgres")]
19pub struct TransactionBuilder<'a, C> {
20    connection: &'a mut C,
21    isolation_level: Option<IsolationLevel>,
22    read_mode: Option<ReadMode>,
23    deferrable: Option<Deferrable>,
24}
25
26impl<'a, C> TransactionBuilder<'a, C>
27where
28    C: AsyncConnection<Backend = Pg, TransactionManager = AnsiTransactionManager>,
29{
30    pub(crate) fn new(connection: &'a mut C) -> Self {
31        Self {
32            connection,
33            isolation_level: None,
34            read_mode: None,
35            deferrable: None,
36        }
37    }
38
39    /// Makes the transaction `READ ONLY`
40    ///
41    /// # Example
42    ///
43    /// ```rust
44    /// # include!("../doctest_setup.rs");
45    /// # use diesel::sql_query;
46    /// use diesel_async::RunQueryDsl;
47    /// #
48    /// # #[tokio::main(flavor = "current_thread")]
49    /// # async fn main() {
50    /// #     run_test().await.unwrap();
51    /// # }
52    /// #
53    /// # diesel::table! {
54    /// #     users_for_read_only {
55    /// #         id -> Integer,
56    /// #         name -> Text,
57    /// #     }
58    /// # }
59    /// #
60    /// # async fn run_test() -> QueryResult<()> {
61    /// #     use users_for_read_only::table as users;
62    /// #     use users_for_read_only::columns::*;
63    /// #     let conn = &mut connection_no_transaction().await;
64    /// #     sql_query("CREATE TABLE IF NOT EXISTS users_for_read_only (
65    /// #       id SERIAL PRIMARY KEY,
66    /// #       name TEXT NOT NULL
67    /// #     )").execute(conn).await?;
68    /// conn.build_transaction()
69    ///     .read_only()
70    ///     .run::<_, diesel::result::Error, _>(|conn| Box::pin(async move {
71    ///         let read_attempt = users.select(name).load::<String>(conn).await;
72    ///         assert!(read_attempt.is_ok());
73    ///
74    ///         let write_attempt = diesel::insert_into(users)
75    ///             .values(name.eq("Ruby"))
76    ///             .execute(conn)
77    ///             .await;
78    ///         assert!(write_attempt.is_err());
79    ///
80    ///         Ok(())
81    ///     }) as _).await?;
82    /// #     sql_query("DROP TABLE users_for_read_only").execute(conn).await?;
83    /// #     Ok(())
84    /// # }
85    /// ```
86    pub fn read_only(mut self) -> Self {
87        self.read_mode = Some(ReadMode::ReadOnly);
88        self
89    }
90
91    /// Makes the transaction `READ WRITE`
92    ///
93    /// This is the default, unless you've changed the
94    /// `default_transaction_read_only` configuration parameter.
95    ///
96    /// # Example
97    ///
98    /// ```rust
99    /// # include!("../doctest_setup.rs");
100    /// # use diesel::result::Error::RollbackTransaction;
101    /// # use diesel::sql_query;
102    /// use diesel_async::RunQueryDsl;
103    ///
104    /// #
105    /// # #[tokio::main(flavor = "current_thread")]
106    /// # async fn main() {
107    /// #     assert_eq!(run_test().await, Err(RollbackTransaction));
108    /// # }
109    /// #
110    /// # async fn run_test() -> QueryResult<()> {
111    /// #     use schema::users::dsl::*;
112    /// #     let conn = &mut connection_no_transaction().await;
113    /// conn.build_transaction()
114    ///     .read_write()
115    ///     .run(|conn| Box::pin( async move {
116    /// #         sql_query("CREATE TABLE IF NOT EXISTS users (
117    /// #             id SERIAL PRIMARY KEY,
118    /// #             name TEXT NOT NULL
119    /// #         )").execute(conn).await?;
120    ///         let read_attempt = users.select(name).load::<String>(conn).await;
121    ///         assert!(read_attempt.is_ok());
122    ///
123    ///         let write_attempt = diesel::insert_into(users)
124    ///             .values(name.eq("Ruby"))
125    ///             .execute(conn)
126    ///             .await;
127    ///         assert!(write_attempt.is_ok());
128    ///
129    /// #       Err(RollbackTransaction)
130    /// #       /*
131    ///         Ok(())
132    /// #       */
133    ///     }) as _)
134    ///     .await
135    /// # }
136    /// ```
137    pub fn read_write(mut self) -> Self {
138        self.read_mode = Some(ReadMode::ReadWrite);
139        self
140    }
141
142    /// Makes the transaction `DEFERRABLE`
143    ///
144    /// # Example
145    ///
146    /// ```rust
147    /// # include!("../doctest_setup.rs");
148    /// #
149    /// # #[tokio::main(flavor = "current_thread")]
150    /// # async fn main() {
151    /// #     run_test().await.unwrap();
152    /// # }
153    /// #
154    /// # async fn run_test() -> QueryResult<()> {
155    /// #     use schema::users::dsl::*;
156    /// #     let conn = &mut connection_no_transaction().await;
157    /// conn.build_transaction()
158    ///     .deferrable()
159    ///     .run(|conn| Box::pin(async { Ok(()) }))
160    ///     .await
161    /// # }
162    /// ```
163    pub fn deferrable(mut self) -> Self {
164        self.deferrable = Some(Deferrable::Deferrable);
165        self
166    }
167
168    /// Makes the transaction `NOT DEFERRABLE`
169    ///
170    /// This is the default, unless you've changed the
171    /// `default_transaction_deferrable` configuration parameter.
172    ///
173    /// # Example
174    ///
175    /// ```rust
176    /// # include!("../doctest_setup.rs");
177    /// #
178    /// # #[tokio::main(flavor = "current_thread")]
179    /// # async fn main() {
180    /// #     run_test().await.unwrap();
181    /// # }
182    /// #
183    /// # async fn run_test() -> QueryResult<()> {
184    /// #     use schema::users::dsl::*;
185    /// #     let conn = &mut connection_no_transaction().await;
186    /// conn.build_transaction()
187    ///     .not_deferrable()
188    ///     .run(|conn| Box::pin(async { Ok(()) }) as _)
189    ///     .await
190    /// # }
191    /// ```
192    pub fn not_deferrable(mut self) -> Self {
193        self.deferrable = Some(Deferrable::NotDeferrable);
194        self
195    }
196
197    /// Makes the transaction `ISOLATION LEVEL READ COMMITTED`
198    ///
199    /// This is the default, unless you've changed the
200    /// `default_transaction_isolation_level` configuration parameter.
201    ///
202    /// # Example
203    ///
204    /// ```rust
205    /// # include!("../doctest_setup.rs");
206    /// #
207    /// # #[tokio::main(flavor = "current_thread")]
208    /// # async fn main() {
209    /// #     run_test().await.unwrap();
210    /// # }
211    /// #
212    /// # async fn run_test() -> QueryResult<()> {
213    /// #     use schema::users::dsl::*;
214    /// #     let conn = &mut connection_no_transaction().await;
215    /// conn.build_transaction()
216    ///     .read_committed()
217    ///     .run(|conn| Box::pin(async { Ok(()) }) as _)
218    ///     .await
219    /// # }
220    /// ```
221    pub fn read_committed(mut self) -> Self {
222        self.isolation_level = Some(IsolationLevel::ReadCommitted);
223        self
224    }
225
226    /// Makes the transaction `ISOLATION LEVEL REPEATABLE READ`
227    ///
228    /// # Example
229    ///
230    /// ```rust
231    /// # include!("../doctest_setup.rs");
232    /// #
233    /// # #[tokio::main(flavor = "current_thread")]
234    /// # async fn main() {
235    /// #     run_test().await.unwrap();
236    /// # }
237    /// #
238    /// # async fn run_test() -> QueryResult<()> {
239    /// #     use schema::users::dsl::*;
240    /// #     let conn = &mut connection_no_transaction().await;
241    /// conn.build_transaction()
242    ///     .repeatable_read()
243    ///     .run(|conn| Box::pin(async { Ok(()) }) as _)
244    ///     .await
245    /// # }
246    /// ```
247    pub fn repeatable_read(mut self) -> Self {
248        self.isolation_level = Some(IsolationLevel::RepeatableRead);
249        self
250    }
251
252    /// Makes the transaction `ISOLATION LEVEL SERIALIZABLE`
253    ///
254    /// # Example
255    ///
256    /// ```rust
257    /// # include!("../doctest_setup.rs");
258    /// #
259    /// # #[tokio::main(flavor = "current_thread")]
260    /// # async fn main() {
261    /// #     run_test().await.unwrap();
262    /// # }
263    /// #
264    /// # async fn run_test() -> QueryResult<()> {
265    /// #     use schema::users::dsl::*;
266    /// #     let conn = &mut connection_no_transaction().await;
267    /// conn.build_transaction()
268    ///     .serializable()
269    ///     .run(|conn| Box::pin(async { Ok(()) }) as _)
270    ///     .await
271    /// # }
272    /// ```
273    pub fn serializable(mut self) -> Self {
274        self.isolation_level = Some(IsolationLevel::Serializable);
275        self
276    }
277
278    /// Runs the given function inside of the transaction
279    /// with the parameters given to this builder.
280    ///
281    /// Returns an error if the connection is already inside a transaction,
282    /// or if the transaction fails to commit or rollback
283    ///
284    /// If the transaction fails to commit due to a `SerializationFailure` or a
285    /// `ReadOnlyTransaction` a rollback will be attempted. If the rollback succeeds,
286    /// the original error will be returned, otherwise the error generated by the rollback
287    /// will be returned. In the second case the connection should be considered broken
288    /// as it contains a uncommitted unabortable open transaction.
289    pub async fn run<'b, T, E, F>(&mut self, f: F) -> Result<T, E>
290    where
291        F: for<'r> FnOnce(&'r mut C) -> ScopedBoxFuture<'b, 'r, Result<T, E>> + Send + 'a,
292        T: 'b,
293        E: From<diesel::result::Error> + 'b,
294    {
295        let mut query_builder = <Pg as Backend>::QueryBuilder::default();
296        self.to_sql(&mut query_builder, &Pg)?;
297        let sql = query_builder.finish();
298
299        AnsiTransactionManager::begin_transaction_sql(&mut *self.connection, &sql).await?;
300        match f(&mut *self.connection).await {
301            Ok(value) => {
302                AnsiTransactionManager::commit_transaction(&mut *self.connection).await?;
303                Ok(value)
304            }
305            Err(e) => {
306                AnsiTransactionManager::rollback_transaction(&mut *self.connection).await?;
307                Err(e)
308            }
309        }
310    }
311}
312
313impl<'a, C> QueryFragment<Pg> for TransactionBuilder<'a, C> {
314    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
315        out.push_sql("BEGIN TRANSACTION");
316        if let Some(ref isolation_level) = self.isolation_level {
317            isolation_level.walk_ast(out.reborrow())?;
318        }
319        if let Some(ref read_mode) = self.read_mode {
320            read_mode.walk_ast(out.reborrow())?;
321        }
322        if let Some(ref deferrable) = self.deferrable {
323            deferrable.walk_ast(out.reborrow())?;
324        }
325        Ok(())
326    }
327}
328
329#[derive(Debug, Clone, Copy)]
330enum IsolationLevel {
331    ReadCommitted,
332    RepeatableRead,
333    Serializable,
334}
335
336impl QueryFragment<Pg> for IsolationLevel {
337    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
338        out.push_sql(" ISOLATION LEVEL ");
339        match *self {
340            IsolationLevel::ReadCommitted => out.push_sql("READ COMMITTED"),
341            IsolationLevel::RepeatableRead => out.push_sql("REPEATABLE READ"),
342            IsolationLevel::Serializable => out.push_sql("SERIALIZABLE"),
343        }
344        Ok(())
345    }
346}
347
348#[derive(Debug, Clone, Copy)]
349enum ReadMode {
350    ReadOnly,
351    ReadWrite,
352}
353
354impl QueryFragment<Pg> for ReadMode {
355    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
356        match *self {
357            ReadMode::ReadOnly => out.push_sql(" READ ONLY"),
358            ReadMode::ReadWrite => out.push_sql(" READ WRITE"),
359        }
360        Ok(())
361    }
362}
363
364#[derive(Debug, Clone, Copy)]
365enum Deferrable {
366    Deferrable,
367    NotDeferrable,
368}
369
370impl QueryFragment<Pg> for Deferrable {
371    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
372        match *self {
373            Deferrable::Deferrable => out.push_sql(" DEFERRABLE"),
374            Deferrable::NotDeferrable => out.push_sql(" NOT DEFERRABLE"),
375        }
376        Ok(())
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[tokio::test]
385    async fn test_transaction_builder_generates_correct_sql() {
386        macro_rules! assert_sql {
387            ($query:expr, $sql:expr) => {
388                let mut query_builder = <Pg as Backend>::QueryBuilder::default();
389                $query.to_sql(&mut query_builder, &Pg).unwrap();
390                let sql = query_builder.finish();
391                assert_eq!(sql, $sql);
392            };
393        }
394
395        let database_url =
396            dbg!(std::env::var("DATABASE_URL")
397                .expect("DATABASE_URL must be set in order to run tests"));
398        let mut conn = crate::AsyncPgConnection::establish(&database_url)
399            .await
400            .unwrap();
401
402        assert_sql!(conn.build_transaction(), "BEGIN TRANSACTION");
403        assert_sql!(
404            conn.build_transaction().read_only(),
405            "BEGIN TRANSACTION READ ONLY"
406        );
407        assert_sql!(
408            conn.build_transaction().read_write(),
409            "BEGIN TRANSACTION READ WRITE"
410        );
411        assert_sql!(
412            conn.build_transaction().deferrable(),
413            "BEGIN TRANSACTION DEFERRABLE"
414        );
415        assert_sql!(
416            conn.build_transaction().not_deferrable(),
417            "BEGIN TRANSACTION NOT DEFERRABLE"
418        );
419        assert_sql!(
420            conn.build_transaction().read_committed(),
421            "BEGIN TRANSACTION ISOLATION LEVEL READ COMMITTED"
422        );
423        assert_sql!(
424            conn.build_transaction().repeatable_read(),
425            "BEGIN TRANSACTION ISOLATION LEVEL REPEATABLE READ"
426        );
427        assert_sql!(
428            conn.build_transaction().serializable(),
429            "BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE"
430        );
431        assert_sql!(
432            conn.build_transaction()
433                .serializable()
434                .deferrable()
435                .read_only(),
436            "BEGIN TRANSACTION ISOLATION LEVEL SERIALIZABLE READ ONLY DEFERRABLE"
437        );
438    }
439}