diesel_async/run_query_dsl/
mod.rs

1use crate::AsyncConnection;
2use diesel::associations::HasTable;
3use diesel::query_builder::IntoUpdateTarget;
4use diesel::result::QueryResult;
5use diesel::AsChangeset;
6use futures_util::{future, stream, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
7use std::pin::Pin;
8
9/// The traits used by `QueryDsl`.
10///
11/// Each trait in this module represents exactly one method from [`RunQueryDsl`].
12/// Apps should general rely on [`RunQueryDsl`] directly, rather than these traits.
13/// However, generic code may need to include a where clause that references
14/// these traits.
15pub mod methods {
16    use super::*;
17    use diesel::backend::Backend;
18    use diesel::deserialize::FromSqlRow;
19    use diesel::expression::QueryMetadata;
20    use diesel::query_builder::{AsQuery, QueryFragment, QueryId};
21    use diesel::query_dsl::CompatibleType;
22    use futures_util::{Future, Stream, TryFutureExt};
23
24    /// The `execute` method
25    ///
26    /// This trait should not be relied on directly by most apps. Its behavior is
27    /// provided by [`RunQueryDsl`]. However, you may need a where clause on this trait
28    /// to call `execute` from generic code.
29    ///
30    /// [`RunQueryDsl`]: super::RunQueryDsl
31    pub trait ExecuteDsl<Conn, DB = <Conn as AsyncConnection>::Backend>
32    where
33        Conn: AsyncConnection<Backend = DB>,
34        DB: Backend,
35    {
36        /// Execute this command
37        fn execute<'conn, 'query>(
38            query: Self,
39            conn: &'conn mut Conn,
40        ) -> Conn::ExecuteFuture<'conn, 'query>
41        where
42            Self: 'query;
43    }
44
45    impl<Conn, DB, T> ExecuteDsl<Conn, DB> for T
46    where
47        Conn: AsyncConnection<Backend = DB>,
48        DB: Backend,
49        T: QueryFragment<DB> + QueryId + Send,
50    {
51        fn execute<'conn, 'query>(
52            query: Self,
53            conn: &'conn mut Conn,
54        ) -> Conn::ExecuteFuture<'conn, 'query>
55        where
56            Self: 'query,
57        {
58            conn.execute_returning_count(query)
59        }
60    }
61
62    /// The `load` method
63    ///
64    /// This trait should not be relied on directly by most apps. Its behavior is
65    /// provided by [`RunQueryDsl`]. However, you may need a where clause on this trait
66    /// to call `load` from generic code.
67    ///
68    /// [`RunQueryDsl`]: super::RunQueryDsl
69    pub trait LoadQuery<'query, Conn: AsyncConnection, U> {
70        /// The future returned by [`LoadQuery::internal_load`]
71        type LoadFuture<'conn>: Future<Output = QueryResult<Self::Stream<'conn>>> + Send
72        where
73            Conn: 'conn;
74        /// The inner stream returned by [`LoadQuery::internal_load`]
75        type Stream<'conn>: Stream<Item = QueryResult<U>> + Send
76        where
77            Conn: 'conn;
78
79        /// Load this query
80        fn internal_load(self, conn: &mut Conn) -> Self::LoadFuture<'_>;
81    }
82
83    impl<'query, Conn, DB, T, U, ST> LoadQuery<'query, Conn, U> for T
84    where
85        Conn: AsyncConnection<Backend = DB>,
86        U: Send,
87        DB: Backend + 'static,
88        T: AsQuery + Send + 'query,
89        T::Query: QueryFragment<DB> + QueryId + Send + 'query,
90        T::SqlType: CompatibleType<U, DB, SqlType = ST>,
91        U: FromSqlRow<ST, DB> + Send + 'static,
92        DB: QueryMetadata<T::SqlType>,
93        ST: 'static,
94    {
95        type LoadFuture<'conn> = future::MapOk<
96            Conn::LoadFuture<'conn, 'query>,
97            fn(Conn::Stream<'conn, 'query>) -> Self::Stream<'conn>,
98        > where Conn: 'conn;
99
100        type Stream<'conn> = stream::Map<
101            Conn::Stream<'conn, 'query>,
102            fn(
103                QueryResult<Conn::Row<'conn, 'query>>,
104            ) -> QueryResult<U>,
105        >where  Conn: 'conn;
106
107        fn internal_load(self, conn: &mut Conn) -> Self::LoadFuture<'_> {
108            conn.load(self)
109                .map_ok(map_result_stream_future::<U, _, _, DB, ST>)
110        }
111    }
112
113    #[allow(clippy::type_complexity)]
114    fn map_result_stream_future<'s, 'a, U, S, R, DB, ST>(
115        stream: S,
116    ) -> stream::Map<S, fn(QueryResult<R>) -> QueryResult<U>>
117    where
118        S: Stream<Item = QueryResult<R>> + Send + 's,
119        R: diesel::row::Row<'a, DB> + 's,
120        DB: Backend + 'static,
121        U: FromSqlRow<ST, DB> + 'static,
122        ST: 'static,
123    {
124        stream.map(map_row_helper::<_, DB, U, ST>)
125    }
126
127    fn map_row_helper<'a, R, DB, U, ST>(row: QueryResult<R>) -> QueryResult<U>
128    where
129        U: FromSqlRow<ST, DB>,
130        R: diesel::row::Row<'a, DB>,
131        DB: Backend,
132    {
133        U::build_from_row(&row?).map_err(diesel::result::Error::DeserializationError)
134    }
135}
136
137/// The return types produced by the various [`RunQueryDsl`] methods
138///
139// We cannot box these types as this would require specifying a lifetime,
140// but concrete connection implementations might want to have control
141// about that so that they can support multiple simultaneous queries on
142// the same connection
143#[allow(type_alias_bounds)] // we need these bounds otherwise we cannot use GAT's
144pub mod return_futures {
145    use super::methods::LoadQuery;
146    use diesel::QueryResult;
147    use futures_util::{future, stream};
148    use std::pin::Pin;
149
150    /// The future returned by [`RunQueryDsl::load`](super::RunQueryDsl::load)
151    /// and [`RunQueryDsl::get_results`](super::RunQueryDsl::get_results)
152    ///
153    /// This is essentially `impl Future<Output = QueryResult<Vec<U>>>`
154    pub type LoadFuture<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> = future::AndThen<
155        Q::LoadFuture<'conn>,
156        stream::TryCollect<Q::Stream<'conn>, Vec<U>>,
157        fn(Q::Stream<'conn>) -> stream::TryCollect<Q::Stream<'conn>, Vec<U>>,
158    >;
159
160    /// The future returned by [`RunQueryDsl::get_result`](super::RunQueryDsl::get_result)
161    ///
162    /// This is essentially `impl Future<Output = QueryResult<U>>`
163    pub type GetResult<'conn, 'query, Q: LoadQuery<'query, Conn, U>, Conn, U> = future::AndThen<
164        Q::LoadFuture<'conn>,
165        future::Map<
166            stream::StreamFuture<Pin<Box<Q::Stream<'conn>>>>,
167            fn((Option<QueryResult<U>>, Pin<Box<Q::Stream<'conn>>>)) -> QueryResult<U>,
168        >,
169        fn(
170            Q::Stream<'conn>,
171        ) -> future::Map<
172            stream::StreamFuture<Pin<Box<Q::Stream<'conn>>>>,
173            fn((Option<QueryResult<U>>, Pin<Box<Q::Stream<'conn>>>)) -> QueryResult<U>,
174        >,
175    >;
176}
177
178/// Methods used to execute queries.
179pub trait RunQueryDsl<Conn>: Sized {
180    /// Executes the given command, returning the number of rows affected.
181    ///
182    /// `execute` is usually used in conjunction with [`insert_into`](diesel::insert_into()),
183    /// [`update`](diesel::update()) and [`delete`](diesel::delete()) where the number of
184    /// affected rows is often enough information.
185    ///
186    /// When asking the database to return data from a query, [`load`](crate::run_query_dsl::RunQueryDsl::load()) should
187    /// probably be used instead.
188    ///
189    /// # Example
190    ///
191    /// ```rust
192    /// # include!("../doctest_setup.rs");
193    /// #
194    /// use diesel_async::RunQueryDsl;
195    ///
196    /// # #[tokio::main(flavor = "current_thread")]
197    /// # async fn main() {
198    /// #     run_test().await;
199    /// # }
200    /// #
201    /// # async fn run_test() -> QueryResult<()> {
202    /// #     use diesel::insert_into;
203    /// #     use schema::users::dsl::*;
204    /// #     let connection = &mut establish_connection().await;
205    /// let inserted_rows = insert_into(users)
206    ///     .values(name.eq("Ruby"))
207    ///     .execute(connection)
208    ///     .await?;
209    /// assert_eq!(1, inserted_rows);
210    ///
211    /// # #[cfg(not(feature = "sqlite"))]
212    /// let inserted_rows = insert_into(users)
213    ///     .values(&vec![name.eq("Jim"), name.eq("James")])
214    ///     .execute(connection)
215    ///     .await?;
216    /// # #[cfg(not(feature = "sqlite"))]
217    /// assert_eq!(2, inserted_rows);
218    /// #     Ok(())
219    /// # }
220    /// ```
221    fn execute<'conn, 'query>(self, conn: &'conn mut Conn) -> Conn::ExecuteFuture<'conn, 'query>
222    where
223        Conn: AsyncConnection + Send,
224        Self: methods::ExecuteDsl<Conn> + 'query,
225    {
226        methods::ExecuteDsl::execute(self, conn)
227    }
228
229    /// Executes the given query, returning a [`Vec`] with the returned rows.
230    ///
231    /// When using the query builder, the return type can be
232    /// a tuple of the values, or a struct which implements [`Queryable`].
233    ///
234    /// When this method is called on [`sql_query`],
235    /// the return type can only be a struct which implements [`QueryableByName`]
236    ///
237    /// For insert, update, and delete operations where only a count of affected is needed,
238    /// [`execute`] should be used instead.
239    ///
240    /// [`Queryable`]: diesel::deserialize::Queryable
241    /// [`QueryableByName`]: diesel::deserialize::QueryableByName
242    /// [`execute`]: crate::run_query_dsl::RunQueryDsl::execute()
243    /// [`sql_query`]: diesel::sql_query()
244    ///
245    /// # Examples
246    ///
247    /// ## Returning a single field
248    ///
249    /// ```rust
250    /// # include!("../doctest_setup.rs");
251    /// #
252    /// use diesel_async::{RunQueryDsl, AsyncConnection};
253    ///
254    /// #
255    /// # #[tokio::main(flavor = "current_thread")]
256    /// # async fn main() {
257    /// #     run_test().await;
258    /// # }
259    /// #
260    /// # async fn run_test() -> QueryResult<()> {
261    /// #     use diesel::insert_into;
262    /// #     use schema::users::dsl::*;
263    /// #     let connection = &mut establish_connection().await;
264    /// let data = users.select(name)
265    ///     .load::<String>(connection)
266    ///     .await?;
267    /// assert_eq!(vec!["Sean", "Tess"], data);
268    /// #     Ok(())
269    /// # }
270    /// ```
271    ///
272    /// ## Returning a tuple
273    ///
274    /// ```rust
275    /// # include!("../doctest_setup.rs");
276    /// use diesel_async::RunQueryDsl;
277    ///
278    /// #
279    /// # #[tokio::main(flavor = "current_thread")]
280    /// # async fn main() {
281    /// #     run_test().await;
282    /// # }
283    /// #
284    /// # async fn run_test() -> QueryResult<()> {
285    /// #     use diesel::insert_into;
286    /// #     use schema::users::dsl::*;
287    /// #     let connection = &mut establish_connection().await;
288    /// let data = users
289    ///     .load::<(i32, String)>(connection)
290    ///     .await?;
291    /// let expected_data = vec![
292    ///     (1, String::from("Sean")),
293    ///     (2, String::from("Tess")),
294    /// ];
295    /// assert_eq!(expected_data, data);
296    /// #     Ok(())
297    /// # }
298    /// ```
299    ///
300    /// ## Returning a struct
301    ///
302    /// ```rust
303    /// # include!("../doctest_setup.rs");
304    /// use diesel_async::RunQueryDsl;
305    ///
306    /// #
307    /// #[derive(Queryable, PartialEq, Debug)]
308    /// struct User {
309    ///     id: i32,
310    ///     name: String,
311    /// }
312    ///
313    /// # #[tokio::main(flavor = "current_thread")]
314    /// # async fn main() {
315    /// #     run_test().await;
316    /// # }
317    /// #
318    /// # async fn run_test() -> QueryResult<()> {
319    /// #     use diesel::insert_into;
320    /// #     use schema::users::dsl::*;
321    /// #     let connection = &mut establish_connection().await;
322    /// let data = users
323    ///     .load::<User>(connection)
324    ///     .await?;
325    /// let expected_data = vec![
326    ///     User { id: 1, name: String::from("Sean") },
327    ///     User { id: 2, name: String::from("Tess") },
328    /// ];
329    /// assert_eq!(expected_data, data);
330    /// #     Ok(())
331    /// # }
332    /// ```
333    fn load<'query, 'conn, U>(
334        self,
335        conn: &'conn mut Conn,
336    ) -> return_futures::LoadFuture<'conn, 'query, Self, Conn, U>
337    where
338        U: Send,
339        Conn: AsyncConnection,
340        Self: methods::LoadQuery<'query, Conn, U> + 'query,
341    {
342        fn collect_result<U, S>(stream: S) -> stream::TryCollect<S, Vec<U>>
343        where
344            S: Stream<Item = QueryResult<U>>,
345        {
346            stream.try_collect()
347        }
348        self.internal_load(conn).and_then(collect_result::<U, _>)
349    }
350
351    /// Executes the given query, returning a [`Stream`] with the returned rows.
352    ///
353    /// **You should normally prefer to use [`RunQueryDsl::load`] instead**. This method
354    /// is provided for situations where the result needs to be collected into a different
355    /// container than a [`Vec`]
356    ///
357    /// When using the query builder, the return type can be
358    /// a tuple of the values, or a struct which implements [`Queryable`].
359    ///
360    /// When this method is called on [`sql_query`],
361    /// the return type can only be a struct which implements [`QueryableByName`]
362    ///
363    /// For insert, update, and delete operations where only a count of affected is needed,
364    /// [`execute`] should be used instead.
365    ///
366    /// [`Queryable`]: diesel::deserialize::Queryable
367    /// [`QueryableByName`]: diesel::deserialize::QueryableByName
368    /// [`execute`]: crate::run_query_dsl::RunQueryDsl::execute()
369    /// [`sql_query`]: diesel::sql_query()
370    ///
371    /// # Examples
372    ///
373    /// ## Returning a single field
374    ///
375    /// ```rust
376    /// # include!("../doctest_setup.rs");
377    /// #
378    /// use diesel_async::RunQueryDsl;
379    ///
380    /// # #[tokio::main(flavor = "current_thread")]
381    /// # async fn main() {
382    /// #     run_test().await;
383    /// # }
384    /// #
385    /// # async fn run_test() -> QueryResult<()> {
386    /// #     use diesel::insert_into;
387    /// #     use schema::users::dsl::*;
388    /// #     use futures_util::stream::TryStreamExt;
389    /// #     let connection = &mut establish_connection().await;
390    /// let data = users.select(name)
391    ///     .load_stream::<String>(connection)
392    ///     .await?
393    ///     .try_fold(Vec::new(), |mut acc, item| {
394    ///          acc.push(item);
395    ///          futures_util::future::ready(Ok(acc))
396    ///      })
397    ///     .await?;
398    /// assert_eq!(vec!["Sean", "Tess"], data);
399    /// #     Ok(())
400    /// # }
401    /// ```
402    ///
403    /// ## Returning a tuple
404    ///
405    /// ```rust
406    /// # include!("../doctest_setup.rs");
407    /// use diesel_async::RunQueryDsl;
408    /// #
409    /// # #[tokio::main(flavor = "current_thread")]
410    /// # async fn main() {
411    /// #     run_test().await;
412    /// # }
413    /// #
414    /// # async fn run_test() -> QueryResult<()> {
415    /// #     use diesel::insert_into;
416    /// #     use schema::users::dsl::*;
417    /// #     use futures_util::stream::TryStreamExt;
418    /// #     let connection = &mut establish_connection().await;
419    /// let data = users
420    ///     .load_stream::<(i32, String)>(connection)
421    ///     .await?
422    ///     .try_fold(Vec::new(), |mut acc, item| {
423    ///          acc.push(item);
424    ///          futures_util::future::ready(Ok(acc))
425    ///      })
426    ///     .await?;
427    /// let expected_data = vec![
428    ///     (1, String::from("Sean")),
429    ///     (2, String::from("Tess")),
430    /// ];
431    /// assert_eq!(expected_data, data);
432    /// #     Ok(())
433    /// # }
434    /// ```
435    ///
436    /// ## Returning a struct
437    ///
438    /// ```rust
439    /// # include!("../doctest_setup.rs");
440    /// #
441    /// use diesel_async::RunQueryDsl;
442    ///
443    /// #[derive(Queryable, PartialEq, Debug)]
444    /// struct User {
445    ///     id: i32,
446    ///     name: String,
447    /// }
448    ///
449    /// # #[tokio::main(flavor = "current_thread")]
450    /// # async fn main() {
451    /// #     run_test().await;
452    /// # }
453    /// #
454    /// # async fn run_test() -> QueryResult<()> {
455    /// #     use diesel::insert_into;
456    /// #     use schema::users::dsl::*;
457    /// #     use futures_util::stream::TryStreamExt;
458    /// #     let connection = &mut establish_connection().await;
459    /// let data = users
460    ///     .load_stream::<User>(connection)
461    ///     .await?
462    ///     .try_fold(Vec::new(), |mut acc, item| {
463    ///          acc.push(item);
464    ///          futures_util::future::ready(Ok(acc))
465    ///      })
466    ///     .await?;
467    /// let expected_data = vec![
468    ///     User { id: 1, name: String::from("Sean") },
469    ///     User { id: 2, name: String::from("Tess") },
470    /// ];
471    /// assert_eq!(expected_data, data);
472    /// #     Ok(())
473    /// # }
474    /// ```
475    fn load_stream<'conn, 'query, U>(self, conn: &'conn mut Conn) -> Self::LoadFuture<'conn>
476    where
477        Conn: AsyncConnection,
478        U: 'conn,
479        Self: methods::LoadQuery<'query, Conn, U> + 'query,
480    {
481        self.internal_load(conn)
482    }
483
484    /// Runs the command, and returns the affected row.
485    ///
486    /// `Err(NotFound)` will be returned if the query affected 0 rows. You can
487    /// call `.optional()` on the result of this if the command was optional to
488    /// get back a `Result<Option<U>>`
489    ///
490    /// When this method is called on an insert, update, or delete statement,
491    /// it will implicitly add a `RETURNING *` to the query,
492    /// unless a returning clause was already specified.
493    ///
494    /// This method only returns the first row that was affected, even if more
495    /// rows are affected.
496    ///
497    /// # Example
498    ///
499    /// ```rust
500    /// # include!("../doctest_setup.rs");
501    /// use diesel_async::RunQueryDsl;
502    ///
503    /// #
504    /// # #[tokio::main(flavor = "current_thread")]
505    /// # async fn main() {
506    /// #     run_test().await;
507    /// # }
508    /// #
509    /// # #[cfg(feature = "postgres")]
510    /// # async fn run_test() -> QueryResult<()> {
511    /// #     use diesel::{insert_into, update};
512    /// #     use schema::users::dsl::*;
513    /// #     let connection = &mut establish_connection().await;
514    /// let inserted_row = insert_into(users)
515    ///     .values(name.eq("Ruby"))
516    ///     .get_result(connection)
517    ///     .await?;
518    /// assert_eq!((3, String::from("Ruby")), inserted_row);
519    ///
520    /// // This will return `NotFound`, as there is no user with ID 4
521    /// let update_result = update(users.find(4))
522    ///     .set(name.eq("Jim"))
523    ///     .get_result::<(i32, String)>(connection)
524    ///     .await;
525    /// assert_eq!(Err(diesel::NotFound), update_result);
526    /// #     Ok(())
527    /// # }
528    /// #
529    /// # #[cfg(not(feature = "postgres"))]
530    /// # async fn run_test() -> QueryResult<()> {
531    /// #     Ok(())
532    /// # }
533    /// ```
534    fn get_result<'query, 'conn, U>(
535        self,
536        conn: &'conn mut Conn,
537    ) -> return_futures::GetResult<'conn, 'query, Self, Conn, U>
538    where
539        U: Send + 'conn,
540        Conn: AsyncConnection,
541        Self: methods::LoadQuery<'query, Conn, U> + 'query,
542    {
543        #[allow(clippy::type_complexity)]
544        fn get_next_stream_element<S, U>(
545            stream: S,
546        ) -> future::Map<
547            stream::StreamFuture<Pin<Box<S>>>,
548            fn((Option<QueryResult<U>>, Pin<Box<S>>)) -> QueryResult<U>,
549        >
550        where
551            S: Stream<Item = QueryResult<U>>,
552        {
553            fn map_option_to_result<U, S>(
554                (o, _): (Option<QueryResult<U>>, Pin<Box<S>>),
555            ) -> QueryResult<U> {
556                match o {
557                    Some(s) => s,
558                    None => Err(diesel::result::Error::NotFound),
559                }
560            }
561
562            Box::pin(stream).into_future().map(map_option_to_result)
563        }
564
565        self.load_stream(conn).and_then(get_next_stream_element)
566    }
567
568    /// Runs the command, returning an `Vec` with the affected rows.
569    ///
570    /// This method is an alias for [`load`], but with a name that makes more
571    /// sense for insert, update, and delete statements.
572    ///
573    /// [`load`]: crate::run_query_dsl::RunQueryDsl::load()
574    fn get_results<'query, 'conn, U>(
575        self,
576        conn: &'conn mut Conn,
577    ) -> return_futures::LoadFuture<'conn, 'query, Self, Conn, U>
578    where
579        U: Send,
580        Conn: AsyncConnection,
581        Self: methods::LoadQuery<'query, Conn, U> + 'query,
582    {
583        self.load(conn)
584    }
585
586    /// Attempts to load a single record.
587    ///
588    /// This method is equivalent to `.limit(1).get_result()`
589    ///
590    /// Returns `Ok(record)` if found, and `Err(NotFound)` if no results are
591    /// returned. If the query truly is optional, you can call `.optional()` on
592    /// the result of this to get a `Result<Option<U>>`.
593    ///
594    /// # Example:
595    ///
596    /// ```rust
597    /// # include!("../doctest_setup.rs");
598    /// use diesel_async::RunQueryDsl;
599    ///
600    /// #
601    /// # #[tokio::main(flavor = "current_thread")]
602    /// # async fn main() {
603    /// #     run_test();
604    /// # }
605    /// #
606    /// # async fn run_test() -> QueryResult<()> {
607    /// #     use schema::users::dsl::*;
608    /// #     let connection = &mut establish_connection().await;
609    /// for n in &["Sean", "Pascal"] {
610    ///     diesel::insert_into(users)
611    ///         .values(name.eq(n))
612    ///         .execute(connection)
613    ///         .await?;
614    /// }
615    ///
616    /// let first_name = users.order(id)
617    ///     .select(name)
618    ///     .first(connection)
619    ///     .await;
620    /// assert_eq!(Ok(String::from("Sean")), first_name);
621    ///
622    /// let not_found = users
623    ///     .filter(name.eq("Foo"))
624    ///     .first::<(i32, String)>(connection)
625    ///     .await;
626    /// assert_eq!(Err(diesel::NotFound), not_found);
627    /// #     Ok(())
628    /// # }
629    /// ```
630    fn first<'query, 'conn, U>(
631        self,
632        conn: &'conn mut Conn,
633    ) -> return_futures::GetResult<'conn, 'query, diesel::dsl::Limit<Self>, Conn, U>
634    where
635        U: Send + 'conn,
636        Conn: AsyncConnection,
637        Self: diesel::query_dsl::methods::LimitDsl,
638        diesel::dsl::Limit<Self>: methods::LoadQuery<'query, Conn, U> + Send + 'query,
639    {
640        diesel::query_dsl::methods::LimitDsl::limit(self, 1).get_result(conn)
641    }
642}
643
644impl<T, Conn> RunQueryDsl<Conn> for T {}
645
646/// Sugar for types which implement both `AsChangeset` and `Identifiable`
647///
648/// On backends which support the `RETURNING` keyword,
649/// `foo.save_changes(&conn)` is equivalent to
650/// `update(&foo).set(&foo).get_result(&conn)`.
651/// On other backends, two queries will be executed.
652///
653/// # Example
654///
655/// ```rust
656/// # include!("../doctest_setup.rs");
657/// # use schema::animals;
658/// #
659/// use diesel_async::{SaveChangesDsl, AsyncConnection};
660///
661/// #[derive(Queryable, Debug, PartialEq)]
662/// struct Animal {
663///    id: i32,
664///    species: String,
665///    legs: i32,
666///    name: Option<String>,
667/// }
668///
669/// #[derive(AsChangeset, Identifiable)]
670/// #[diesel(table_name = animals)]
671/// struct AnimalForm<'a> {
672///     id: i32,
673///     name: &'a str,
674/// }
675///
676/// # #[tokio::main(flavor = "current_thread")]
677/// # async fn main() {
678/// #     run_test().await.unwrap();
679/// # }
680/// #
681/// # async fn run_test() -> QueryResult<()> {
682/// #     use self::animals::dsl::*;
683/// #     let connection = &mut establish_connection().await;
684/// let form = AnimalForm { id: 2, name: "Super scary" };
685/// # #[cfg(not(feature = "sqlite"))]
686/// let changed_animal = form.save_changes(connection).await?;
687/// let expected_animal = Animal {
688///     id: 2,
689///     species: String::from("spider"),
690///     legs: 8,
691///     name: Some(String::from("Super scary")),
692/// };
693/// # #[cfg(not(feature = "sqlite"))]
694/// assert_eq!(expected_animal, changed_animal);
695/// #     Ok(())
696/// # }
697/// ```
698#[async_trait::async_trait]
699pub trait SaveChangesDsl<Conn> {
700    /// See the trait documentation
701    async fn save_changes<T>(self, connection: &mut Conn) -> QueryResult<T>
702    where
703        Self: Sized + diesel::prelude::Identifiable,
704        Conn: UpdateAndFetchResults<Self, T>,
705    {
706        connection.update_and_fetch(self).await
707    }
708}
709
710impl<T, Conn> SaveChangesDsl<Conn> for T where
711    T: Copy + AsChangeset<Target = <T as HasTable>::Table> + IntoUpdateTarget
712{
713}
714
715/// A trait defining how to update a record and fetch the updated entry
716/// on a certain backend.
717///
718/// The only case where it is required to work with this trait is while
719/// implementing a new connection type.
720/// Otherwise use [`SaveChangesDsl`]
721///
722/// For implementing this trait for a custom backend:
723/// * The `Changes` generic parameter represents the changeset that should be stored
724/// * The `Output` generic parameter represents the type of the response.
725#[async_trait::async_trait]
726pub trait UpdateAndFetchResults<Changes, Output>: AsyncConnection
727where
728    Changes: diesel::prelude::Identifiable + HasTable,
729{
730    /// See the traits documentation.
731    async fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult<Output>
732    where
733        Changes: 'async_trait;
734}
735
736#[cfg(feature = "mysql")]
737#[async_trait::async_trait]
738impl<'b, Changes, Output> UpdateAndFetchResults<Changes, Output> for crate::AsyncMysqlConnection
739where
740    Output: Send,
741    Changes: Copy + diesel::Identifiable + Send,
742    Changes: AsChangeset<Target = <Changes as HasTable>::Table> + IntoUpdateTarget,
743    Changes::Table: diesel::query_dsl::methods::FindDsl<Changes::Id> + Send,
744    Changes::WhereClause: Send,
745    Changes::Changeset: Send,
746    Changes::Id: Send,
747    diesel::dsl::Update<Changes, Changes>: methods::ExecuteDsl<crate::AsyncMysqlConnection>,
748    diesel::dsl::Find<Changes::Table, Changes::Id>:
749        methods::LoadQuery<'b, crate::AsyncMysqlConnection, Output> + Send + 'b,
750    <Changes::Table as diesel::Table>::AllColumns: diesel::expression::ValidGrouping<()>,
751    <<Changes::Table as diesel::Table>::AllColumns as diesel::expression::ValidGrouping<()>>::IsAggregate: diesel::expression::MixedAggregates<
752        diesel::expression::is_aggregate::No,
753        Output = diesel::expression::is_aggregate::No,
754    >,
755    <Changes::Table as diesel::query_source::QuerySource>::FromClause: Send,
756{
757    async fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult<Output>
758    where
759        Changes: 'async_trait,
760    {
761        use diesel::query_dsl::methods::FindDsl;
762
763        diesel::update(changeset)
764            .set(changeset)
765            .execute(self)
766            .await?;
767        Changes::table().find(changeset.id()).get_result(self).await
768    }
769}
770
771#[cfg(feature = "postgres")]
772#[async_trait::async_trait]
773impl<'b, Changes, Output, Tab, V> UpdateAndFetchResults<Changes, Output>
774    for crate::AsyncPgConnection
775where
776    Output: Send,
777    Changes:
778        Copy + AsChangeset<Target = Tab> + Send + diesel::associations::Identifiable<Table = Tab>,
779    Tab: diesel::Table + diesel::query_dsl::methods::FindDsl<Changes::Id> + 'b,
780    diesel::dsl::Find<Tab, Changes::Id>: IntoUpdateTarget<Table = Tab, WhereClause = V>,
781    diesel::query_builder::UpdateStatement<Tab, V, Changes::Changeset>:
782        diesel::query_builder::AsQuery,
783    diesel::dsl::Update<Changes, Changes>: methods::LoadQuery<'b, Self, Output>,
784    V: Send + 'b,
785    Changes::Changeset: Send + 'b,
786    Tab::FromClause: Send,
787{
788    async fn update_and_fetch(&mut self, changeset: Changes) -> QueryResult<Output>
789    where
790        Changes: 'async_trait,
791        Changes::Changeset: 'async_trait,
792    {
793        diesel::update(changeset)
794            .set(changeset)
795            .get_result(self)
796            .await
797    }
798}