use super::{column::ColumnExt, query::QueryExt, schema::Schema, DatabaseDriver};
use futures::TryStreamExt;
use sqlx::{Decode, Row, Type};
use std::{fmt::Display, sync::atomic::Ordering::Relaxed};
use zino_core::{error::Error, extension::JsonValueExt, model::Query, Map};
pub trait ScalarQuery<K>: Schema<PrimaryKey = K>
where
K: Default + Display + PartialEq,
{
async fn find_scalar<T>(query: &Query) -> Result<T, Error>
where
T: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
{
Self::before_query(query).await?;
let table_name = query.format_table_name::<Self>();
let projection = query.format_projection();
let filters = query.format_filters::<Self>();
let sort = query.format_sort();
let sql = format!("SELECT {projection} FROM {table_name} {filters} {sort} LIMIT 1;");
let mut ctx = Self::before_scan(&sql).await?;
ctx.set_query(sql);
let pool = Self::acquire_reader().await?.pool();
let scalar = sqlx::query_scalar(ctx.query()).fetch_one(pool).await?;
ctx.set_query_result(1, true);
Self::after_scan(&ctx).await?;
Self::after_query(&ctx).await?;
Ok(scalar)
}
async fn find_scalars<T>(query: &Query) -> Result<Vec<T>, Error>
where
T: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
{
Self::before_query(query).await?;
let table_name = query.format_table_name::<Self>();
let projection = query.format_projection();
let filters = query.format_filters::<Self>();
let sort = query.format_sort();
let pagination = query.format_pagination();
let sql = format!("SELECT {projection} FROM {table_name} {filters} {sort} {pagination};");
let mut ctx = Self::before_scan(&sql).await?;
ctx.set_query(&sql);
let pool = Self::acquire_reader().await?.pool();
let mut rows = sqlx::query(&sql).fetch(pool);
let mut data = Vec::new();
let mut max_rows = super::MAX_ROWS.load(Relaxed);
while let Some(row) = rows.try_next().await? {
if max_rows > 0 {
data.push(row.try_get_unchecked(0)?);
max_rows -= 1;
} else {
break;
}
}
ctx.set_query_result(u64::try_from(data.len())?, true);
Self::after_scan(&ctx).await?;
Self::after_query(&ctx).await?;
Ok(data)
}
async fn find_distinct_scalars<T>(query: &Query) -> Result<Vec<T>, Error>
where
T: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
{
Self::before_query(query).await?;
let table_name = query.format_table_name::<Self>();
let projection = query.format_projection();
let filters = query.format_filters::<Self>();
let sort = query.format_sort();
let pagination = query.format_pagination();
let sql = format!(
"SELECT DISTINCT {projection} FROM {table_name} \
{filters} {sort} {pagination};"
);
let mut ctx = Self::before_scan(&sql).await?;
ctx.set_query(&sql);
let pool = Self::acquire_reader().await?.pool();
let mut rows = sqlx::query(&sql).fetch(pool);
let mut data = Vec::new();
let mut max_rows = super::MAX_ROWS.load(Relaxed);
while let Some(row) = rows.try_next().await? {
if max_rows > 0 {
data.push(row.try_get_unchecked(0)?);
max_rows -= 1;
} else {
break;
}
}
ctx.set_query_result(u64::try_from(data.len())?, true);
Self::after_scan(&ctx).await?;
Self::after_query(&ctx).await?;
Ok(data)
}
async fn query_scalar<T>(query: &str, params: Option<&Map>) -> Result<T, Error>
where
T: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
{
let (sql, values) = Query::prepare_query(query, params);
let mut ctx = Self::before_scan(&sql).await?;
ctx.set_query(sql);
let mut query = sqlx::query_scalar(ctx.query());
let mut arguments = Vec::with_capacity(values.len());
for value in values {
query = query.bind(value.to_string_unquoted());
arguments.push(value.to_string_unquoted());
}
let pool = Self::acquire_reader().await?.pool();
let scalar = query.fetch_one(pool).await?;
ctx.append_arguments(&mut arguments);
ctx.set_query_result(1, true);
Self::after_scan(&ctx).await?;
Ok(scalar)
}
async fn query_scalars<T>(query: &str, params: Option<&Map>) -> Result<Vec<T>, Error>
where
T: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
{
let (sql, values) = Query::prepare_query(query, params);
let mut ctx = Self::before_scan(&sql).await?;
ctx.set_query(sql.as_ref());
let mut query = sqlx::query(&sql);
let mut arguments = Vec::with_capacity(values.len());
for value in values {
query = query.bind(value.to_string_unquoted());
arguments.push(value.to_string_unquoted());
}
let pool = Self::acquire_reader().await?.pool();
let mut rows = query.fetch(pool);
let mut data = Vec::new();
let mut max_rows = super::MAX_ROWS.load(Relaxed);
while let Some(row) = rows.try_next().await? {
if max_rows > 0 {
data.push(row.try_get_unchecked(0)?);
max_rows -= 1;
} else {
break;
}
}
ctx.append_arguments(&mut arguments);
ctx.set_query_result(u64::try_from(data.len())?, true);
Self::after_scan(&ctx).await?;
Ok(data)
}
async fn find_scalar_by_id<C, T>(primary_key: &Self::PrimaryKey, column: C) -> Result<T, Error>
where
C: AsRef<str>,
T: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
{
let primary_key_name = Self::PRIMARY_KEY_NAME;
let table_name = Query::table_name_escaped::<Self>();
let projection = Query::format_field(column.as_ref());
let placeholder = Query::placeholder(1);
let sql = if cfg!(feature = "orm-postgres") {
let type_annotation = Self::primary_key_column().type_annotation();
format!(
"SELECT {projection} FROM {table_name} \
WHERE {primary_key_name} = ({placeholder}){type_annotation};"
)
} else {
format!(
"SELECT {projection} FROM {table_name} WHERE {primary_key_name} = {placeholder};"
)
};
let mut ctx = Self::before_scan(&sql).await?;
ctx.set_query(sql);
let pool = Self::acquire_reader().await?.pool();
let query = sqlx::query_scalar(ctx.query()).bind(primary_key.to_string());
let scalar = query.fetch_one(pool).await?;
ctx.set_query_result(1, true);
Self::after_scan(&ctx).await?;
Self::after_query(&ctx).await?;
Ok(scalar)
}
async fn find_primary_key(query: &Query) -> Result<K, Error>
where
K: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
{
Self::before_query(query).await?;
let projection = Self::PRIMARY_KEY_NAME;
let table_name = query.format_table_name::<Self>();
let filters = query.format_filters::<Self>();
let sort = query.format_sort();
let sql = format!("SELECT {projection} FROM {table_name} {filters} {sort} LIMIT 1;");
let mut ctx = Self::before_scan(&sql).await?;
ctx.set_query(sql);
let pool = Self::acquire_reader().await?.pool();
let scalar = sqlx::query_scalar(ctx.query()).fetch_one(pool).await?;
ctx.set_query_result(1, true);
Self::after_scan(&ctx).await?;
Self::after_query(&ctx).await?;
Ok(scalar)
}
async fn find_primary_keys(query: &Query) -> Result<Vec<K>, Error>
where
K: Send + Unpin + Type<DatabaseDriver> + for<'r> Decode<'r, DatabaseDriver>,
{
Self::before_query(query).await?;
let projection = Self::PRIMARY_KEY_NAME;
let table_name = query.format_table_name::<Self>();
let filters = query.format_filters::<Self>();
let sort = query.format_sort();
let pagination = query.format_pagination();
let sql = format!("SELECT {projection} FROM {table_name} {filters} {sort} {pagination};");
let mut ctx = Self::before_scan(&sql).await?;
ctx.set_query(&sql);
let pool = Self::acquire_reader().await?.pool();
let mut rows = sqlx::query(&sql).fetch(pool);
let mut data = Vec::new();
let mut max_rows = super::MAX_ROWS.load(Relaxed);
while let Some(row) = rows.try_next().await? {
if max_rows > 0 {
data.push(row.try_get_unchecked(0)?);
max_rows -= 1;
} else {
break;
}
}
ctx.set_query_result(u64::try_from(data.len())?, true);
Self::after_scan(&ctx).await?;
Self::after_query(&ctx).await?;
Ok(data)
}
}
impl<M, K> ScalarQuery<K> for M
where
M: Schema<PrimaryKey = K>,
K: Default + Display + PartialEq,
{
}