wiggle_generate/
config.rs

1use {
2    proc_macro2::{Span, TokenStream},
3    std::{collections::HashMap, path::PathBuf},
4    syn::{
5        braced, bracketed,
6        parse::{Parse, ParseStream},
7        punctuated::Punctuated,
8        Error, Ident, LitStr, Result, Token,
9    },
10};
11
12#[derive(Debug, Clone)]
13pub struct Config {
14    pub witx: WitxConf,
15    pub errors: ErrorConf,
16    pub async_: AsyncConf,
17    pub wasmtime: bool,
18    pub tracing: TracingConf,
19    pub mutable: bool,
20}
21
22mod kw {
23    syn::custom_keyword!(witx);
24    syn::custom_keyword!(witx_literal);
25    syn::custom_keyword!(block_on);
26    syn::custom_keyword!(errors);
27    syn::custom_keyword!(target);
28    syn::custom_keyword!(wasmtime);
29    syn::custom_keyword!(mutable);
30    syn::custom_keyword!(tracing);
31    syn::custom_keyword!(disable_for);
32    syn::custom_keyword!(trappable);
33}
34
35#[derive(Debug, Clone)]
36pub enum ConfigField {
37    Witx(WitxConf),
38    Error(ErrorConf),
39    Async(AsyncConf),
40    Wasmtime(bool),
41    Tracing(TracingConf),
42    Mutable(bool),
43}
44
45impl Parse for ConfigField {
46    fn parse(input: ParseStream) -> Result<Self> {
47        let lookahead = input.lookahead1();
48        if lookahead.peek(kw::witx) {
49            input.parse::<kw::witx>()?;
50            input.parse::<Token![:]>()?;
51            Ok(ConfigField::Witx(WitxConf::Paths(input.parse()?)))
52        } else if lookahead.peek(kw::witx_literal) {
53            input.parse::<kw::witx_literal>()?;
54            input.parse::<Token![:]>()?;
55            Ok(ConfigField::Witx(WitxConf::Literal(input.parse()?)))
56        } else if lookahead.peek(kw::errors) {
57            input.parse::<kw::errors>()?;
58            input.parse::<Token![:]>()?;
59            Ok(ConfigField::Error(input.parse()?))
60        } else if lookahead.peek(Token![async]) {
61            input.parse::<Token![async]>()?;
62            input.parse::<Token![:]>()?;
63            Ok(ConfigField::Async(AsyncConf {
64                block_with: None,
65                functions: input.parse()?,
66            }))
67        } else if lookahead.peek(kw::block_on) {
68            input.parse::<kw::block_on>()?;
69            let block_with = if input.peek(syn::token::Bracket) {
70                let content;
71                let _ = bracketed!(content in input);
72                content.parse()?
73            } else {
74                quote::quote!(wiggle::run_in_dummy_executor)
75            };
76            input.parse::<Token![:]>()?;
77            Ok(ConfigField::Async(AsyncConf {
78                block_with: Some(block_with),
79                functions: input.parse()?,
80            }))
81        } else if lookahead.peek(kw::wasmtime) {
82            input.parse::<kw::wasmtime>()?;
83            input.parse::<Token![:]>()?;
84            Ok(ConfigField::Wasmtime(input.parse::<syn::LitBool>()?.value))
85        } else if lookahead.peek(kw::tracing) {
86            input.parse::<kw::tracing>()?;
87            input.parse::<Token![:]>()?;
88            Ok(ConfigField::Tracing(input.parse()?))
89        } else if lookahead.peek(kw::mutable) {
90            input.parse::<kw::mutable>()?;
91            input.parse::<Token![:]>()?;
92            Ok(ConfigField::Mutable(input.parse::<syn::LitBool>()?.value))
93        } else {
94            Err(lookahead.error())
95        }
96    }
97}
98
99impl Config {
100    pub fn build(fields: impl Iterator<Item = ConfigField>, err_loc: Span) -> Result<Self> {
101        let mut witx = None;
102        let mut errors = None;
103        let mut async_ = None;
104        let mut wasmtime = None;
105        let mut tracing = None;
106        let mut mutable = None;
107        for f in fields {
108            match f {
109                ConfigField::Witx(c) => {
110                    if witx.is_some() {
111                        return Err(Error::new(err_loc, "duplicate `witx` field"));
112                    }
113                    witx = Some(c);
114                }
115                ConfigField::Error(c) => {
116                    if errors.is_some() {
117                        return Err(Error::new(err_loc, "duplicate `errors` field"));
118                    }
119                    errors = Some(c);
120                }
121                ConfigField::Async(c) => {
122                    if async_.is_some() {
123                        return Err(Error::new(err_loc, "duplicate `async` field"));
124                    }
125                    async_ = Some(c);
126                }
127                ConfigField::Wasmtime(c) => {
128                    if wasmtime.is_some() {
129                        return Err(Error::new(err_loc, "duplicate `wasmtime` field"));
130                    }
131                    wasmtime = Some(c);
132                }
133                ConfigField::Tracing(c) => {
134                    if tracing.is_some() {
135                        return Err(Error::new(err_loc, "duplicate `tracing` field"));
136                    }
137                    tracing = Some(c);
138                }
139                ConfigField::Mutable(c) => {
140                    if mutable.is_some() {
141                        return Err(Error::new(err_loc, "duplicate `mutable` field"));
142                    }
143                    mutable = Some(c);
144                }
145            }
146        }
147        Ok(Config {
148            witx: witx
149                .take()
150                .ok_or_else(|| Error::new(err_loc, "`witx` field required"))?,
151            errors: errors.take().unwrap_or_default(),
152            async_: async_.take().unwrap_or_default(),
153            wasmtime: wasmtime.unwrap_or(true),
154            tracing: tracing.unwrap_or_default(),
155            mutable: mutable.unwrap_or(true),
156        })
157    }
158
159    /// Load the `witx` document for the configuration.
160    ///
161    /// # Panics
162    ///
163    /// This method will panic if the paths given in the `witx` field were not valid documents.
164    pub fn load_document(&self) -> witx::Document {
165        self.witx.load_document()
166    }
167}
168
169impl Parse for Config {
170    fn parse(input: ParseStream) -> Result<Self> {
171        let contents;
172        let _lbrace = braced!(contents in input);
173        let fields: Punctuated<ConfigField, Token![,]> =
174            contents.parse_terminated(ConfigField::parse, Token![,])?;
175        Ok(Config::build(fields.into_iter(), input.span())?)
176    }
177}
178
179/// The witx document(s) that will be loaded from a [`Config`](struct.Config.html).
180///
181/// A witx interface definition can be provided either as a collection of relative paths to
182/// documents, or as a single inlined string literal. Note that `(use ...)` directives are not
183/// permitted when providing a string literal.
184#[derive(Debug, Clone)]
185pub enum WitxConf {
186    /// A collection of paths pointing to witx files.
187    Paths(Paths),
188    /// A single witx document, provided as a string literal.
189    Literal(Literal),
190}
191
192impl WitxConf {
193    /// Load the `witx` document.
194    ///
195    /// # Panics
196    ///
197    /// This method will panic if the paths given in the `witx` field were not valid documents, or
198    /// if any of the given documents were not syntactically valid.
199    pub fn load_document(&self) -> witx::Document {
200        match self {
201            Self::Paths(paths) => witx::load(paths.as_ref()).expect("loading witx"),
202            Self::Literal(doc) => witx::parse(doc.as_ref()).expect("parsing witx"),
203        }
204    }
205}
206
207/// A collection of paths, pointing to witx documents.
208#[derive(Debug, Clone)]
209pub struct Paths(Vec<PathBuf>);
210
211impl Paths {
212    /// Create a new, empty collection of paths.
213    pub fn new() -> Self {
214        Default::default()
215    }
216}
217
218impl Default for Paths {
219    fn default() -> Self {
220        Self(Default::default())
221    }
222}
223
224impl AsRef<[PathBuf]> for Paths {
225    fn as_ref(&self) -> &[PathBuf] {
226        self.0.as_ref()
227    }
228}
229
230impl AsMut<[PathBuf]> for Paths {
231    fn as_mut(&mut self) -> &mut [PathBuf] {
232        self.0.as_mut()
233    }
234}
235
236impl FromIterator<PathBuf> for Paths {
237    fn from_iter<I>(iter: I) -> Self
238    where
239        I: IntoIterator<Item = PathBuf>,
240    {
241        Self(iter.into_iter().collect())
242    }
243}
244
245impl Parse for Paths {
246    fn parse(input: ParseStream) -> Result<Self> {
247        let content;
248        let _ = bracketed!(content in input);
249        let path_lits: Punctuated<LitStr, Token![,]> =
250            content.parse_terminated(Parse::parse, Token![,])?;
251
252        let expanded_paths = path_lits
253            .iter()
254            .map(|lit| {
255                PathBuf::from(
256                    shellexpand::env(&lit.value())
257                        .expect("shell expansion")
258                        .as_ref(),
259                )
260            })
261            .collect::<Vec<PathBuf>>();
262
263        Ok(Paths(expanded_paths))
264    }
265}
266
267/// A single witx document, provided as a string literal.
268#[derive(Debug, Clone)]
269pub struct Literal(String);
270
271impl AsRef<str> for Literal {
272    fn as_ref(&self) -> &str {
273        self.0.as_ref()
274    }
275}
276
277impl Parse for Literal {
278    fn parse(input: ParseStream) -> Result<Self> {
279        Ok(Self(input.parse::<syn::LitStr>()?.value()))
280    }
281}
282
283#[derive(Clone, Default, Debug)]
284/// Map from abi error type to rich error type
285pub struct ErrorConf(HashMap<Ident, ErrorConfField>);
286
287impl ErrorConf {
288    pub fn iter(&self) -> impl Iterator<Item = (&Ident, &ErrorConfField)> {
289        self.0.iter()
290    }
291}
292
293impl Parse for ErrorConf {
294    fn parse(input: ParseStream) -> Result<Self> {
295        let content;
296        let _ = braced!(content in input);
297        let items: Punctuated<ErrorConfField, Token![,]> =
298            content.parse_terminated(Parse::parse, Token![,])?;
299        let mut m = HashMap::new();
300        for i in items {
301            match m.insert(i.abi_error().clone(), i.clone()) {
302                None => {}
303                Some(prev_def) => {
304                    return Err(Error::new(
305                        *i.err_loc(),
306                        format!(
307                        "duplicate definition of rich error type for {:?}: previously defined at {:?}",
308                        i.abi_error(), prev_def.err_loc(),
309                    ),
310                    ))
311                }
312            }
313        }
314        Ok(ErrorConf(m))
315    }
316}
317
318#[derive(Debug, Clone)]
319pub enum ErrorConfField {
320    Trappable(TrappableErrorConfField),
321    User(UserErrorConfField),
322}
323impl ErrorConfField {
324    pub fn abi_error(&self) -> &Ident {
325        match self {
326            Self::Trappable(t) => &t.abi_error,
327            Self::User(u) => &u.abi_error,
328        }
329    }
330    pub fn err_loc(&self) -> &Span {
331        match self {
332            Self::Trappable(t) => &t.err_loc,
333            Self::User(u) => &u.err_loc,
334        }
335    }
336}
337
338impl Parse for ErrorConfField {
339    fn parse(input: ParseStream) -> Result<Self> {
340        let err_loc = input.span();
341        let abi_error = input.parse::<Ident>()?;
342        let _arrow: Token![=>] = input.parse()?;
343
344        let lookahead = input.lookahead1();
345        if lookahead.peek(kw::trappable) {
346            let _ = input.parse::<kw::trappable>()?;
347            let rich_error = input.parse()?;
348            Ok(ErrorConfField::Trappable(TrappableErrorConfField {
349                abi_error,
350                rich_error,
351                err_loc,
352            }))
353        } else {
354            let rich_error = input.parse::<syn::Path>()?;
355            Ok(ErrorConfField::User(UserErrorConfField {
356                abi_error,
357                rich_error,
358                err_loc,
359            }))
360        }
361    }
362}
363
364#[derive(Clone, Debug)]
365pub struct TrappableErrorConfField {
366    pub abi_error: Ident,
367    pub rich_error: Ident,
368    pub err_loc: Span,
369}
370
371#[derive(Clone)]
372pub struct UserErrorConfField {
373    pub abi_error: Ident,
374    pub rich_error: syn::Path,
375    pub err_loc: Span,
376}
377
378impl std::fmt::Debug for UserErrorConfField {
379    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380        f.debug_struct("ErrorConfField")
381            .field("abi_error", &self.abi_error)
382            .field("rich_error", &"(...)")
383            .field("err_loc", &self.err_loc)
384            .finish()
385    }
386}
387
388#[derive(Clone, Default, Debug)]
389/// Modules and funcs that have async signatures
390pub struct AsyncConf {
391    block_with: Option<TokenStream>,
392    functions: AsyncFunctions,
393}
394
395#[derive(Clone, Debug)]
396pub enum Asyncness {
397    /// Wiggle function is synchronous, wasmtime Func is synchronous
398    Sync,
399    /// Wiggle function is asynchronous, but wasmtime Func is synchronous
400    Blocking { block_with: TokenStream },
401    /// Wiggle function and wasmtime Func are asynchronous.
402    Async,
403}
404
405impl Asyncness {
406    pub fn is_async(&self) -> bool {
407        match self {
408            Self::Async => true,
409            _ => false,
410        }
411    }
412    pub fn blocking(&self) -> Option<&TokenStream> {
413        match self {
414            Self::Blocking { block_with } => Some(block_with),
415            _ => None,
416        }
417    }
418    pub fn is_sync(&self) -> bool {
419        match self {
420            Self::Sync => true,
421            _ => false,
422        }
423    }
424}
425
426#[derive(Clone, Debug)]
427pub enum AsyncFunctions {
428    Some(HashMap<String, Vec<String>>),
429    All,
430}
431impl Default for AsyncFunctions {
432    fn default() -> Self {
433        AsyncFunctions::Some(HashMap::default())
434    }
435}
436
437impl AsyncConf {
438    pub fn get(&self, module: &str, function: &str) -> Asyncness {
439        let a = match &self.block_with {
440            Some(block_with) => Asyncness::Blocking {
441                block_with: block_with.clone(),
442            },
443            None => Asyncness::Async,
444        };
445        match &self.functions {
446            AsyncFunctions::Some(fs) => {
447                if fs
448                    .get(module)
449                    .and_then(|fs| fs.iter().find(|f| *f == function))
450                    .is_some()
451                {
452                    a
453                } else {
454                    Asyncness::Sync
455                }
456            }
457            AsyncFunctions::All => a,
458        }
459    }
460
461    pub fn contains_async(&self, module: &witx::Module) -> bool {
462        for f in module.funcs() {
463            if self.get(module.name.as_str(), f.name.as_str()).is_async() {
464                return true;
465            }
466        }
467        false
468    }
469}
470
471impl Parse for AsyncFunctions {
472    fn parse(input: ParseStream) -> Result<Self> {
473        let content;
474        let lookahead = input.lookahead1();
475        if lookahead.peek(syn::token::Brace) {
476            let _ = braced!(content in input);
477            let items: Punctuated<FunctionField, Token![,]> =
478                content.parse_terminated(Parse::parse, Token![,])?;
479            let mut functions: HashMap<String, Vec<String>> = HashMap::new();
480            use std::collections::hash_map::Entry;
481            for i in items {
482                let function_names = i
483                    .function_names
484                    .iter()
485                    .map(|i| i.to_string())
486                    .collect::<Vec<String>>();
487                match functions.entry(i.module_name.to_string()) {
488                    Entry::Occupied(o) => o.into_mut().extend(function_names),
489                    Entry::Vacant(v) => {
490                        v.insert(function_names);
491                    }
492                }
493            }
494            Ok(AsyncFunctions::Some(functions))
495        } else if lookahead.peek(Token![*]) {
496            let _: Token![*] = input.parse().unwrap();
497            Ok(AsyncFunctions::All)
498        } else {
499            Err(lookahead.error())
500        }
501    }
502}
503
504#[derive(Clone)]
505pub struct FunctionField {
506    pub module_name: Ident,
507    pub function_names: Vec<Ident>,
508    pub err_loc: Span,
509}
510
511impl Parse for FunctionField {
512    fn parse(input: ParseStream) -> Result<Self> {
513        let err_loc = input.span();
514        let module_name = input.parse::<Ident>()?;
515        let _doublecolon: Token![::] = input.parse()?;
516        let lookahead = input.lookahead1();
517        if lookahead.peek(syn::token::Brace) {
518            let content;
519            let _ = braced!(content in input);
520            let function_names: Punctuated<Ident, Token![,]> =
521                content.parse_terminated(Parse::parse, Token![,])?;
522            Ok(FunctionField {
523                module_name,
524                function_names: function_names.iter().cloned().collect(),
525                err_loc,
526            })
527        } else if lookahead.peek(Ident) {
528            let name = input.parse()?;
529            Ok(FunctionField {
530                module_name,
531                function_names: vec![name],
532                err_loc,
533            })
534        } else {
535            Err(lookahead.error())
536        }
537    }
538}
539
540#[derive(Clone)]
541pub struct WasmtimeConfig {
542    pub c: Config,
543    pub target: syn::Path,
544}
545
546#[derive(Clone)]
547pub enum WasmtimeConfigField {
548    Core(ConfigField),
549    Target(syn::Path),
550}
551impl WasmtimeConfig {
552    pub fn build(fields: impl Iterator<Item = WasmtimeConfigField>, err_loc: Span) -> Result<Self> {
553        let mut target = None;
554        let mut cs = Vec::new();
555        for f in fields {
556            match f {
557                WasmtimeConfigField::Target(c) => {
558                    if target.is_some() {
559                        return Err(Error::new(err_loc, "duplicate `target` field"));
560                    }
561                    target = Some(c);
562                }
563                WasmtimeConfigField::Core(c) => cs.push(c),
564            }
565        }
566        let c = Config::build(cs.into_iter(), err_loc)?;
567        Ok(WasmtimeConfig {
568            c,
569            target: target
570                .take()
571                .ok_or_else(|| Error::new(err_loc, "`target` field required"))?,
572        })
573    }
574}
575
576impl Parse for WasmtimeConfig {
577    fn parse(input: ParseStream) -> Result<Self> {
578        let contents;
579        let _lbrace = braced!(contents in input);
580        let fields: Punctuated<WasmtimeConfigField, Token![,]> =
581            contents.parse_terminated(WasmtimeConfigField::parse, Token![,])?;
582        Ok(WasmtimeConfig::build(fields.into_iter(), input.span())?)
583    }
584}
585
586impl Parse for WasmtimeConfigField {
587    fn parse(input: ParseStream) -> Result<Self> {
588        if input.peek(kw::target) {
589            input.parse::<kw::target>()?;
590            input.parse::<Token![:]>()?;
591            Ok(WasmtimeConfigField::Target(input.parse()?))
592        } else {
593            Ok(WasmtimeConfigField::Core(input.parse()?))
594        }
595    }
596}
597
598#[derive(Clone, Debug)]
599pub struct TracingConf {
600    enabled: bool,
601    excluded_functions: HashMap<String, Vec<String>>,
602}
603
604impl TracingConf {
605    pub fn enabled_for(&self, module: &str, function: &str) -> bool {
606        if !self.enabled {
607            return false;
608        }
609        self.excluded_functions
610            .get(module)
611            .and_then(|fs| fs.iter().find(|f| *f == function))
612            .is_none()
613    }
614}
615
616impl Default for TracingConf {
617    fn default() -> Self {
618        Self {
619            enabled: true,
620            excluded_functions: HashMap::new(),
621        }
622    }
623}
624
625impl Parse for TracingConf {
626    fn parse(input: ParseStream) -> Result<Self> {
627        let enabled = input.parse::<syn::LitBool>()?.value;
628
629        let lookahead = input.lookahead1();
630        if lookahead.peek(kw::disable_for) {
631            input.parse::<kw::disable_for>()?;
632            let content;
633            let _ = braced!(content in input);
634            let items: Punctuated<FunctionField, Token![,]> =
635                content.parse_terminated(Parse::parse, Token![,])?;
636            let mut functions: HashMap<String, Vec<String>> = HashMap::new();
637            use std::collections::hash_map::Entry;
638            for i in items {
639                let function_names = i
640                    .function_names
641                    .iter()
642                    .map(|i| i.to_string())
643                    .collect::<Vec<String>>();
644                match functions.entry(i.module_name.to_string()) {
645                    Entry::Occupied(o) => o.into_mut().extend(function_names),
646                    Entry::Vacant(v) => {
647                        v.insert(function_names);
648                    }
649                }
650            }
651
652            Ok(TracingConf {
653                enabled,
654                excluded_functions: functions,
655            })
656        } else {
657            Ok(TracingConf {
658                enabled,
659                excluded_functions: HashMap::new(),
660            })
661        }
662    }
663}