anchor_syn/idl/
defined.rs

1use anyhow::{anyhow, Result};
2use proc_macro2::TokenStream;
3use quote::quote;
4
5use super::common::{get_idl_module_path, get_no_docs};
6use crate::parser::docs;
7
8/// Generate `IdlBuild` impl for a struct.
9pub fn impl_idl_build_struct(item: &syn::ItemStruct) -> TokenStream {
10    impl_idl_build(&item.ident, &item.generics, gen_idl_type_def_struct(item))
11}
12
13/// Generate `IdlBuild` impl for an enum.
14pub fn impl_idl_build_enum(item: &syn::ItemEnum) -> TokenStream {
15    impl_idl_build(&item.ident, &item.generics, gen_idl_type_def_enum(item))
16}
17
18/// Generate `IdlBuild` impl for a union.
19///
20/// Unions are not currently supported in the IDL.
21pub fn impl_idl_build_union(item: &syn::ItemUnion) -> TokenStream {
22    impl_idl_build(
23        &item.ident,
24        &item.generics,
25        Err(anyhow!("Unions are not supported")),
26    )
27}
28
29/// Generate `IdlBuild` implementation.
30fn impl_idl_build(
31    ident: &syn::Ident,
32    generics: &syn::Generics,
33    type_def: Result<(TokenStream, Vec<syn::TypePath>)>,
34) -> TokenStream {
35    let idl = get_idl_module_path();
36    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
37    let idl_build_trait = quote!(anchor_lang::idl::build::IdlBuild);
38
39    let (idl_type_def, insert_defined) = match type_def {
40        Ok((ts, defined)) => (
41            quote! { Some(#ts) },
42            quote! {
43                #(
44                    if let Some(ty) = <#defined>::create_type() {
45                        types.insert(<#defined>::get_full_path(), ty);
46                        <#defined>::insert_types(types);
47                    }
48                );*
49            },
50        ),
51        _ => (quote! { None }, quote! {}),
52    };
53
54    quote! {
55        impl #impl_generics #idl_build_trait for #ident #ty_generics #where_clause {
56            fn create_type() -> Option<#idl::IdlTypeDef> {
57                #idl_type_def
58            }
59
60            fn insert_types(
61                types: &mut std::collections::BTreeMap<String, #idl::IdlTypeDef>
62            ) {
63                #insert_defined
64            }
65
66            fn get_full_path() -> String {
67                format!("{}::{}", module_path!(), stringify!(#ident))
68            }
69        }
70    }
71}
72
73pub fn gen_idl_type_def_struct(
74    strct: &syn::ItemStruct,
75) -> Result<(TokenStream, Vec<syn::TypePath>)> {
76    gen_idl_type_def(&strct.attrs, &strct.generics, |generic_params| {
77        let no_docs = get_no_docs();
78        let idl = get_idl_module_path();
79
80        let (fields, defined) = match &strct.fields {
81            syn::Fields::Unit => (quote! { None }, vec![]),
82            syn::Fields::Named(fields) => {
83                let (fields, defined) = fields
84                    .named
85                    .iter()
86                    .map(|f| gen_idl_field(f, generic_params, no_docs))
87                    .collect::<Result<Vec<_>>>()?
88                    .into_iter()
89                    .unzip::<_, _, Vec<_>, Vec<_>>();
90
91                (
92                    quote! { Some(#idl::IdlDefinedFields::Named(vec![#(#fields),*])) },
93                    defined,
94                )
95            }
96            syn::Fields::Unnamed(fields) => {
97                let (types, defined) = fields
98                    .unnamed
99                    .iter()
100                    .map(|f| gen_idl_type(&f.ty, generic_params))
101                    .collect::<Result<Vec<_>>>()?
102                    .into_iter()
103                    .unzip::<_, Vec<_>, Vec<_>, Vec<_>>();
104
105                (
106                    quote! { Some(#idl::IdlDefinedFields::Tuple(vec![#(#types),*])) },
107                    defined,
108                )
109            }
110        };
111        let defined = defined.into_iter().flatten().collect::<Vec<_>>();
112
113        Ok((
114            quote! {
115                #idl::IdlTypeDefTy::Struct {
116                    fields: #fields,
117                }
118            },
119            defined,
120        ))
121    })
122}
123
124fn gen_idl_type_def_enum(enm: &syn::ItemEnum) -> Result<(TokenStream, Vec<syn::TypePath>)> {
125    gen_idl_type_def(&enm.attrs, &enm.generics, |generic_params| {
126        let no_docs = get_no_docs();
127        let idl = get_idl_module_path();
128
129        let (variants, defined) = enm
130            .variants
131            .iter()
132            .map(|variant| {
133                let name = variant.ident.to_string();
134                let (fields, defined) = match &variant.fields {
135                    syn::Fields::Unit => (quote! { None }, vec![]),
136                    syn::Fields::Named(fields) => {
137                        let (fields, defined) = fields
138                            .named
139                            .iter()
140                            .map(|f| gen_idl_field(f, generic_params, no_docs))
141                            .collect::<Result<Vec<_>>>()?
142                            .into_iter()
143                            .unzip::<_, Vec<_>, Vec<_>, Vec<_>>();
144                        let defined = defined.into_iter().flatten().collect::<Vec<_>>();
145
146                        (
147                            quote! { Some(#idl::IdlDefinedFields::Named(vec![#(#fields),*])) },
148                            defined,
149                        )
150                    }
151                    syn::Fields::Unnamed(fields) => {
152                        let (types, defined) = fields
153                            .unnamed
154                            .iter()
155                            .map(|f| gen_idl_type(&f.ty, generic_params))
156                            .collect::<Result<Vec<_>>>()?
157                            .into_iter()
158                            .unzip::<_, Vec<_>, Vec<_>, Vec<_>>();
159                        let defined = defined.into_iter().flatten().collect::<Vec<_>>();
160
161                        (
162                            quote! { Some(#idl::IdlDefinedFields::Tuple(vec![#(#types),*])) },
163                            defined,
164                        )
165                    }
166                };
167
168                Ok((
169                    quote! { #idl::IdlEnumVariant { name: #name.into(), fields: #fields } },
170                    defined,
171                ))
172            })
173            .collect::<Result<Vec<_>>>()?
174            .into_iter()
175            .unzip::<_, _, Vec<_>, Vec<_>>();
176        let defined = defined.into_iter().flatten().collect::<Vec<_>>();
177
178        Ok((
179            quote! {
180                #idl::IdlTypeDefTy::Enum {
181                    variants: vec![#(#variants),*],
182                }
183            },
184            defined,
185        ))
186    })
187}
188
189fn gen_idl_type_def<F>(
190    attrs: &[syn::Attribute],
191    generics: &syn::Generics,
192    create_fields: F,
193) -> Result<(TokenStream, Vec<syn::TypePath>)>
194where
195    F: Fn(&[syn::Ident]) -> Result<(TokenStream, Vec<syn::TypePath>)>,
196{
197    let no_docs = get_no_docs();
198    let idl = get_idl_module_path();
199
200    let docs = match docs::parse(attrs) {
201        Some(docs) if !no_docs => quote! { vec![#(#docs.into()),*] },
202        _ => quote! { vec![] },
203    };
204
205    let serialization = get_attr_str("derive", attrs)
206        .and_then(|derive| {
207            if derive.contains("bytemuck") {
208                if derive.to_lowercase().contains("unsafe") {
209                    Some(quote! { #idl::IdlSerialization::BytemuckUnsafe })
210                } else {
211                    Some(quote! { #idl::IdlSerialization::Bytemuck })
212                }
213            } else {
214                None
215            }
216        })
217        .unwrap_or_else(|| quote! { #idl::IdlSerialization::default() });
218
219    let repr = get_attr_str("repr", attrs)
220        .map(|repr| {
221            let packed = repr.contains("packed");
222            let align = repr
223                .find("align")
224                .and_then(|i| repr.get(i..))
225                .and_then(|align| {
226                    align
227                        .find('(')
228                        .and_then(|start| align.find(')').and_then(|end| align.get(start + 1..end)))
229                })
230                .and_then(|size| size.parse::<usize>().ok())
231                .map(|size| quote! { Some(#size) })
232                .unwrap_or_else(|| quote! { None });
233            let modifier = quote! {
234                #idl::IdlReprModifier {
235                    packed: #packed,
236                    align: #align,
237                }
238            };
239
240            if repr.contains("transparent") {
241                quote! { #idl::IdlRepr::Transparent }
242            } else if repr.contains('C') {
243                quote! { #idl::IdlRepr::C(#modifier) }
244            } else {
245                quote! { #idl::IdlRepr::Rust(#modifier) }
246            }
247        })
248        .map(|repr| quote! { Some(#repr) })
249        .unwrap_or_else(|| quote! { None });
250
251    let generic_params = generics
252        .params
253        .iter()
254        .filter_map(|p| match p {
255            syn::GenericParam::Type(ty) => Some(ty.ident.clone()),
256            syn::GenericParam::Const(c) => Some(c.ident.clone()),
257            _ => None,
258        })
259        .collect::<Vec<_>>();
260    let (ty, defined) = create_fields(&generic_params)?;
261
262    let generics = generics
263        .params
264        .iter()
265        .filter_map(|p| match p {
266            syn::GenericParam::Type(ty) => {
267                let name = ty.ident.to_string();
268                Some(quote! {
269                    #idl::IdlTypeDefGeneric::Type {
270                        name: #name.into(),
271                    }
272                })
273            }
274            syn::GenericParam::Const(c) => {
275                let name = c.ident.to_string();
276                let ty = match &c.ty {
277                    syn::Type::Path(path) => get_first_segment(path).ident.to_string(),
278                    _ => unreachable!("Const generic type can only be path"),
279                };
280                Some(quote! {
281                    #idl::IdlTypeDefGeneric::Const {
282                        name: #name.into(),
283                        ty: #ty.into(),
284                    }
285                })
286            }
287            _ => None,
288        })
289        .collect::<Vec<_>>();
290
291    Ok((
292        quote! {
293            #idl::IdlTypeDef {
294                name: Self::get_full_path(),
295                docs: #docs,
296                serialization: #serialization,
297                repr: #repr,
298                generics: vec![#(#generics.into()),*],
299                ty: #ty,
300            }
301        },
302        defined,
303    ))
304}
305
306fn get_attr_str(name: impl AsRef<str>, attrs: &[syn::Attribute]) -> Option<String> {
307    attrs
308        .iter()
309        .filter(|attr| {
310            attr.path
311                .segments
312                .first()
313                .filter(|seg| seg.ident == name)
314                .is_some()
315        })
316        .map(|attr| attr.tokens.to_string())
317        .reduce(|acc, cur| {
318            format!(
319                "{} , {}",
320                acc.get(..acc.len() - 1).unwrap(),
321                cur.get(1..).unwrap()
322            )
323        })
324}
325
326fn gen_idl_field(
327    field: &syn::Field,
328    generic_params: &[syn::Ident],
329    no_docs: bool,
330) -> Result<(TokenStream, Vec<syn::TypePath>)> {
331    let idl = get_idl_module_path();
332
333    let name = field.ident.as_ref().unwrap().to_string();
334    let docs = match docs::parse(&field.attrs) {
335        Some(docs) if !no_docs => quote! { vec![#(#docs.into()),*] },
336        _ => quote! { vec![] },
337    };
338    let (ty, defined) = gen_idl_type(&field.ty, generic_params)?;
339
340    Ok((
341        quote! {
342            #idl::IdlField {
343                name: #name.into(),
344                docs: #docs,
345                ty: #ty,
346            }
347        },
348        defined,
349    ))
350}
351
352pub fn gen_idl_type(
353    ty: &syn::Type,
354    generic_params: &[syn::Ident],
355) -> Result<(TokenStream, Vec<syn::TypePath>)> {
356    let idl = get_idl_module_path();
357
358    fn the_only_segment_is(path: &syn::TypePath, cmp: &str) -> bool {
359        if path.path.segments.len() != 1 {
360            return false;
361        };
362        return get_first_segment(path).ident == cmp;
363    }
364
365    fn get_angle_bracketed_type_args(seg: &syn::PathSegment) -> Vec<&syn::Type> {
366        match &seg.arguments {
367            syn::PathArguments::AngleBracketed(ab) => ab
368                .args
369                .iter()
370                .filter_map(|arg| match arg {
371                    syn::GenericArgument::Type(ty) => Some(ty),
372                    _ => None,
373                })
374                .collect(),
375            _ => panic!("No angle bracket for {seg:#?}"),
376        }
377    }
378
379    match ty {
380        syn::Type::Path(path) if the_only_segment_is(path, "bool") => {
381            Ok((quote! { #idl::IdlType::Bool }, vec![]))
382        }
383        syn::Type::Path(path) if the_only_segment_is(path, "u8") => {
384            Ok((quote! { #idl::IdlType::U8 }, vec![]))
385        }
386        syn::Type::Path(path) if the_only_segment_is(path, "i8") => {
387            Ok((quote! { #idl::IdlType::I8 }, vec![]))
388        }
389        syn::Type::Path(path) if the_only_segment_is(path, "u16") => {
390            Ok((quote! { #idl::IdlType::U16 }, vec![]))
391        }
392        syn::Type::Path(path) if the_only_segment_is(path, "i16") => {
393            Ok((quote! { #idl::IdlType::I16 }, vec![]))
394        }
395        syn::Type::Path(path) if the_only_segment_is(path, "u32") => {
396            Ok((quote! { #idl::IdlType::U32 }, vec![]))
397        }
398        syn::Type::Path(path) if the_only_segment_is(path, "i32") => {
399            Ok((quote! { #idl::IdlType::I32 }, vec![]))
400        }
401        syn::Type::Path(path) if the_only_segment_is(path, "f32") => {
402            Ok((quote! { #idl::IdlType::F32 }, vec![]))
403        }
404        syn::Type::Path(path) if the_only_segment_is(path, "u64") => {
405            Ok((quote! { #idl::IdlType::U64 }, vec![]))
406        }
407        syn::Type::Path(path) if the_only_segment_is(path, "i64") => {
408            Ok((quote! { #idl::IdlType::I64 }, vec![]))
409        }
410        syn::Type::Path(path) if the_only_segment_is(path, "f64") => {
411            Ok((quote! { #idl::IdlType::F64 }, vec![]))
412        }
413        syn::Type::Path(path) if the_only_segment_is(path, "u128") => {
414            Ok((quote! { #idl::IdlType::U128 }, vec![]))
415        }
416        syn::Type::Path(path) if the_only_segment_is(path, "i128") => {
417            Ok((quote! { #idl::IdlType::I128 }, vec![]))
418        }
419        syn::Type::Path(path)
420            if the_only_segment_is(path, "String") || the_only_segment_is(path, "str") =>
421        {
422            Ok((quote! { #idl::IdlType::String }, vec![]))
423        }
424        syn::Type::Path(path) if the_only_segment_is(path, "Pubkey") => {
425            Ok((quote! { #idl::IdlType::Pubkey }, vec![]))
426        }
427        syn::Type::Path(path) if the_only_segment_is(path, "Option") => {
428            let segment = get_first_segment(path);
429            let arg = get_angle_bracketed_type_args(segment)
430                .into_iter()
431                .next()
432                .unwrap();
433            let (inner, defined) = gen_idl_type(arg, generic_params)?;
434            Ok((quote! { #idl::IdlType::Option(Box::new(#inner)) }, defined))
435        }
436        syn::Type::Path(path) if the_only_segment_is(path, "Vec") => {
437            let segment = get_first_segment(path);
438            let arg = get_angle_bracketed_type_args(segment)
439                .into_iter()
440                .next()
441                .unwrap();
442            match arg {
443                syn::Type::Path(path) if the_only_segment_is(path, "u8") => {
444                    return Ok((quote! {#idl::IdlType::Bytes}, vec![]));
445                }
446                _ => (),
447            };
448            let (inner, defined) = gen_idl_type(arg, generic_params)?;
449            Ok((quote! { #idl::IdlType::Vec(Box::new(#inner)) }, defined))
450        }
451        syn::Type::Path(path) if the_only_segment_is(path, "Box") => {
452            let segment = get_first_segment(path);
453            let arg = get_angle_bracketed_type_args(segment)
454                .into_iter()
455                .next()
456                .unwrap();
457            gen_idl_type(arg, generic_params)
458        }
459        syn::Type::Array(arr) => {
460            let len = &arr.len;
461            let is_generic = generic_params.iter().any(|param| match len {
462                syn::Expr::Path(path) => path.path.is_ident(param),
463                _ => false,
464            });
465
466            let len = if is_generic {
467                match len {
468                    syn::Expr::Path(len) => {
469                        let len = len.path.get_ident().unwrap().to_string();
470                        quote! { #idl::IdlArrayLen::Generic(#len.into()) }
471                    }
472                    _ => unreachable!("Array length can only be a generic parameter"),
473                }
474            } else {
475                quote! { #idl::IdlArrayLen::Value(#len) }
476            };
477
478            let (inner, defined) = gen_idl_type(&arr.elem, generic_params)?;
479            Ok((
480                quote! { #idl::IdlType::Array(Box::new(#inner), #len) },
481                defined,
482            ))
483        }
484        // Defined
485        syn::Type::Path(path) => {
486            let is_generic_param = generic_params.iter().any(|param| path.path.is_ident(param));
487            if is_generic_param {
488                let generic = get_first_segment(path).ident.to_string();
489                return Ok((quote! { #idl::IdlType::Generic(#generic.into()) }, vec![]));
490            }
491
492            // Handle type aliases and external types
493            #[cfg(procmacro2_semver_exempt)]
494            {
495                use super::{common::find_path, external::get_external_type};
496                use crate::parser::context::CrateContext;
497                use quote::ToTokens;
498
499                let source_path = proc_macro2::Span::call_site().source_file().path();
500                if let Ok(Ok(ctx)) = find_path("lib.rs", &source_path).map(CrateContext::parse) {
501                    let name = path.path.segments.last().unwrap().ident.to_string();
502                    let alias = ctx.type_aliases().find(|ty| ty.ident == name);
503                    if let Some(alias) = alias {
504                        if let Some(segment) = path.path.segments.last() {
505                            if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
506                                let inners = args
507                                    .args
508                                    .iter()
509                                    .map(|arg| match arg {
510                                        syn::GenericArgument::Type(ty) => match ty {
511                                            syn::Type::Path(inner_ty) => {
512                                                inner_ty.path.to_token_stream().to_string()
513                                            }
514                                            _ => {
515                                                unimplemented!("Inner type not implemented: {ty:?}")
516                                            }
517                                        },
518                                        syn::GenericArgument::Const(c) => {
519                                            c.to_token_stream().to_string()
520                                        }
521                                        _ => unimplemented!("Arg not implemented: {arg:?}"),
522                                    })
523                                    .collect::<Vec<_>>();
524
525                                let outer = match &*alias.ty {
526                                    syn::Type::Path(outer_ty) => outer_ty.path.to_token_stream(),
527                                    syn::Type::Array(outer_ty) => outer_ty.to_token_stream(),
528                                    _ => unimplemented!("Type not implemented: {:?}", alias.ty),
529                                }
530                                .to_string();
531
532                                let resolved_alias = alias
533                                    .generics
534                                    .params
535                                    .iter()
536                                    .map(|param| match param {
537                                        syn::GenericParam::Const(param) => param.ident.to_string(),
538                                        syn::GenericParam::Type(param) => param.ident.to_string(),
539                                        _ => panic!("Lifetime parameters are not allowed"),
540                                    })
541                                    .enumerate()
542                                    .fold(outer, |acc, (i, cur)| {
543                                        let inner = &inners[i];
544                                        // The spacing of the `outer` variable can differ between
545                                        // versions, e.g. `[T; N]` and `[T ; N]`
546                                        acc.replace(&format!(" {cur} "), &format!(" {inner} "))
547                                            .replace(&format!(" {cur},"), &format!(" {inner},"))
548                                            .replace(&format!("[{cur} "), &format!("[{inner} "))
549                                            .replace(&format!("[{cur};"), &format!("[{inner};"))
550                                            .replace(&format!(" {cur}]"), &format!(" {inner}]"))
551                                    });
552                                if let Ok(ty) = syn::parse_str(&resolved_alias) {
553                                    return gen_idl_type(&ty, generic_params);
554                                }
555                            }
556                        };
557
558                        // Non-generic type alias e.g. `type UnixTimestamp = i64`
559                        return gen_idl_type(&*alias.ty, generic_params);
560                    }
561
562                    // Handle external types
563                    let is_external = ctx
564                        .structs()
565                        .map(|s| s.ident.to_string())
566                        .chain(ctx.enums().map(|e| e.ident.to_string()))
567                        .find(|defined| defined == &name)
568                        .is_none();
569                    if is_external {
570                        if let Ok(Some(ty)) = get_external_type(&name, source_path) {
571                            return gen_idl_type(&ty, generic_params);
572                        }
573                    }
574                }
575            }
576
577            // Defined in crate
578            let mut generics = vec![];
579            let mut defined = vec![path.clone()];
580
581            if let Some(segment) = path.path.segments.last() {
582                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
583                    for arg in &args.args {
584                        let generic = match arg {
585                            syn::GenericArgument::Const(c) => {
586                                quote! { #idl::IdlGenericArg::Const { value: #c.to_string() } }
587                            }
588                            // `MY_CONST` in `Foo<MY_CONST>` is parsed as `GenericArgument::Type`
589                            // instead of `GenericArgument::Const` because they're indistinguishable
590                            // syntactically, as mentioned in
591                            // https://github.com/dtolnay/syn/blob/bfa790b8e445dc67b7ab94d75adb1a92d6296c9a/src/path.rs#L113-L115
592                            //
593                            // As a workaround, we're manually checking to see if it *looks* like a
594                            // constant identifier to fix the issue mentioned in
595                            // https://github.com/coral-xyz/anchor/issues/3520
596                            syn::GenericArgument::Type(syn::Type::Path(p))
597                                if p.path
598                                    .segments
599                                    .last()
600                                    .map(|seg| seg.ident.to_string())
601                                    .map(|ident| ident.len() > 1 && ident == ident.to_uppercase())
602                                    .unwrap_or_default() =>
603                            {
604                                quote! { #idl::IdlGenericArg::Const { value: #p.to_string() } }
605                            }
606                            syn::GenericArgument::Type(ty) => {
607                                let (ty, def) = gen_idl_type(ty, generic_params)?;
608                                defined.extend(def);
609                                quote! { #idl::IdlGenericArg::Type { ty: #ty } }
610                            }
611                            _ => return Err(anyhow!("Unsupported generic argument: {arg:#?}")),
612                        };
613                        generics.push(generic);
614                    }
615                }
616            }
617
618            Ok((
619                quote! {
620                    #idl::IdlType::Defined {
621                        name: <#path>::get_full_path(),
622                        generics: vec![#(#generics),*],
623                    }
624                },
625                defined,
626            ))
627        }
628        syn::Type::Reference(reference) => match reference.elem.as_ref() {
629            syn::Type::Slice(slice) if matches!(&*slice.elem, syn::Type::Path(path) if the_only_segment_is(path, "u8")) => {
630                Ok((quote! {#idl::IdlType::Bytes}, vec![]))
631            }
632            _ => gen_idl_type(&reference.elem, generic_params),
633        },
634        _ => Err(anyhow!("Unknown type: {ty:#?}")),
635    }
636}
637
638fn get_first_segment(type_path: &syn::TypePath) -> &syn::PathSegment {
639    type_path.path.segments.first().unwrap()
640}