wiggle_generate/
module_trait.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3
4use crate::codegen_settings::{CodegenSettings, ErrorType};
5use crate::names;
6use witx::Module;
7
8pub fn passed_by_reference(ty: &witx::Type) -> bool {
9    match ty {
10        witx::Type::Record(r) => r.bitflags_repr().is_none(),
11        witx::Type::Variant(v) => !v.is_enum(),
12        _ => false,
13    }
14}
15
16pub fn define_module_trait(m: &Module, settings: &CodegenSettings) -> TokenStream {
17    let traitname = names::trait_name(&m.name);
18    let traitmethods = m.funcs().map(|f| {
19        let funcname = names::func(&f.name);
20        let args = f.params.iter().map(|arg| {
21            let arg_name = names::func_param(&arg.name);
22            let arg_typename = names::type_ref(&arg.tref, quote!());
23            let arg_type = if passed_by_reference(&*arg.tref.type_()) {
24                quote!(&#arg_typename)
25            } else {
26                quote!(#arg_typename)
27            };
28            quote!(#arg_name: #arg_type)
29        });
30
31        let result = match f.results.len() {
32            0 if f.noreturn => quote!(wiggle::anyhow::Error),
33            0 => quote!(()),
34            1 => {
35                let (ok, err) = match &**f.results[0].tref.type_() {
36                    witx::Type::Variant(v) => match v.as_expected() {
37                        Some(p) => p,
38                        None => unimplemented!("anonymous variant ref {:?}", v),
39                    },
40                    _ => unimplemented!(),
41                };
42
43                let ok = match ok {
44                    Some(ty) => names::type_ref(ty, quote!()),
45                    None => quote!(()),
46                };
47                let err = match err {
48                    Some(ty) => match settings.errors.for_abi_error(ty) {
49                        Some(ErrorType::User(custom)) => {
50                            let tn = custom.typename();
51                            quote!(super::#tn)
52                        }
53                        Some(ErrorType::Generated(g)) => g.typename(),
54                        None => names::type_ref(ty, quote!()),
55                    },
56                    None => quote!(()),
57                };
58                quote!(Result<#ok, #err>)
59            }
60            _ => unimplemented!(),
61        };
62
63        let asyncness = if settings.get_async(&m, &f).is_sync() {
64            quote!()
65        } else {
66            quote!(async)
67        };
68
69        let self_ = if settings.mutable {
70            quote!(&mut self)
71        } else {
72            quote!(&self)
73        };
74        quote!(
75            #asyncness fn #funcname(
76                #self_,
77                mem: &mut wiggle::GuestMemory<'_>,
78                #(#args),*
79            ) -> #result;
80        )
81    });
82
83    quote! {
84        #[wiggle::async_trait]
85        pub trait #traitname {
86            #(#traitmethods)*
87        }
88    }
89}