odbc_api_derive/lib.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, DeriveInput};
/// Use this to derive the trait `FetchRow` for structs defined in the application logic.
///
/// # Example
///
/// ```
/// use odbc_api_derive::Fetch;
/// use odbc_api::{Connection, Error, Cursor, parameter::VarCharArray, buffers::RowVec};
///
/// #[derive(Default, Clone, Copy, Fetch)]
/// struct Person {
/// first_name: VarCharArray<255>,
/// last_name: VarCharArray<255>,
/// }
///
/// fn send_greetings(conn: &mut Connection) -> Result<(), Error> {
/// let max_rows_in_batch = 250;
/// let buffer = RowVec::<Person>::new(max_rows_in_batch);
/// let mut cursor = conn.execute("SELECT first_name, last_name FROM Persons", ())?
/// .expect("SELECT must yield a result set");
/// let mut block_cursor = cursor.bind_buffer(buffer)?;
///
/// while let Some(batch) = block_cursor.fetch()? {
/// for person in batch.iter() {
/// let first = person.first_name.as_str()
/// .expect("First name must be UTF-8")
/// .expect("First Name must not be NULL");
/// let last = person.last_name.as_str()
/// .expect("Last name must be UTF-8")
/// .expect("Last Name must not be NULL");
/// println!("Hello {first} {last}!")
/// }
/// }
/// Ok(())
/// }
/// ```
#[proc_macro_derive(Fetch)]
pub fn derive_fetch_row(item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as DeriveInput);
let struct_name = input.ident;
let struct_data = match input.data {
syn::Data::Struct(struct_data) => struct_data,
_ => panic!("Fetch can only be derived for structs"),
};
let fields = struct_data.fields;
let bindings = fields.iter().enumerate().map(|(index, field)| {
let field_name = field
.ident
.as_ref()
.expect("All struct members must be named");
let col_index = (index + 1) as u16;
quote! {
odbc_api::buffers::FetchRowMember::bind_to_col(
&mut self.#field_name,
#col_index,
&mut cursor
)?;
}
});
let find_truncation = fields.iter().enumerate().map(|(index, field)| {
let field_name = field
.ident
.as_ref()
.expect("All struct members must be named");
quote! {
let maybe_truncation = odbc_api::buffers::FetchRowMember::find_truncation(
&self.#field_name,
#index,
);
if let Some(truncation_info) = maybe_truncation {
return Some(truncation_info);
}
}
});
let expanded = quote! {
unsafe impl odbc_api::buffers::FetchRow for #struct_name {
unsafe fn bind_columns_to_cursor(
&mut self,
mut cursor: odbc_api::handles::StatementRef<'_>
) -> std::result::Result<(), odbc_api::Error> {
#(#bindings)*
Ok(())
}
fn find_truncation(&self) -> std::option::Option<odbc_api::TruncationInfo> {
#(#find_truncation)*
None
}
}
};
expanded.into()
}