use super::{
executor::Executor, mutation::MutationExt, query::QueryExt, schema::Schema, DatabaseDriver,
EncodeColumn,
};
use std::fmt::Display;
use zino_core::{
error::Error,
extension::JsonValueExt,
model::{Mutation, Query},
BoxFuture, Map,
};
#[cfg(feature = "orm-sqlx")]
use sqlx::Acquire;
pub trait Transaction<K, Tx>: Schema<PrimaryKey = K>
where
K: Default + Display + PartialEq,
{
async fn transaction<F, T>(tx: F) -> Result<T, Error>
where
F: for<'t> FnOnce(&'t mut Tx) -> BoxFuture<'t, Result<T, Error>>;
async fn transactional_execute(queries: &[&str], params: Option<&Map>) -> Result<u64, Error>;
async fn transactional_insert<M: Schema>(self, models: Vec<M>) -> Result<u64, Error>;
async fn transactional_update<M: Schema>(
queries: (&Query, &Query),
mutations: (&mut Mutation, &mut Mutation),
) -> Result<u64, Error>;
async fn transactional_delete<M: Schema>(queries: (&Query, &Query)) -> Result<u64, Error>;
}
#[cfg(feature = "orm-sqlx")]
impl<'c, M, K> Transaction<K, sqlx::Transaction<'c, DatabaseDriver>> for M
where
M: Schema<PrimaryKey = K>,
K: Default + Display + PartialEq,
{
async fn transaction<F, T>(tx: F) -> Result<T, Error>
where
F: for<'t> FnOnce(
&'t mut sqlx::Transaction<'c, DatabaseDriver>,
) -> BoxFuture<'t, Result<T, Error>>,
{
let mut transaction = Self::acquire_writer().await?.pool().begin().await?;
let data = tx(&mut transaction).await?;
transaction.commit().await?;
Ok(data)
}
async fn transactional_execute(queries: &[&str], params: Option<&Map>) -> Result<u64, Error> {
let mut transaction = Self::acquire_writer().await?.pool().begin().await?;
let connection = transaction.acquire().await?;
let mut total_rows = 0;
for query in queries {
let (sql, values) = Query::prepare_query(query, params);
let mut ctx = Self::before_scan(&sql).await?;
ctx.set_query(sql);
let mut arguments = values
.iter()
.map(|v| v.to_string_unquoted())
.collect::<Vec<_>>();
let rows_affected = connection
.execute_with(ctx.query(), &arguments)
.await?
.rows_affected();
total_rows += rows_affected;
ctx.append_arguments(&mut arguments);
ctx.set_query_result(rows_affected, true);
Self::after_scan(&ctx).await?;
}
transaction.commit().await?;
Ok(total_rows)
}
async fn transactional_insert<S: Schema>(mut self, associations: Vec<S>) -> Result<u64, Error> {
let mut transaction = Self::acquire_writer().await?.pool().begin().await?;
let connection = transaction.acquire().await?;
let model_data = self.before_insert().await?;
let map = self.into_map();
let columns = Self::columns();
let mut fields = Vec::with_capacity(columns.len());
let values = columns
.iter()
.filter_map(|col| {
if col.auto_increment() {
None
} else {
let name = col.name();
fields.push(name);
Some(col.encode_value(map.get(name)))
}
})
.collect::<Vec<_>>()
.join(", ");
let fields = fields.join(", ");
let table_name = Query::table_name_escaped::<Self>();
let sql = format!("INSERT INTO {table_name} ({fields}) VALUES ({values});");
let mut ctx = Self::before_scan(&sql).await?;
ctx.set_query(sql);
let mut total_rows = 0;
let query_result = connection.execute(ctx.query()).await?;
let (last_insert_id, rows_affected) = Query::parse_query_result(query_result);
let success = rows_affected == 1;
if let Some(last_insert_id) = last_insert_id {
ctx.set_last_insert_id(last_insert_id);
}
total_rows += rows_affected;
ctx.set_query_result(rows_affected, success);
Self::after_scan(&ctx).await?;
Self::after_insert(&ctx, model_data).await?;
let columns = S::columns();
let mut values = Vec::with_capacity(associations.len());
for mut association in associations.into_iter() {
let _association_data = association.before_insert().await?;
let map = association.into_map();
let entries = columns
.iter()
.map(|col| col.encode_value(map.get(col.name())))
.collect::<Vec<_>>()
.join(", ");
values.push(format!("({entries})"));
}
let table_name = Query::table_name_escaped::<S>();
let fields = S::fields().join(", ");
let values = values.join(", ");
let sql = format!("INSERT INTO {table_name} ({fields}) VALUES {values};");
let mut ctx = S::before_scan(&sql).await?;
ctx.set_query(sql);
let rows_affected = connection.execute(ctx.query()).await?.rows_affected();
total_rows += rows_affected;
ctx.set_query_result(rows_affected, true);
S::after_scan(&ctx).await?;
transaction.commit().await?;
Ok(total_rows)
}
async fn transactional_update<S: Schema>(
queries: (&Query, &Query),
mutations: (&mut Mutation, &mut Mutation),
) -> Result<u64, Error> {
let mut transaction = Self::acquire_writer().await?.pool().begin().await?;
let connection = transaction.acquire().await?;
let query = queries.0;
let mutation = mutations.0;
Self::before_mutation(query, mutation).await?;
let table_name = query.format_table_name::<Self>();
let filters = query.format_filters::<Self>();
let updates = mutation.format_updates::<Self>();
let sql = format!("UPDATE {table_name} SET {updates} {filters};");
let mut ctx = Self::before_scan(&sql).await?;
ctx.set_query(sql);
let mut total_rows = 0;
let rows_affected = connection.execute(ctx.query()).await?.rows_affected();
total_rows += rows_affected;
ctx.set_query_result(rows_affected, true);
Self::after_scan(&ctx).await?;
Self::after_mutation(&ctx).await?;
let query = queries.1;
let mutation = mutations.1;
S::before_mutation(query, mutation).await?;
let table_name = query.format_table_name::<S>();
let filters = query.format_filters::<S>();
let updates = mutation.format_updates::<S>();
let sql = format!("UPDATE {table_name} SET {updates} {filters};");
let mut ctx = S::before_scan(&sql).await?;
ctx.set_query(sql);
let rows_affected = connection.execute(ctx.query()).await?.rows_affected();
total_rows += rows_affected;
ctx.set_query_result(rows_affected, true);
S::after_scan(&ctx).await?;
S::after_mutation(&ctx).await?;
transaction.commit().await?;
Ok(total_rows)
}
async fn transactional_delete<S: Schema>(queries: (&Query, &Query)) -> Result<u64, Error> {
let mut transaction = Self::acquire_writer().await?.pool().begin().await?;
let connection = transaction.acquire().await?;
let query = queries.0;
Self::before_query(query).await?;
let table_name = query.format_table_name::<Self>();
let filters = query.format_filters::<Self>();
let sql = format!("DELETE FROM {table_name} {filters};");
let mut ctx = Self::before_scan(&sql).await?;
ctx.set_query(sql);
let mut total_rows = 0;
let rows_affected = connection.execute(ctx.query()).await?.rows_affected();
total_rows += rows_affected;
ctx.set_query_result(rows_affected, true);
Self::after_scan(&ctx).await?;
Self::after_query(&ctx).await?;
let query = queries.1;
S::before_query(query).await?;
let table_name = query.format_table_name::<S>();
let filters = query.format_filters::<S>();
let sql = format!("DELETE FROM {table_name} {filters};");
let mut ctx = S::before_scan(&sql).await?;
ctx.set_query(sql);
let rows_affected = connection.execute(ctx.query()).await?.rows_affected();
total_rows += rows_affected;
ctx.set_query_result(rows_affected, true);
S::after_scan(&ctx).await?;
S::after_query(&ctx).await?;
transaction.commit().await?;
Ok(total_rows)
}
}