anchor_attribute_account/
lib.rs1extern crate proc_macro;
2
3use anchor_syn::{codegen::program::common::gen_discriminator, Overrides};
4use quote::{quote, ToTokens};
5use syn::{
6 parenthesized,
7 parse::{Parse, ParseStream},
8 parse_macro_input,
9 token::{Comma, Paren},
10 Ident, LitStr,
11};
12
13mod id;
14
15#[cfg(feature = "lazy-account")]
16mod lazy;
17
18#[proc_macro_attribute]
97pub fn account(
98 args: proc_macro::TokenStream,
99 input: proc_macro::TokenStream,
100) -> proc_macro::TokenStream {
101 let args = parse_macro_input!(args as AccountArgs);
102 let namespace = args.namespace.unwrap_or_default();
103 let is_zero_copy = args.zero_copy.is_some();
104 let unsafe_bytemuck = args.zero_copy.unwrap_or_default();
105
106 let account_strct = parse_macro_input!(input as syn::ItemStruct);
107 let account_name = &account_strct.ident;
108 let account_name_str = account_name.to_string();
109 let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
110
111 let discriminator = args
112 .overrides
113 .and_then(|ov| ov.discriminator)
114 .unwrap_or_else(|| {
115 let namespace = if namespace.is_empty() {
117 "account"
118 } else {
119 &namespace
120 };
121
122 gen_discriminator(namespace, account_name)
123 });
124 let disc = if account_strct.generics.lt_token.is_some() {
125 quote! { #account_name::#type_gen::DISCRIMINATOR }
126 } else {
127 quote! { #account_name::DISCRIMINATOR }
128 };
129
130 let owner_impl = {
131 if namespace.is_empty() {
132 quote! {
133 #[automatically_derived]
134 impl #impl_gen anchor_lang::Owner for #account_name #type_gen #where_clause {
135 fn owner() -> Pubkey {
136 crate::ID
137 }
138 }
139 }
140 } else {
141 quote! {}
142 }
143 };
144
145 let unsafe_bytemuck_impl = {
146 if unsafe_bytemuck {
147 quote! {
148 #[automatically_derived]
149 unsafe impl #impl_gen anchor_lang::__private::bytemuck::Pod for #account_name #type_gen #where_clause {}
150 #[automatically_derived]
151 unsafe impl #impl_gen anchor_lang::__private::bytemuck::Zeroable for #account_name #type_gen #where_clause {}
152 }
153 } else {
154 quote! {}
155 }
156 };
157
158 let bytemuck_derives = {
159 if !unsafe_bytemuck {
160 quote! {
161 #[zero_copy]
162 }
163 } else {
164 quote! {
165 #[zero_copy(unsafe)]
166 }
167 }
168 };
169
170 proc_macro::TokenStream::from({
171 if is_zero_copy {
172 quote! {
173 #bytemuck_derives
174 #account_strct
175
176 #unsafe_bytemuck_impl
177
178 #[automatically_derived]
179 impl #impl_gen anchor_lang::ZeroCopy for #account_name #type_gen #where_clause {}
180
181 #[automatically_derived]
182 impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
183 const DISCRIMINATOR: &'static [u8] = #discriminator;
184 }
185
186 #[automatically_derived]
189 impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
190 fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
191 if buf.len() < #disc.len() {
192 return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
193 }
194 let given_disc = &buf[..#disc.len()];
195 if #disc != given_disc {
196 return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
197 }
198 Self::try_deserialize_unchecked(buf)
199 }
200
201 fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
202 let data: &[u8] = &buf[#disc.len()..];
203 let account = anchor_lang::__private::bytemuck::from_bytes(data);
205 Ok(*account)
207 }
208 }
209
210 #owner_impl
211 }
212 } else {
213 let lazy = {
214 #[cfg(feature = "lazy-account")]
215 match namespace.is_empty().then(|| lazy::gen_lazy(&account_strct)) {
216 Some(Ok(lazy)) => lazy,
217 _ => Default::default(),
220 }
221 #[cfg(not(feature = "lazy-account"))]
222 proc_macro2::TokenStream::default()
223 };
224 quote! {
225 #[derive(AnchorSerialize, AnchorDeserialize, Clone)]
226 #account_strct
227
228 #[automatically_derived]
229 impl #impl_gen anchor_lang::AccountSerialize for #account_name #type_gen #where_clause {
230 fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> anchor_lang::Result<()> {
231 if writer.write_all(#disc).is_err() {
232 return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
233 }
234
235 if AnchorSerialize::serialize(self, writer).is_err() {
236 return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
237 }
238 Ok(())
239 }
240 }
241
242 #[automatically_derived]
243 impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
244 fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
245 if buf.len() < #disc.len() {
246 return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
247 }
248 let given_disc = &buf[..#disc.len()];
249 if #disc != given_disc {
250 return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
251 }
252 Self::try_deserialize_unchecked(buf)
253 }
254
255 fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
256 let mut data: &[u8] = &buf[#disc.len()..];
257 AnchorDeserialize::deserialize(&mut data)
258 .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
259 }
260 }
261
262 #[automatically_derived]
263 impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
264 const DISCRIMINATOR: &'static [u8] = #discriminator;
265 }
266
267 #owner_impl
268
269 #lazy
270 }
271 }
272 })
273}
274
275#[derive(Debug, Default)]
276struct AccountArgs {
277 zero_copy: Option<bool>,
279 namespace: Option<String>,
281 overrides: Option<Overrides>,
283}
284
285impl Parse for AccountArgs {
286 fn parse(input: ParseStream) -> syn::Result<Self> {
287 let mut parsed = Self::default();
288 let args = input.parse_terminated::<_, Comma>(AccountArg::parse)?;
289 for arg in args {
290 match arg {
291 AccountArg::ZeroCopy { is_unsafe } => {
292 parsed.zero_copy.replace(is_unsafe);
293 }
294 AccountArg::Namespace(ns) => {
295 parsed.namespace.replace(ns);
296 }
297 AccountArg::Overrides(ov) => {
298 parsed.overrides.replace(ov);
299 }
300 }
301 }
302
303 Ok(parsed)
304 }
305}
306
307enum AccountArg {
308 ZeroCopy { is_unsafe: bool },
309 Namespace(String),
310 Overrides(Overrides),
311}
312
313impl Parse for AccountArg {
314 fn parse(input: ParseStream) -> syn::Result<Self> {
315 if let Ok(ns) = input.parse::<LitStr>() {
317 return Ok(Self::Namespace(
318 ns.to_token_stream().to_string().replace('\"', ""),
319 ));
320 }
321
322 if input.fork().parse::<Ident>()? == "zero_copy" {
324 input.parse::<Ident>()?;
325 let is_unsafe = if input.peek(Paren) {
326 let content;
327 parenthesized!(content in input);
328 let content = content.parse::<proc_macro2::TokenStream>()?;
329 if content.to_string().as_str().trim() != "unsafe" {
330 return Err(syn::Error::new(
331 syn::spanned::Spanned::span(&content),
332 "Expected `unsafe`",
333 ));
334 }
335
336 true
337 } else {
338 false
339 };
340
341 return Ok(Self::ZeroCopy { is_unsafe });
342 };
343
344 input.parse::<Overrides>().map(Self::Overrides)
346 }
347}
348
349#[proc_macro_derive(ZeroCopyAccessor, attributes(accessor))]
350pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
351 let account_strct = parse_macro_input!(item as syn::ItemStruct);
352 let account_name = &account_strct.ident;
353 let (impl_gen, ty_gen, where_clause) = account_strct.generics.split_for_impl();
354
355 let fields = match &account_strct.fields {
356 syn::Fields::Named(n) => n,
357 _ => panic!("Fields must be named"),
358 };
359 let methods: Vec<proc_macro2::TokenStream> = fields
360 .named
361 .iter()
362 .filter_map(|field: &syn::Field| {
363 field
364 .attrs
365 .iter()
366 .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "accessor")
367 .map(|attr| {
368 let mut tts = attr.tokens.clone().into_iter();
369 let g_stream = match tts.next().expect("Must have a token group") {
370 proc_macro2::TokenTree::Group(g) => g.stream(),
371 _ => panic!("Invalid syntax"),
372 };
373 let accessor_ty = match g_stream.into_iter().next() {
374 Some(token) => token,
375 _ => panic!("Missing accessor type"),
376 };
377
378 let field_name = field.ident.as_ref().unwrap();
379
380 let get_field: proc_macro2::TokenStream =
381 format!("get_{field_name}").parse().unwrap();
382 let set_field: proc_macro2::TokenStream =
383 format!("set_{field_name}").parse().unwrap();
384
385 quote! {
386 pub fn #get_field(&self) -> #accessor_ty {
387 anchor_lang::__private::ZeroCopyAccessor::get(&self.#field_name)
388 }
389 pub fn #set_field(&mut self, input: &#accessor_ty) {
390 self.#field_name = anchor_lang::__private::ZeroCopyAccessor::set(input);
391 }
392 }
393 })
394 })
395 .collect();
396 proc_macro::TokenStream::from(quote! {
397 #[automatically_derived]
398 impl #impl_gen #account_name #ty_gen #where_clause {
399 #(#methods)*
400 }
401 })
402}
403
404#[proc_macro_attribute]
417pub fn zero_copy(
418 args: proc_macro::TokenStream,
419 item: proc_macro::TokenStream,
420) -> proc_macro::TokenStream {
421 let mut is_unsafe = false;
422 for arg in args.into_iter() {
423 match arg {
424 proc_macro::TokenTree::Ident(ident) => {
425 if ident.to_string() == "unsafe" {
426 is_unsafe = true;
434 } else {
435 panic!("expected single ident `unsafe`");
437 }
438 }
439 _ => {
440 panic!("expected single ident `unsafe`");
441 }
442 }
443 }
444
445 let account_strct = parse_macro_input!(item as syn::ItemStruct);
446
447 let attr = account_strct
450 .attrs
451 .iter()
452 .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "repr");
453
454 let repr = match attr {
455 Some(_attr) => quote! {},
457 None => {
458 if is_unsafe {
459 quote! {#[repr(Rust, packed)]}
460 } else {
461 quote! {#[repr(C)]}
462 }
463 }
464 };
465
466 let mut has_pod_attr = false;
467 let mut has_zeroable_attr = false;
468 for attr in account_strct.attrs.iter() {
469 let token_string = attr.tokens.to_string();
470 if token_string.contains("bytemuck :: Pod") {
471 has_pod_attr = true;
472 }
473 if token_string.contains("bytemuck :: Zeroable") {
474 has_zeroable_attr = true;
475 }
476 }
477
478 let pod = if has_pod_attr || is_unsafe {
483 quote! {}
484 } else {
485 quote! {#[derive(::bytemuck::Pod)]}
486 };
487 let zeroable = if has_zeroable_attr || is_unsafe {
488 quote! {}
489 } else {
490 quote! {#[derive(::bytemuck::Zeroable)]}
491 };
492
493 let ret = quote! {
494 #[derive(anchor_lang::__private::ZeroCopyAccessor, Copy, Clone)]
495 #repr
496 #pod
497 #zeroable
498 #account_strct
499 };
500
501 #[cfg(feature = "idl-build")]
502 {
503 let derive_unsafe = if is_unsafe {
504 quote! { #[derive(bytemuck::Unsafe)] }
506 } else {
507 quote! {}
508 };
509 let zc_struct = syn::parse2(quote! {
510 #derive_unsafe
511 #ret
512 })
513 .unwrap();
514 let idl_build_impl = anchor_syn::idl::impl_idl_build_struct(&zc_struct);
515 return proc_macro::TokenStream::from(quote! {
516 #ret
517 #idl_build_impl
518 });
519 }
520
521 #[allow(unreachable_code)]
522 proc_macro::TokenStream::from(ret)
523}
524
525#[proc_macro]
529pub fn pubkey(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
530 let pk = parse_macro_input!(input as id::Pubkey);
531 proc_macro::TokenStream::from(quote! {#pk})
532}
533
534#[proc_macro]
537pub fn declare_id(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
538 #[cfg(feature = "idl-build")]
539 let address = input.clone().to_string();
540
541 let id = parse_macro_input!(input as id::Id);
542 let ret = quote! { #id };
543
544 #[cfg(feature = "idl-build")]
545 {
546 let idl_print = anchor_syn::idl::gen_idl_print_fn_address(address);
547 return proc_macro::TokenStream::from(quote! {
548 #ret
549 #idl_print
550 });
551 }
552
553 #[allow(unreachable_code)]
554 proc_macro::TokenStream::from(ret)
555}