wiggle_generate/
funcs.rs

1use crate::codegen_settings::{CodegenSettings, ErrorType};
2use crate::lifetimes::anon_lifetime;
3use crate::module_trait::passed_by_reference;
4use crate::names;
5use crate::types::WiggleType;
6use proc_macro2::{Ident, Span, TokenStream};
7use quote::quote;
8use std::mem;
9use witx::Instruction;
10
11pub fn define_func(
12    module: &witx::Module,
13    func: &witx::InterfaceFunc,
14    settings: &CodegenSettings,
15) -> TokenStream {
16    let (ts, _bounds) = _define_func(module, func, settings);
17    ts
18}
19
20pub fn func_bounds(
21    module: &witx::Module,
22    func: &witx::InterfaceFunc,
23    settings: &CodegenSettings,
24) -> Vec<Ident> {
25    let (_ts, bounds) = _define_func(module, func, settings);
26    bounds
27}
28
29fn _define_func(
30    module: &witx::Module,
31    func: &witx::InterfaceFunc,
32    settings: &CodegenSettings,
33) -> (TokenStream, Vec<Ident>) {
34    let ident = names::func(&func.name);
35
36    let (wasm_params, wasm_results) = func.wasm_signature();
37    let param_names = (0..wasm_params.len())
38        .map(|i| Ident::new(&format!("arg{i}"), Span::call_site()))
39        .collect::<Vec<_>>();
40    let abi_params = wasm_params.iter().zip(&param_names).map(|(arg, name)| {
41        let wasm = names::wasm_type(*arg);
42        quote!(#name : #wasm)
43    });
44
45    let abi_ret = match wasm_results.len() {
46        0 => quote!(()),
47        1 => {
48            let ty = names::wasm_type(wasm_results[0]);
49            quote!(#ty)
50        }
51        _ => unimplemented!(),
52    };
53
54    let mut body = TokenStream::new();
55    let mut bounds = vec![names::trait_name(&module.name)];
56    func.call_interface(
57        &module.name,
58        &mut Rust {
59            src: &mut body,
60            params: &param_names,
61            block_storage: Vec::new(),
62            blocks: Vec::new(),
63            module,
64            funcname: func.name.as_str(),
65            settings,
66            bounds: &mut bounds,
67        },
68    );
69
70    let mod_name = &module.name.as_str();
71    let func_name = &func.name.as_str();
72    let mk_span = quote!(
73        let _span = wiggle::tracing::span!(
74            wiggle::tracing::Level::TRACE,
75            "wiggle abi",
76            module = #mod_name,
77            function = #func_name
78        );
79    );
80    let ctx_type = if settings.mutable {
81        quote!(&mut)
82    } else {
83        quote!(&)
84    };
85    if settings.get_async(&module, &func).is_sync() {
86        let traced_body = if settings.tracing.enabled_for(&mod_name, &func_name) {
87            quote!(
88                #mk_span
89                _span.in_scope(|| {
90                  #body
91                })
92            )
93        } else {
94            quote!(#body)
95        };
96        (
97            quote!(
98                #[allow(unreachable_code)] // deals with warnings in noreturn functions
99                pub fn #ident(
100                    ctx: #ctx_type (impl #(#bounds)+*),
101                    memory: &mut wiggle::GuestMemory<'_>,
102                    #(#abi_params),*
103                ) -> wiggle::anyhow::Result<#abi_ret> {
104                    use std::convert::TryFrom as _;
105                    #traced_body
106                }
107            ),
108            bounds,
109        )
110    } else {
111        let traced_body = if settings.tracing.enabled_for(&mod_name, &func_name) {
112            quote!(
113                use wiggle::tracing::Instrument as _;
114                #mk_span
115                async move {
116                    #body
117                }.instrument(_span).await
118            )
119        } else {
120            quote!(
121                #body
122            )
123        };
124        (
125            quote!(
126                #[allow(unreachable_code)] // deals with warnings in noreturn functions
127                pub async fn #ident(
128                    ctx: #ctx_type (impl #(#bounds)+*),
129                    memory: &mut wiggle::GuestMemory<'_>,
130                    #(#abi_params),*
131                ) -> wiggle::anyhow::Result<#abi_ret> {
132                    use std::convert::TryFrom as _;
133                    #traced_body
134                }
135            ),
136            bounds,
137        )
138    }
139}
140
141struct Rust<'a> {
142    src: &'a mut TokenStream,
143    params: &'a [Ident],
144    block_storage: Vec<TokenStream>,
145    blocks: Vec<TokenStream>,
146    module: &'a witx::Module,
147    funcname: &'a str,
148    settings: &'a CodegenSettings,
149    bounds: &'a mut Vec<Ident>,
150}
151
152impl Rust<'_> {
153    fn bound(&mut self, i: Ident) {
154        if !self.bounds.contains(&i) {
155            self.bounds.push(i);
156        }
157    }
158}
159
160impl witx::Bindgen for Rust<'_> {
161    type Operand = TokenStream;
162
163    fn push_block(&mut self) {
164        let prev = mem::replace(self.src, TokenStream::new());
165        self.block_storage.push(prev);
166    }
167
168    fn finish_block(&mut self, operand: Option<TokenStream>) {
169        let to_restore = self.block_storage.pop().unwrap();
170        let src = mem::replace(self.src, to_restore);
171        match operand {
172            None => self.blocks.push(src),
173            Some(s) => {
174                if src.is_empty() {
175                    self.blocks.push(s);
176                } else {
177                    self.blocks.push(quote!({ #src; #s }));
178                }
179            }
180        }
181    }
182
183    // This is only used for `call_wasm` at this time.
184    fn allocate_space(&mut self, _: usize, _: &witx::NamedType) {
185        unimplemented!()
186    }
187
188    fn emit(
189        &mut self,
190        inst: &Instruction<'_>,
191        operands: &mut Vec<TokenStream>,
192        results: &mut Vec<TokenStream>,
193    ) {
194        let wrap_err = |location: &str| {
195            let modulename = self.module.name.as_str();
196            let funcname = self.funcname;
197            quote! {
198                |e| {
199                    wiggle::GuestError::InFunc {
200                        modulename: #modulename,
201                        funcname: #funcname,
202                        location: #location,
203                        err: Box::new(wiggle::GuestError::from(e)),
204                    }
205                }
206            }
207        };
208
209        let mut try_from = |ty: TokenStream| {
210            let val = operands.pop().unwrap();
211            let wrap_err = wrap_err(&format!("convert {ty}"));
212            results.push(quote!(#ty::try_from(#val).map_err(#wrap_err)?));
213        };
214
215        match inst {
216            Instruction::GetArg { nth } => {
217                let param = &self.params[*nth];
218                results.push(quote!(#param));
219            }
220
221            Instruction::PointerFromI32 { ty } | Instruction::ConstPointerFromI32 { ty } => {
222                let val = operands.pop().unwrap();
223                let pointee_type = names::type_ref(ty, anon_lifetime());
224                results.push(quote! {
225                    wiggle::GuestPtr::<#pointee_type>::new(#val as u32)
226                });
227            }
228
229            Instruction::ListFromPointerLength { ty } => {
230                let ptr = &operands[0];
231                let len = &operands[1];
232                let ty = match &**ty.type_() {
233                    witx::Type::Builtin(witx::BuiltinType::Char) => quote!(str),
234                    _ => {
235                        let ty = names::type_ref(ty, anon_lifetime());
236                        quote!([#ty])
237                    }
238                };
239                results.push(quote! {
240                    wiggle::GuestPtr::<#ty>::new((#ptr as u32, #len as u32));
241                })
242            }
243
244            Instruction::CallInterface { func, .. } => {
245                // Use the `tracing` crate to log all arguments that are going
246                // out, and afterwards we call the function with those bindings.
247                let mut args = Vec::new();
248                for (i, param) in func.params.iter().enumerate() {
249                    let name = names::func_param(&param.name);
250                    let val = &operands[i];
251                    self.src.extend(quote!(let #name = #val;));
252                    if passed_by_reference(param.tref.type_()) {
253                        args.push(quote!(&#name));
254                    } else {
255                        args.push(quote!(#name));
256                    }
257                }
258                if self
259                    .settings
260                    .tracing
261                    .enabled_for(self.module.name.as_str(), self.funcname)
262                    && func.params.len() > 0
263                {
264                    let args = func
265                        .params
266                        .iter()
267                        .map(|param| {
268                            let name = names::func_param(&param.name);
269                            if param.impls_display() {
270                                quote!( #name = wiggle::tracing::field::display(&#name) )
271                            } else {
272                                quote!( #name = wiggle::tracing::field::debug(&#name) )
273                            }
274                        })
275                        .collect::<Vec<_>>();
276                    self.src.extend(quote! {
277                        wiggle::tracing::event!(wiggle::tracing::Level::TRACE, #(#args),*);
278                    });
279                }
280
281                let trait_name = names::trait_name(&self.module.name);
282                let ident = names::func(&func.name);
283                if self.settings.get_async(&self.module, &func).is_sync() {
284                    self.src.extend(quote! {
285                        let ret = #trait_name::#ident(ctx, memory, #(#args),*);
286                    })
287                } else {
288                    self.src.extend(quote! {
289                        let ret = #trait_name::#ident(ctx, memory, #(#args),*).await;
290                    })
291                };
292                if self
293                    .settings
294                    .tracing
295                    .enabled_for(self.module.name.as_str(), self.funcname)
296                {
297                    self.src.extend(quote! {
298                        wiggle::tracing::event!(
299                            wiggle::tracing::Level::TRACE,
300                            result = wiggle::tracing::field::debug(&ret),
301                        );
302                    });
303                }
304
305                if func.results.len() > 0 {
306                    results.push(quote!(ret));
307                } else if func.noreturn {
308                    self.src.extend(quote!(return Err(ret);));
309                }
310            }
311
312            // Lowering an enum is typically simple but if we have an error
313            // transformation registered for this enum then what we're actually
314            // doing is lowering from a user-defined error type to the error
315            // enum, and *then* we lower to an i32.
316            Instruction::EnumLower { ty } => {
317                let val = operands.pop().unwrap();
318                let val = match self.settings.errors.for_name(ty) {
319                    Some(ErrorType::User(custom)) => {
320                        let method = names::user_error_conversion_method(&custom);
321                        self.bound(quote::format_ident!("UserErrorConversion"));
322                        quote!(UserErrorConversion::#method(ctx, #val)?)
323                    }
324                    Some(ErrorType::Generated(_)) => quote!(#val.downcast()?),
325                    None => val,
326                };
327                results.push(quote!(#val as i32));
328            }
329
330            Instruction::ResultLower { err: err_ty, .. } => {
331                let err = self.blocks.pop().unwrap();
332                let ok = self.blocks.pop().unwrap();
333                let val = operands.pop().unwrap();
334                let err_typename = names::type_ref(err_ty.unwrap(), anon_lifetime());
335                results.push(quote! {
336                    match #val {
337                        Ok(e) => { #ok; <#err_typename as wiggle::GuestErrorType>::success() as i32 }
338                        Err(e) => { #err }
339                    }
340                });
341            }
342
343            Instruction::VariantPayload => results.push(quote!(e)),
344
345            Instruction::Return { amt: 0 } => {
346                self.src.extend(quote!(return Ok(())));
347            }
348            Instruction::Return { amt: 1 } => {
349                let val = operands.pop().unwrap();
350                self.src.extend(quote!(return Ok(#val)));
351            }
352            Instruction::Return { .. } => unimplemented!(),
353
354            Instruction::TupleLower { amt } => {
355                let names = (0..*amt)
356                    .map(|i| Ident::new(&format!("t{i}"), Span::call_site()))
357                    .collect::<Vec<_>>();
358                let val = operands.pop().unwrap();
359                self.src.extend(quote!( let (#(#names,)*) = #val;));
360                results.extend(names.iter().map(|i| quote!(#i)));
361            }
362
363            Instruction::Store { ty } => {
364                let ptr = operands.pop().unwrap();
365                let val = operands.pop().unwrap();
366                let wrap_err = wrap_err(&format!("write {}", ty.name.as_str()));
367                let pointee_type = names::type_(&ty.name);
368                self.src.extend(quote! {
369                    memory.write(
370                        wiggle::GuestPtr::<#pointee_type>::new(#ptr as u32),
371                        #val,
372                    )
373                    .map_err(#wrap_err)?;
374                });
375            }
376
377            Instruction::Load { ty } => {
378                let ptr = operands.pop().unwrap();
379                let wrap_err = wrap_err(&format!("read {}", ty.name.as_str()));
380                let pointee_type = names::type_(&ty.name);
381                results.push(quote! {
382                    memory.read(wiggle::GuestPtr::<#pointee_type>::new(#ptr as u32))
383                        .map_err(#wrap_err)?
384                });
385            }
386
387            Instruction::HandleFromI32 { ty } => {
388                let val = operands.pop().unwrap();
389                let ty = names::type_(&ty.name);
390                results.push(quote!(#ty::from(#val)));
391            }
392
393            // Smaller-than-32 numerical conversions are done with `TryFrom` to
394            // ensure we're not losing bits.
395            Instruction::U8FromI32 => try_from(quote!(u8)),
396            Instruction::S8FromI32 => try_from(quote!(i8)),
397            Instruction::Char8FromI32 => try_from(quote!(u8)),
398            Instruction::U16FromI32 => try_from(quote!(u16)),
399            Instruction::S16FromI32 => try_from(quote!(i16)),
400
401            // Conversions with matching bit-widths but different signededness
402            // use `as` since we're basically just reinterpreting the bits.
403            Instruction::U32FromI32 | Instruction::UsizeFromI32 => {
404                let val = operands.pop().unwrap();
405                results.push(quote!(#val as u32));
406            }
407            Instruction::U64FromI64 => {
408                let val = operands.pop().unwrap();
409                results.push(quote!(#val as u64));
410            }
411
412            // Conversions to enums/bitflags use `TryFrom` to ensure that the
413            // values are valid coming in.
414            Instruction::EnumLift { ty }
415            | Instruction::BitflagsFromI64 { ty }
416            | Instruction::BitflagsFromI32 { ty } => {
417                let ty = names::type_(&ty.name);
418                try_from(quote!(#ty))
419            }
420
421            // No conversions necessary for these, the native wasm type matches
422            // our own representation.
423            Instruction::If32FromF32
424            | Instruction::If64FromF64
425            | Instruction::S32FromI32
426            | Instruction::S64FromI64 => results.push(operands.pop().unwrap()),
427
428            // There's a number of other instructions we could implement but
429            // they're not exercised by WASI at this time. As necessary we can
430            // add code to implement them.
431            other => panic!("no implementation for {other:?}"),
432        }
433    }
434}