anchor_attribute_account/
lib.rs

1extern crate proc_macro;
2
3use anchor_syn::{codegen::program::common::gen_discriminator, Overrides};
4use quote::{quote, ToTokens};
5use syn::{
6    parenthesized,
7    parse::{Parse, ParseStream},
8    parse_macro_input,
9    token::{Comma, Paren},
10    Ident, LitStr,
11};
12
13mod id;
14
15#[cfg(feature = "lazy-account")]
16mod lazy;
17
18/// An attribute for a data structure representing a Solana account.
19///
20/// `#[account]` generates trait implementations for the following traits:
21///
22/// - [`AccountSerialize`](./trait.AccountSerialize.html)
23/// - [`AccountDeserialize`](./trait.AccountDeserialize.html)
24/// - [`AnchorSerialize`](./trait.AnchorSerialize.html)
25/// - [`AnchorDeserialize`](./trait.AnchorDeserialize.html)
26/// - [`Clone`](https://doc.rust-lang.org/std/clone/trait.Clone.html)
27/// - [`Discriminator`](./trait.Discriminator.html)
28/// - [`Owner`](./trait.Owner.html)
29///
30/// When implementing account serialization traits the first 8 bytes are
31/// reserved for a unique account discriminator by default, self described by
32/// the first 8 bytes of the SHA256 of the account's Rust ident. This is unless
33/// the discriminator was overridden with the `discriminator` argument (see
34/// [Arguments](#arguments)).
35///
36/// As a result, any calls to `AccountDeserialize`'s `try_deserialize` will
37/// check this discriminator. If it doesn't match, an invalid account was given,
38/// and the account deserialization will exit with an error.
39///
40/// # Arguments
41///
42/// - `discriminator`: Override the default 8-byte discriminator
43///
44///     **Usage:** `discriminator = <CONST_EXPR>`
45///
46///     All constant expressions are supported.
47///
48///     **Examples:**
49///
50///     - `discriminator = 1` (shortcut for `[1]`)
51///     - `discriminator = [1, 2, 3, 4]`
52///     - `discriminator = b"hi"`
53///     - `discriminator = MY_DISC`
54///     - `discriminator = get_disc(...)`
55///
56/// # Zero Copy Deserialization
57///
58/// **WARNING**: Zero copy deserialization is an experimental feature. It's
59/// recommended to use it only when necessary, i.e., when you have extremely
60/// large accounts that cannot be Borsh deserialized without hitting stack or
61/// heap limits.
62///
63/// ## Usage
64///
65/// To enable zero-copy-deserialization, one can pass in the `zero_copy`
66/// argument to the macro as follows:
67///
68/// ```ignore
69/// #[account(zero_copy)]
70/// ```
71///
72/// This can be used to conveniently implement
73/// [`ZeroCopy`](./trait.ZeroCopy.html) so that the account can be used
74/// with [`AccountLoader`](./accounts/account_loader/struct.AccountLoader.html).
75///
76/// Other than being more efficient, the most salient benefit this provides is
77/// the ability to define account types larger than the max stack or heap size.
78/// When using borsh, the account has to be copied and deserialized into a new
79/// data structure and thus is constrained by stack and heap limits imposed by
80/// the BPF VM. With zero copy deserialization, all bytes from the account's
81/// backing `RefCell<&mut [u8]>` are simply re-interpreted as a reference to
82/// the data structure. No allocations or copies necessary. Hence the ability
83/// to get around stack and heap limitations.
84///
85/// To facilitate this, all fields in an account must be constrained to be
86/// "plain old  data", i.e., they must implement
87/// [`Pod`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html). Please review the
88/// [`safety`](https://docs.rs/bytemuck/latest/bytemuck/trait.Pod.html#safety)
89/// section before using.
90///
91/// Using `zero_copy` requires adding the following dependency to your `Cargo.toml` file:
92///
93/// ```toml
94/// bytemuck = { version = "1.17", features = ["derive", "min_const_generics"] }
95/// ```
96#[proc_macro_attribute]
97pub fn account(
98    args: proc_macro::TokenStream,
99    input: proc_macro::TokenStream,
100) -> proc_macro::TokenStream {
101    let args = parse_macro_input!(args as AccountArgs);
102    let namespace = args.namespace.unwrap_or_default();
103    let is_zero_copy = args.zero_copy.is_some();
104    let unsafe_bytemuck = args.zero_copy.unwrap_or_default();
105
106    let account_strct = parse_macro_input!(input as syn::ItemStruct);
107    let account_name = &account_strct.ident;
108    let account_name_str = account_name.to_string();
109    let (impl_gen, type_gen, where_clause) = account_strct.generics.split_for_impl();
110
111    let discriminator = args
112        .overrides
113        .and_then(|ov| ov.discriminator)
114        .unwrap_or_else(|| {
115            // Namespace the discriminator to prevent collisions.
116            let namespace = if namespace.is_empty() {
117                "account"
118            } else {
119                &namespace
120            };
121
122            gen_discriminator(namespace, account_name)
123        });
124    let disc = if account_strct.generics.lt_token.is_some() {
125        quote! { #account_name::#type_gen::DISCRIMINATOR }
126    } else {
127        quote! { #account_name::DISCRIMINATOR }
128    };
129
130    let owner_impl = {
131        if namespace.is_empty() {
132            quote! {
133                #[automatically_derived]
134                impl #impl_gen anchor_lang::Owner for #account_name #type_gen #where_clause {
135                    fn owner() -> Pubkey {
136                        crate::ID
137                    }
138                }
139            }
140        } else {
141            quote! {}
142        }
143    };
144
145    let unsafe_bytemuck_impl = {
146        if unsafe_bytemuck {
147            quote! {
148                #[automatically_derived]
149                unsafe impl #impl_gen anchor_lang::__private::bytemuck::Pod for #account_name #type_gen #where_clause {}
150                #[automatically_derived]
151                unsafe impl #impl_gen anchor_lang::__private::bytemuck::Zeroable for #account_name #type_gen #where_clause {}
152            }
153        } else {
154            quote! {}
155        }
156    };
157
158    let bytemuck_derives = {
159        if !unsafe_bytemuck {
160            quote! {
161                #[zero_copy]
162            }
163        } else {
164            quote! {
165                #[zero_copy(unsafe)]
166            }
167        }
168    };
169
170    proc_macro::TokenStream::from({
171        if is_zero_copy {
172            quote! {
173                #bytemuck_derives
174                #account_strct
175
176                #unsafe_bytemuck_impl
177
178                #[automatically_derived]
179                impl #impl_gen anchor_lang::ZeroCopy for #account_name #type_gen #where_clause {}
180
181                #[automatically_derived]
182                impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
183                    const DISCRIMINATOR: &'static [u8] = #discriminator;
184                }
185
186                // This trait is useful for clients deserializing accounts.
187                // It's expected on-chain programs deserialize via zero-copy.
188                #[automatically_derived]
189                impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
190                    fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
191                        if buf.len() < #disc.len() {
192                            return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
193                        }
194                        let given_disc = &buf[..#disc.len()];
195                        if #disc != given_disc {
196                            return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
197                        }
198                        Self::try_deserialize_unchecked(buf)
199                    }
200
201                    fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
202                        let data: &[u8] = &buf[#disc.len()..];
203                        // Re-interpret raw bytes into the POD data structure.
204                        let account = anchor_lang::__private::bytemuck::from_bytes(data);
205                        // Copy out the bytes into a new, owned data structure.
206                        Ok(*account)
207                    }
208                }
209
210                #owner_impl
211            }
212        } else {
213            let lazy = {
214                #[cfg(feature = "lazy-account")]
215                match namespace.is_empty().then(|| lazy::gen_lazy(&account_strct)) {
216                    Some(Ok(lazy)) => lazy,
217                    // If lazy codegen fails for whatever reason, return empty tokenstream which
218                    // will make the account unusable with `LazyAccount<T>`
219                    _ => Default::default(),
220                }
221                #[cfg(not(feature = "lazy-account"))]
222                proc_macro2::TokenStream::default()
223            };
224            quote! {
225                #[derive(AnchorSerialize, AnchorDeserialize, Clone)]
226                #account_strct
227
228                #[automatically_derived]
229                impl #impl_gen anchor_lang::AccountSerialize for #account_name #type_gen #where_clause {
230                    fn try_serialize<W: std::io::Write>(&self, writer: &mut W) -> anchor_lang::Result<()> {
231                        if writer.write_all(#disc).is_err() {
232                            return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
233                        }
234
235                        if AnchorSerialize::serialize(self, writer).is_err() {
236                            return Err(anchor_lang::error::ErrorCode::AccountDidNotSerialize.into());
237                        }
238                        Ok(())
239                    }
240                }
241
242                #[automatically_derived]
243                impl #impl_gen anchor_lang::AccountDeserialize for #account_name #type_gen #where_clause {
244                    fn try_deserialize(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
245                        if buf.len() < #disc.len() {
246                            return Err(anchor_lang::error::ErrorCode::AccountDiscriminatorNotFound.into());
247                        }
248                        let given_disc = &buf[..#disc.len()];
249                        if #disc != given_disc {
250                            return Err(anchor_lang::error!(anchor_lang::error::ErrorCode::AccountDiscriminatorMismatch).with_account_name(#account_name_str));
251                        }
252                        Self::try_deserialize_unchecked(buf)
253                    }
254
255                    fn try_deserialize_unchecked(buf: &mut &[u8]) -> anchor_lang::Result<Self> {
256                        let mut data: &[u8] = &buf[#disc.len()..];
257                        AnchorDeserialize::deserialize(&mut data)
258                            .map_err(|_| anchor_lang::error::ErrorCode::AccountDidNotDeserialize.into())
259                    }
260                }
261
262                #[automatically_derived]
263                impl #impl_gen anchor_lang::Discriminator for #account_name #type_gen #where_clause {
264                    const DISCRIMINATOR: &'static [u8] = #discriminator;
265                }
266
267                #owner_impl
268
269                #lazy
270            }
271        }
272    })
273}
274
275#[derive(Debug, Default)]
276struct AccountArgs {
277    /// `bool` is for deciding whether to use `unsafe` e.g. `Some(true)` for `zero_copy(unsafe)`
278    zero_copy: Option<bool>,
279    /// Account namespace override, `account` if not specified
280    namespace: Option<String>,
281    /// Named overrides
282    overrides: Option<Overrides>,
283}
284
285impl Parse for AccountArgs {
286    fn parse(input: ParseStream) -> syn::Result<Self> {
287        let mut parsed = Self::default();
288        let args = input.parse_terminated::<_, Comma>(AccountArg::parse)?;
289        for arg in args {
290            match arg {
291                AccountArg::ZeroCopy { is_unsafe } => {
292                    parsed.zero_copy.replace(is_unsafe);
293                }
294                AccountArg::Namespace(ns) => {
295                    parsed.namespace.replace(ns);
296                }
297                AccountArg::Overrides(ov) => {
298                    parsed.overrides.replace(ov);
299                }
300            }
301        }
302
303        Ok(parsed)
304    }
305}
306
307enum AccountArg {
308    ZeroCopy { is_unsafe: bool },
309    Namespace(String),
310    Overrides(Overrides),
311}
312
313impl Parse for AccountArg {
314    fn parse(input: ParseStream) -> syn::Result<Self> {
315        // Namespace
316        if let Ok(ns) = input.parse::<LitStr>() {
317            return Ok(Self::Namespace(
318                ns.to_token_stream().to_string().replace('\"', ""),
319            ));
320        }
321
322        // Zero copy
323        if input.fork().parse::<Ident>()? == "zero_copy" {
324            input.parse::<Ident>()?;
325            let is_unsafe = if input.peek(Paren) {
326                let content;
327                parenthesized!(content in input);
328                let content = content.parse::<proc_macro2::TokenStream>()?;
329                if content.to_string().as_str().trim() != "unsafe" {
330                    return Err(syn::Error::new(
331                        syn::spanned::Spanned::span(&content),
332                        "Expected `unsafe`",
333                    ));
334                }
335
336                true
337            } else {
338                false
339            };
340
341            return Ok(Self::ZeroCopy { is_unsafe });
342        };
343
344        // Overrides
345        input.parse::<Overrides>().map(Self::Overrides)
346    }
347}
348
349#[proc_macro_derive(ZeroCopyAccessor, attributes(accessor))]
350pub fn derive_zero_copy_accessor(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
351    let account_strct = parse_macro_input!(item as syn::ItemStruct);
352    let account_name = &account_strct.ident;
353    let (impl_gen, ty_gen, where_clause) = account_strct.generics.split_for_impl();
354
355    let fields = match &account_strct.fields {
356        syn::Fields::Named(n) => n,
357        _ => panic!("Fields must be named"),
358    };
359    let methods: Vec<proc_macro2::TokenStream> = fields
360        .named
361        .iter()
362        .filter_map(|field: &syn::Field| {
363            field
364                .attrs
365                .iter()
366                .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "accessor")
367                .map(|attr| {
368                    let mut tts = attr.tokens.clone().into_iter();
369                    let g_stream = match tts.next().expect("Must have a token group") {
370                        proc_macro2::TokenTree::Group(g) => g.stream(),
371                        _ => panic!("Invalid syntax"),
372                    };
373                    let accessor_ty = match g_stream.into_iter().next() {
374                        Some(token) => token,
375                        _ => panic!("Missing accessor type"),
376                    };
377
378                    let field_name = field.ident.as_ref().unwrap();
379
380                    let get_field: proc_macro2::TokenStream =
381                        format!("get_{field_name}").parse().unwrap();
382                    let set_field: proc_macro2::TokenStream =
383                        format!("set_{field_name}").parse().unwrap();
384
385                    quote! {
386                        pub fn #get_field(&self) -> #accessor_ty {
387                            anchor_lang::__private::ZeroCopyAccessor::get(&self.#field_name)
388                        }
389                        pub fn #set_field(&mut self, input: &#accessor_ty) {
390                            self.#field_name = anchor_lang::__private::ZeroCopyAccessor::set(input);
391                        }
392                    }
393                })
394        })
395        .collect();
396    proc_macro::TokenStream::from(quote! {
397        #[automatically_derived]
398        impl #impl_gen #account_name #ty_gen #where_clause {
399            #(#methods)*
400        }
401    })
402}
403
404/// A data structure that can be used as an internal field for a zero copy
405/// deserialized account, i.e., a struct marked with `#[account(zero_copy)]`.
406///
407/// `#[zero_copy]` is just a convenient alias for
408///
409/// ```ignore
410/// #[derive(Copy, Clone)]
411/// #[derive(bytemuck::Zeroable)]
412/// #[derive(bytemuck::Pod)]
413/// #[repr(C)]
414/// struct MyStruct {...}
415/// ```
416#[proc_macro_attribute]
417pub fn zero_copy(
418    args: proc_macro::TokenStream,
419    item: proc_macro::TokenStream,
420) -> proc_macro::TokenStream {
421    let mut is_unsafe = false;
422    for arg in args.into_iter() {
423        match arg {
424            proc_macro::TokenTree::Ident(ident) => {
425                if ident.to_string() == "unsafe" {
426                    // `#[zero_copy(unsafe)]` maintains the old behaviour
427                    //
428                    // ```ignore
429                    // #[derive(Copy, Clone)]
430                    // #[repr(packed)]
431                    // struct MyStruct {...}
432                    // ```
433                    is_unsafe = true;
434                } else {
435                    // TODO: how to return a compile error with a span (can't return prase error because expected type TokenStream)
436                    panic!("expected single ident `unsafe`");
437                }
438            }
439            _ => {
440                panic!("expected single ident `unsafe`");
441            }
442        }
443    }
444
445    let account_strct = parse_macro_input!(item as syn::ItemStruct);
446
447    // Takes the first repr. It's assumed that more than one are not on the
448    // struct.
449    let attr = account_strct
450        .attrs
451        .iter()
452        .find(|attr| anchor_syn::parser::tts_to_string(&attr.path) == "repr");
453
454    let repr = match attr {
455        // Users might want to manually specify repr modifiers e.g. repr(C, packed)
456        Some(_attr) => quote! {},
457        None => {
458            if is_unsafe {
459                quote! {#[repr(Rust, packed)]}
460            } else {
461                quote! {#[repr(C)]}
462            }
463        }
464    };
465
466    let mut has_pod_attr = false;
467    let mut has_zeroable_attr = false;
468    for attr in account_strct.attrs.iter() {
469        let token_string = attr.tokens.to_string();
470        if token_string.contains("bytemuck :: Pod") {
471            has_pod_attr = true;
472        }
473        if token_string.contains("bytemuck :: Zeroable") {
474            has_zeroable_attr = true;
475        }
476    }
477
478    // Once the Pod derive macro is expanded the compiler has to use the local crate's
479    // bytemuck `::bytemuck::Pod` anyway, so we're no longer using the privately
480    // exported anchor bytemuck `__private::bytemuck`, so that there won't be any
481    // possible disparity between the anchor version and the local crate's version.
482    let pod = if has_pod_attr || is_unsafe {
483        quote! {}
484    } else {
485        quote! {#[derive(::bytemuck::Pod)]}
486    };
487    let zeroable = if has_zeroable_attr || is_unsafe {
488        quote! {}
489    } else {
490        quote! {#[derive(::bytemuck::Zeroable)]}
491    };
492
493    let ret = quote! {
494        #[derive(anchor_lang::__private::ZeroCopyAccessor, Copy, Clone)]
495        #repr
496        #pod
497        #zeroable
498        #account_strct
499    };
500
501    #[cfg(feature = "idl-build")]
502    {
503        let derive_unsafe = if is_unsafe {
504            // Not a real proc-macro but exists in order to pass the serialization info
505            quote! { #[derive(bytemuck::Unsafe)] }
506        } else {
507            quote! {}
508        };
509        let zc_struct = syn::parse2(quote! {
510            #derive_unsafe
511            #ret
512        })
513        .unwrap();
514        let idl_build_impl = anchor_syn::idl::impl_idl_build_struct(&zc_struct);
515        return proc_macro::TokenStream::from(quote! {
516            #ret
517            #idl_build_impl
518        });
519    }
520
521    #[allow(unreachable_code)]
522    proc_macro::TokenStream::from(ret)
523}
524
525/// Convenience macro to define a static public key.
526///
527/// Input: a single literal base58 string representation of a Pubkey.
528#[proc_macro]
529pub fn pubkey(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
530    let pk = parse_macro_input!(input as id::Pubkey);
531    proc_macro::TokenStream::from(quote! {#pk})
532}
533
534/// Defines the program's ID. This should be used at the root of all Anchor
535/// based programs.
536#[proc_macro]
537pub fn declare_id(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
538    #[cfg(feature = "idl-build")]
539    let address = input.clone().to_string();
540
541    let id = parse_macro_input!(input as id::Id);
542    let ret = quote! { #id };
543
544    #[cfg(feature = "idl-build")]
545    {
546        let idl_print = anchor_syn::idl::gen_idl_print_fn_address(address);
547        return proc_macro::TokenStream::from(quote! {
548            #ret
549            #idl_print
550        });
551    }
552
553    #[allow(unreachable_code)]
554    proc_macro::TokenStream::from(ret)
555}