wit_bindgen_rust_macro/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use quote::ToTokens;
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};
6use syn::parse::{Error, Parse, ParseStream, Result};
7use syn::punctuated::Punctuated;
8use syn::spanned::Spanned;
9use syn::{braced, token, LitStr, Token};
10use wit_bindgen_core::wit_parser::{PackageId, Resolve, UnresolvedPackageGroup, WorldId};
11use wit_bindgen_rust::{AsyncConfig, Opts, Ownership, WithOption};
12
13#[proc_macro]
14pub fn generate(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
15    syn::parse_macro_input!(input as Config)
16        .expand()
17        .unwrap_or_else(Error::into_compile_error)
18        .into()
19}
20
21fn anyhow_to_syn(span: Span, err: anyhow::Error) -> Error {
22    let err = attach_with_context(err);
23    let mut msg = err.to_string();
24    for cause in err.chain().skip(1) {
25        msg.push_str(&format!("\n\nCaused by:\n  {cause}"));
26    }
27    Error::new(span, msg)
28}
29
30fn attach_with_context(err: anyhow::Error) -> anyhow::Error {
31    if let Some(e) = err.downcast_ref::<wit_bindgen_rust::MissingWith>() {
32        let option = e.0.clone();
33        return err.context(format!(
34            "missing one of:\n\
35            * `generate_all` option\n\
36            * `with: {{ \"{option}\": path::to::bindings, }}`\n\
37            * `with: {{ \"{option}\": generate, }}`\
38            "
39        ));
40    }
41    err
42}
43
44struct Config {
45    opts: Opts,
46    resolve: Resolve,
47    world: WorldId,
48    files: Vec<PathBuf>,
49    debug: bool,
50}
51
52/// The source of the wit package definition
53enum Source {
54    /// A path to a wit directory
55    Paths(Vec<PathBuf>),
56    /// Inline sources have an optional path to a directory of their dependencies
57    Inline(String, Option<Vec<PathBuf>>),
58}
59
60impl Parse for Config {
61    fn parse(input: ParseStream<'_>) -> Result<Self> {
62        let call_site = Span::call_site();
63        let mut opts = Opts::default();
64        let mut world = None;
65        let mut source = None;
66        let mut features = Vec::new();
67        let mut async_configured = false;
68        let mut debug = false;
69
70        if input.peek(token::Brace) {
71            let content;
72            syn::braced!(content in input);
73            let fields = Punctuated::<Opt, Token![,]>::parse_terminated(&content)?;
74            for field in fields.into_pairs() {
75                match field.into_value() {
76                    Opt::Path(span, p) => {
77                        let paths = p.into_iter().map(|f| PathBuf::from(f.value())).collect();
78
79                        source = Some(match source {
80                            Some(Source::Paths(_)) | Some(Source::Inline(_, Some(_))) => {
81                                return Err(Error::new(span, "cannot specify second source"));
82                            }
83                            Some(Source::Inline(i, None)) => Source::Inline(i, Some(paths)),
84                            None => Source::Paths(paths),
85                        })
86                    }
87                    Opt::World(s) => {
88                        if world.is_some() {
89                            return Err(Error::new(s.span(), "cannot specify second world"));
90                        }
91                        world = Some(s.value());
92                    }
93                    Opt::Inline(s) => {
94                        source = Some(match source {
95                            Some(Source::Inline(_, _)) => {
96                                return Err(Error::new(s.span(), "cannot specify second source"));
97                            }
98                            Some(Source::Paths(p)) => Source::Inline(s.value(), Some(p)),
99                            None => Source::Inline(s.value(), None),
100                        })
101                    }
102                    Opt::UseStdFeature => opts.std_feature = true,
103                    Opt::RawStrings => opts.raw_strings = true,
104                    Opt::Ownership(ownership) => opts.ownership = ownership,
105                    Opt::Skip(list) => opts.skip.extend(list.iter().map(|i| i.value())),
106                    Opt::RuntimePath(path) => opts.runtime_path = Some(path.value()),
107                    Opt::BitflagsPath(path) => opts.bitflags_path = Some(path.value()),
108                    Opt::Stubs => {
109                        opts.stubs = true;
110                    }
111                    Opt::ExportPrefix(prefix) => opts.export_prefix = Some(prefix.value()),
112                    Opt::AdditionalDerives(paths) => {
113                        opts.additional_derive_attributes = paths
114                            .into_iter()
115                            .map(|p| p.into_token_stream().to_string())
116                            .collect()
117                    }
118                    Opt::With(with) => opts.with.extend(with),
119                    Opt::GenerateAll => {
120                        opts.generate_all = true;
121                    }
122                    Opt::TypeSectionSuffix(suffix) => {
123                        opts.type_section_suffix = Some(suffix.value());
124                    }
125                    Opt::DisableRunCtorsOnceWorkaround(enable) => {
126                        opts.disable_run_ctors_once_workaround = enable.value();
127                    }
128                    Opt::DefaultBindingsModule(enable) => {
129                        opts.default_bindings_module = Some(enable.value());
130                    }
131                    Opt::ExportMacroName(name) => {
132                        opts.export_macro_name = Some(name.value());
133                    }
134                    Opt::PubExportMacro(enable) => {
135                        opts.pub_export_macro = enable.value();
136                    }
137                    Opt::GenerateUnusedTypes(enable) => {
138                        opts.generate_unused_types = enable.value();
139                    }
140                    Opt::Features(f) => {
141                        features.extend(f.into_iter().map(|f| f.value()));
142                    }
143                    Opt::DisableCustomSectionLinkHelpers(disable) => {
144                        opts.disable_custom_section_link_helpers = disable.value();
145                    }
146                    Opt::Debug(enable) => {
147                        debug = enable.value();
148                    }
149                    Opt::Async(val, span) => {
150                        if async_configured {
151                            return Err(Error::new(span, "cannot specify second async config"));
152                        }
153                        async_configured = true;
154                        if !matches!(val, AsyncConfig::None) && !cfg!(feature = "async") {
155                            return Err(Error::new(
156                                span,
157                                "must enable `async` feature to enable async imports and/or exports",
158                            ));
159                        }
160                        opts.async_ = val;
161                    }
162                }
163            }
164        } else {
165            world = input.parse::<Option<syn::LitStr>>()?.map(|s| s.value());
166            if input.parse::<Option<syn::token::In>>()?.is_some() {
167                source = Some(Source::Paths(vec![PathBuf::from(
168                    input.parse::<syn::LitStr>()?.value(),
169                )]));
170            }
171        }
172        let (resolve, pkgs, files) =
173            parse_source(&source, &features).map_err(|err| anyhow_to_syn(call_site, err))?;
174        let world = select_world(&resolve, &pkgs, world.as_deref())
175            .map_err(|e| anyhow_to_syn(call_site, e))?;
176        Ok(Config {
177            opts,
178            resolve,
179            world,
180            files,
181            debug,
182        })
183    }
184}
185
186fn select_world(
187    resolve: &Resolve,
188    pkgs: &[PackageId],
189    world: Option<&str>,
190) -> anyhow::Result<WorldId> {
191    if pkgs.len() == 1 {
192        resolve.select_world(pkgs[0], world)
193    } else {
194        assert!(!pkgs.is_empty());
195        match world {
196            Some(name) => {
197                if !name.contains(":") {
198                    anyhow::bail!(
199                        "with multiple packages a fully qualified \
200                         world name must be specified"
201                    )
202                }
203
204                // This will ignore the package argument due to the fully
205                // qualified name being used.
206                resolve.select_world(pkgs[0], world)
207            }
208            None => {
209                let worlds = pkgs
210                    .iter()
211                    .filter_map(|p| resolve.select_world(*p, None).ok())
212                    .collect::<Vec<_>>();
213                match &worlds[..] {
214                    [] => anyhow::bail!("no packages have a world"),
215                    [world] => Ok(*world),
216                    _ => anyhow::bail!("multiple packages have a world, must specify which to use"),
217                }
218            }
219        }
220    }
221}
222
223/// Parse the source
224fn parse_source(
225    source: &Option<Source>,
226    features: &[String],
227) -> anyhow::Result<(Resolve, Vec<PackageId>, Vec<PathBuf>)> {
228    let mut resolve = Resolve::default();
229    resolve.features.extend(features.iter().cloned());
230    let mut files = Vec::new();
231    let mut pkgs = Vec::new();
232    let root = PathBuf::from(std::env::var("CARGO_MANIFEST_DIR").unwrap());
233    let mut parse = |paths: &[PathBuf]| -> anyhow::Result<()> {
234        for path in paths {
235            let p = root.join(path);
236            // Try to normalize the path to make the error message more understandable when
237            // the path is not correct. Fallback to the original path if normalization fails
238            // (probably return an error somewhere else).
239            let normalized_path = match std::fs::canonicalize(&p) {
240                Ok(p) => p,
241                Err(_) => p.to_path_buf(),
242            };
243            let (pkg, sources) = resolve.push_path(normalized_path)?;
244            pkgs.push(pkg);
245            files.extend(sources.paths().map(|p| p.to_owned()));
246        }
247        Ok(())
248    };
249    match source {
250        Some(Source::Inline(s, path)) => {
251            if let Some(p) = path {
252                parse(p)?;
253            }
254            pkgs.push(resolve.push_group(UnresolvedPackageGroup::parse("macro-input", s)?)?);
255        }
256        Some(Source::Paths(p)) => parse(p)?,
257        None => parse(&vec![root.join("wit")])?,
258    };
259
260    Ok((resolve, pkgs, files))
261}
262
263impl Config {
264    fn expand(self) -> Result<TokenStream> {
265        let mut files = Default::default();
266        let mut generator = self.opts.build();
267        generator
268            .generate(&self.resolve, self.world, &mut files)
269            .map_err(|e| anyhow_to_syn(Span::call_site(), e))?;
270        let (_, src) = files.iter().next().unwrap();
271        let mut src = std::str::from_utf8(src).unwrap().to_string();
272
273        // If a magical `WIT_BINDGEN_DEBUG` environment variable is set then
274        // place a formatted version of the expanded code into a file. This file
275        // will then show up in rustc error messages for any codegen issues and can
276        // be inspected manually.
277        if std::env::var("WIT_BINDGEN_DEBUG").is_ok() || self.debug {
278            static INVOCATION: AtomicUsize = AtomicUsize::new(0);
279            let root = Path::new(env!("DEBUG_OUTPUT_DIR"));
280            let world_name = &self.resolve.worlds[self.world].name;
281            let n = INVOCATION.fetch_add(1, Relaxed);
282            let path = root.join(format!("{world_name}{n}.rs"));
283
284            // optimistically format the code but don't require success
285            let contents = match fmt(&src) {
286                Ok(formatted) => formatted,
287                Err(_) => src.clone(),
288            };
289            std::fs::write(&path, contents.as_bytes()).unwrap();
290
291            src = format!("include!({path:?});");
292        }
293        let mut contents = src.parse::<TokenStream>().unwrap();
294
295        // Include a dummy `include_bytes!` for any files we read so rustc knows that
296        // we depend on the contents of those files.
297        for file in self.files.iter() {
298            contents.extend(
299                format!(
300                    "const _: &[u8] = include_bytes!(r#\"{}\"#);\n",
301                    file.display()
302                )
303                .parse::<TokenStream>()
304                .unwrap(),
305            );
306        }
307
308        Ok(contents)
309    }
310}
311
312mod kw {
313    syn::custom_keyword!(std_feature);
314    syn::custom_keyword!(raw_strings);
315    syn::custom_keyword!(skip);
316    syn::custom_keyword!(world);
317    syn::custom_keyword!(path);
318    syn::custom_keyword!(inline);
319    syn::custom_keyword!(ownership);
320    syn::custom_keyword!(runtime_path);
321    syn::custom_keyword!(bitflags_path);
322    syn::custom_keyword!(exports);
323    syn::custom_keyword!(stubs);
324    syn::custom_keyword!(export_prefix);
325    syn::custom_keyword!(additional_derives);
326    syn::custom_keyword!(with);
327    syn::custom_keyword!(generate_all);
328    syn::custom_keyword!(type_section_suffix);
329    syn::custom_keyword!(disable_run_ctors_once_workaround);
330    syn::custom_keyword!(default_bindings_module);
331    syn::custom_keyword!(export_macro_name);
332    syn::custom_keyword!(pub_export_macro);
333    syn::custom_keyword!(generate_unused_types);
334    syn::custom_keyword!(features);
335    syn::custom_keyword!(disable_custom_section_link_helpers);
336    syn::custom_keyword!(imports);
337    syn::custom_keyword!(debug);
338}
339
340#[derive(Clone)]
341enum ExportKey {
342    World,
343    Name(syn::LitStr),
344}
345
346impl Parse for ExportKey {
347    fn parse(input: ParseStream<'_>) -> Result<Self> {
348        let l = input.lookahead1();
349        Ok(if l.peek(kw::world) {
350            input.parse::<kw::world>()?;
351            Self::World
352        } else {
353            Self::Name(input.parse()?)
354        })
355    }
356}
357
358impl From<ExportKey> for wit_bindgen_rust::ExportKey {
359    fn from(key: ExportKey) -> Self {
360        match key {
361            ExportKey::World => Self::World,
362            ExportKey::Name(s) => Self::Name(s.value()),
363        }
364    }
365}
366
367enum AsyncConfigSomeKind {
368    Imports,
369    Exports,
370}
371
372enum Opt {
373    World(syn::LitStr),
374    Path(Span, Vec<syn::LitStr>),
375    Inline(syn::LitStr),
376    UseStdFeature,
377    RawStrings,
378    Skip(Vec<syn::LitStr>),
379    Ownership(Ownership),
380    RuntimePath(syn::LitStr),
381    BitflagsPath(syn::LitStr),
382    Stubs,
383    ExportPrefix(syn::LitStr),
384    // Parse as paths so we can take the concrete types/macro names rather than raw strings
385    AdditionalDerives(Vec<syn::Path>),
386    With(HashMap<String, WithOption>),
387    GenerateAll,
388    TypeSectionSuffix(syn::LitStr),
389    DisableRunCtorsOnceWorkaround(syn::LitBool),
390    DefaultBindingsModule(syn::LitStr),
391    ExportMacroName(syn::LitStr),
392    PubExportMacro(syn::LitBool),
393    GenerateUnusedTypes(syn::LitBool),
394    Features(Vec<syn::LitStr>),
395    DisableCustomSectionLinkHelpers(syn::LitBool),
396    Async(AsyncConfig, Span),
397    Debug(syn::LitBool),
398}
399
400impl Parse for Opt {
401    fn parse(input: ParseStream<'_>) -> Result<Self> {
402        let l = input.lookahead1();
403        if l.peek(kw::path) {
404            input.parse::<kw::path>()?;
405            input.parse::<Token![:]>()?;
406            // the `path` supports two forms:
407            // * path: "xxx"
408            // * path: ["aaa", "bbb"]
409            if input.peek(token::Bracket) {
410                let contents;
411                syn::bracketed!(contents in input);
412                let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
413                Ok(Opt::Path(list.span(), list.into_iter().collect()))
414            } else {
415                let path: LitStr = input.parse()?;
416                Ok(Opt::Path(path.span(), vec![path]))
417            }
418        } else if l.peek(kw::inline) {
419            input.parse::<kw::inline>()?;
420            input.parse::<Token![:]>()?;
421            Ok(Opt::Inline(input.parse()?))
422        } else if l.peek(kw::world) {
423            input.parse::<kw::world>()?;
424            input.parse::<Token![:]>()?;
425            Ok(Opt::World(input.parse()?))
426        } else if l.peek(kw::std_feature) {
427            input.parse::<kw::std_feature>()?;
428            Ok(Opt::UseStdFeature)
429        } else if l.peek(kw::raw_strings) {
430            input.parse::<kw::raw_strings>()?;
431            Ok(Opt::RawStrings)
432        } else if l.peek(kw::ownership) {
433            input.parse::<kw::ownership>()?;
434            input.parse::<Token![:]>()?;
435            let ownership = input.parse::<syn::Ident>()?;
436            Ok(Opt::Ownership(match ownership.to_string().as_str() {
437                "Owning" => Ownership::Owning,
438                "Borrowing" => Ownership::Borrowing {
439                    duplicate_if_necessary: {
440                        let contents;
441                        braced!(contents in input);
442                        let field = contents.parse::<syn::Ident>()?;
443                        match field.to_string().as_str() {
444                            "duplicate_if_necessary" => {
445                                contents.parse::<Token![:]>()?;
446                                contents.parse::<syn::LitBool>()?.value
447                            }
448                            name => {
449                                return Err(Error::new(
450                                    field.span(),
451                                    format!(
452                                        "unrecognized `Ownership::Borrowing` field: `{name}`; \
453                                         expected `duplicate_if_necessary`"
454                                    ),
455                                ));
456                            }
457                        }
458                    },
459                },
460                name => {
461                    return Err(Error::new(
462                        ownership.span(),
463                        format!(
464                            "unrecognized ownership: `{name}`; \
465                             expected `Owning` or `Borrowing`"
466                        ),
467                    ));
468                }
469            }))
470        } else if l.peek(kw::skip) {
471            input.parse::<kw::skip>()?;
472            input.parse::<Token![:]>()?;
473            let contents;
474            syn::bracketed!(contents in input);
475            let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
476            Ok(Opt::Skip(list.iter().cloned().collect()))
477        } else if l.peek(kw::runtime_path) {
478            input.parse::<kw::runtime_path>()?;
479            input.parse::<Token![:]>()?;
480            Ok(Opt::RuntimePath(input.parse()?))
481        } else if l.peek(kw::bitflags_path) {
482            input.parse::<kw::bitflags_path>()?;
483            input.parse::<Token![:]>()?;
484            Ok(Opt::BitflagsPath(input.parse()?))
485        } else if l.peek(kw::stubs) {
486            input.parse::<kw::stubs>()?;
487            Ok(Opt::Stubs)
488        } else if l.peek(kw::export_prefix) {
489            input.parse::<kw::export_prefix>()?;
490            input.parse::<Token![:]>()?;
491            Ok(Opt::ExportPrefix(input.parse()?))
492        } else if l.peek(kw::additional_derives) {
493            input.parse::<kw::additional_derives>()?;
494            input.parse::<Token![:]>()?;
495            let contents;
496            syn::bracketed!(contents in input);
497            let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
498            Ok(Opt::AdditionalDerives(list.iter().cloned().collect()))
499        } else if l.peek(kw::with) {
500            input.parse::<kw::with>()?;
501            input.parse::<Token![:]>()?;
502            let contents;
503            let _lbrace = braced!(contents in input);
504            let fields: Punctuated<_, Token![,]> =
505                contents.parse_terminated(with_field_parse, Token![,])?;
506            Ok(Opt::With(HashMap::from_iter(fields.into_iter())))
507        } else if l.peek(kw::generate_all) {
508            input.parse::<kw::generate_all>()?;
509            Ok(Opt::GenerateAll)
510        } else if l.peek(kw::type_section_suffix) {
511            input.parse::<kw::type_section_suffix>()?;
512            input.parse::<Token![:]>()?;
513            Ok(Opt::TypeSectionSuffix(input.parse()?))
514        } else if l.peek(kw::disable_run_ctors_once_workaround) {
515            input.parse::<kw::disable_run_ctors_once_workaround>()?;
516            input.parse::<Token![:]>()?;
517            Ok(Opt::DisableRunCtorsOnceWorkaround(input.parse()?))
518        } else if l.peek(kw::default_bindings_module) {
519            input.parse::<kw::default_bindings_module>()?;
520            input.parse::<Token![:]>()?;
521            Ok(Opt::DefaultBindingsModule(input.parse()?))
522        } else if l.peek(kw::export_macro_name) {
523            input.parse::<kw::export_macro_name>()?;
524            input.parse::<Token![:]>()?;
525            Ok(Opt::ExportMacroName(input.parse()?))
526        } else if l.peek(kw::pub_export_macro) {
527            input.parse::<kw::pub_export_macro>()?;
528            input.parse::<Token![:]>()?;
529            Ok(Opt::PubExportMacro(input.parse()?))
530        } else if l.peek(kw::generate_unused_types) {
531            input.parse::<kw::generate_unused_types>()?;
532            input.parse::<Token![:]>()?;
533            Ok(Opt::GenerateUnusedTypes(input.parse()?))
534        } else if l.peek(kw::features) {
535            input.parse::<kw::features>()?;
536            input.parse::<Token![:]>()?;
537            let contents;
538            syn::bracketed!(contents in input);
539            let list = Punctuated::<_, Token![,]>::parse_terminated(&contents)?;
540            Ok(Opt::Features(list.into_iter().collect()))
541        } else if l.peek(kw::disable_custom_section_link_helpers) {
542            input.parse::<kw::disable_custom_section_link_helpers>()?;
543            input.parse::<Token![:]>()?;
544            Ok(Opt::DisableCustomSectionLinkHelpers(input.parse()?))
545        } else if l.peek(kw::debug) {
546            input.parse::<kw::debug>()?;
547            input.parse::<Token![:]>()?;
548            Ok(Opt::Debug(input.parse()?))
549        } else if l.peek(Token![async]) {
550            let span = input.parse::<Token![async]>()?.span;
551            input.parse::<Token![:]>()?;
552            if input.peek(syn::LitBool) {
553                if input.parse::<syn::LitBool>()?.value {
554                    Ok(Opt::Async(AsyncConfig::All, span))
555                } else {
556                    Ok(Opt::Async(AsyncConfig::None, span))
557                }
558            } else {
559                let mut imports = Vec::new();
560                let mut exports = Vec::new();
561                let contents;
562                syn::braced!(contents in input);
563                for (kind, values) in
564                    contents.parse_terminated(parse_async_some_field, Token![,])?
565                {
566                    match kind {
567                        AsyncConfigSomeKind::Imports => imports = values,
568                        AsyncConfigSomeKind::Exports => exports = values,
569                    }
570                }
571                Ok(Opt::Async(AsyncConfig::Some { imports, exports }, span))
572            }
573        } else {
574            Err(l.error())
575        }
576    }
577}
578
579fn with_field_parse(input: ParseStream<'_>) -> Result<(String, WithOption)> {
580    let interface = input.parse::<syn::LitStr>()?.value();
581    input.parse::<Token![:]>()?;
582    let start = input.span();
583    let path = input.parse::<syn::Path>()?;
584
585    // It's not possible for the segments of a path to be empty
586    let span = start
587        .join(path.segments.last().unwrap().ident.span())
588        .unwrap_or(start);
589
590    if path.is_ident("generate") {
591        return Ok((interface, WithOption::Generate));
592    }
593
594    let mut buf = String::new();
595    let append = |buf: &mut String, segment: syn::PathSegment| -> Result<()> {
596        if !segment.arguments.is_none() {
597            return Err(Error::new(
598                span,
599                "Module path must not contain angles or parens",
600            ));
601        }
602
603        buf.push_str(&segment.ident.to_string());
604
605        Ok(())
606    };
607
608    if path.leading_colon.is_some() {
609        buf.push_str("::");
610    }
611
612    let mut segments = path.segments.into_iter();
613
614    if let Some(segment) = segments.next() {
615        append(&mut buf, segment)?;
616    }
617
618    for segment in segments {
619        buf.push_str("::");
620        append(&mut buf, segment)?;
621    }
622
623    Ok((interface, WithOption::Path(buf)))
624}
625
626/// Format a valid Rust string
627fn fmt(input: &str) -> Result<String> {
628    let syntax_tree = syn::parse_file(&input)?;
629    Ok(prettyplease::unparse(&syntax_tree))
630}
631
632fn parse_async_some_field(input: ParseStream<'_>) -> Result<(AsyncConfigSomeKind, Vec<String>)> {
633    let lookahead = input.lookahead1();
634    let kind = if lookahead.peek(kw::imports) {
635        input.parse::<kw::imports>()?;
636        input.parse::<Token![:]>()?;
637        AsyncConfigSomeKind::Imports
638    } else if lookahead.peek(kw::exports) {
639        input.parse::<kw::exports>()?;
640        input.parse::<Token![:]>()?;
641        AsyncConfigSomeKind::Exports
642    } else {
643        return Err(lookahead.error());
644    };
645
646    let list;
647    syn::bracketed!(list in input);
648    let fields = list.parse_terminated(Parse::parse, Token![,])?;
649
650    Ok((
651        kind,
652        fields.iter().map(|s: &syn::LitStr| s.value()).collect(),
653    ))
654}