1use {
2 proc_macro2::{Span, TokenStream},
3 std::{collections::HashMap, path::PathBuf},
4 syn::{
5 braced, bracketed,
6 parse::{Parse, ParseStream},
7 punctuated::Punctuated,
8 Error, Ident, LitStr, Result, Token,
9 },
10};
11
12#[derive(Debug, Clone)]
13pub struct Config {
14 pub witx: WitxConf,
15 pub errors: ErrorConf,
16 pub async_: AsyncConf,
17 pub wasmtime: bool,
18 pub tracing: TracingConf,
19 pub mutable: bool,
20}
21
22mod kw {
23 syn::custom_keyword!(witx);
24 syn::custom_keyword!(witx_literal);
25 syn::custom_keyword!(block_on);
26 syn::custom_keyword!(errors);
27 syn::custom_keyword!(target);
28 syn::custom_keyword!(wasmtime);
29 syn::custom_keyword!(mutable);
30 syn::custom_keyword!(tracing);
31 syn::custom_keyword!(disable_for);
32 syn::custom_keyword!(trappable);
33}
34
35#[derive(Debug, Clone)]
36pub enum ConfigField {
37 Witx(WitxConf),
38 Error(ErrorConf),
39 Async(AsyncConf),
40 Wasmtime(bool),
41 Tracing(TracingConf),
42 Mutable(bool),
43}
44
45impl Parse for ConfigField {
46 fn parse(input: ParseStream) -> Result<Self> {
47 let lookahead = input.lookahead1();
48 if lookahead.peek(kw::witx) {
49 input.parse::<kw::witx>()?;
50 input.parse::<Token![:]>()?;
51 Ok(ConfigField::Witx(WitxConf::Paths(input.parse()?)))
52 } else if lookahead.peek(kw::witx_literal) {
53 input.parse::<kw::witx_literal>()?;
54 input.parse::<Token![:]>()?;
55 Ok(ConfigField::Witx(WitxConf::Literal(input.parse()?)))
56 } else if lookahead.peek(kw::errors) {
57 input.parse::<kw::errors>()?;
58 input.parse::<Token![:]>()?;
59 Ok(ConfigField::Error(input.parse()?))
60 } else if lookahead.peek(Token![async]) {
61 input.parse::<Token![async]>()?;
62 input.parse::<Token![:]>()?;
63 Ok(ConfigField::Async(AsyncConf {
64 block_with: None,
65 functions: input.parse()?,
66 }))
67 } else if lookahead.peek(kw::block_on) {
68 input.parse::<kw::block_on>()?;
69 let block_with = if input.peek(syn::token::Bracket) {
70 let content;
71 let _ = bracketed!(content in input);
72 content.parse()?
73 } else {
74 quote::quote!(wiggle::run_in_dummy_executor)
75 };
76 input.parse::<Token![:]>()?;
77 Ok(ConfigField::Async(AsyncConf {
78 block_with: Some(block_with),
79 functions: input.parse()?,
80 }))
81 } else if lookahead.peek(kw::wasmtime) {
82 input.parse::<kw::wasmtime>()?;
83 input.parse::<Token![:]>()?;
84 Ok(ConfigField::Wasmtime(input.parse::<syn::LitBool>()?.value))
85 } else if lookahead.peek(kw::tracing) {
86 input.parse::<kw::tracing>()?;
87 input.parse::<Token![:]>()?;
88 Ok(ConfigField::Tracing(input.parse()?))
89 } else if lookahead.peek(kw::mutable) {
90 input.parse::<kw::mutable>()?;
91 input.parse::<Token![:]>()?;
92 Ok(ConfigField::Mutable(input.parse::<syn::LitBool>()?.value))
93 } else {
94 Err(lookahead.error())
95 }
96 }
97}
98
99impl Config {
100 pub fn build(fields: impl Iterator<Item = ConfigField>, err_loc: Span) -> Result<Self> {
101 let mut witx = None;
102 let mut errors = None;
103 let mut async_ = None;
104 let mut wasmtime = None;
105 let mut tracing = None;
106 let mut mutable = None;
107 for f in fields {
108 match f {
109 ConfigField::Witx(c) => {
110 if witx.is_some() {
111 return Err(Error::new(err_loc, "duplicate `witx` field"));
112 }
113 witx = Some(c);
114 }
115 ConfigField::Error(c) => {
116 if errors.is_some() {
117 return Err(Error::new(err_loc, "duplicate `errors` field"));
118 }
119 errors = Some(c);
120 }
121 ConfigField::Async(c) => {
122 if async_.is_some() {
123 return Err(Error::new(err_loc, "duplicate `async` field"));
124 }
125 async_ = Some(c);
126 }
127 ConfigField::Wasmtime(c) => {
128 if wasmtime.is_some() {
129 return Err(Error::new(err_loc, "duplicate `wasmtime` field"));
130 }
131 wasmtime = Some(c);
132 }
133 ConfigField::Tracing(c) => {
134 if tracing.is_some() {
135 return Err(Error::new(err_loc, "duplicate `tracing` field"));
136 }
137 tracing = Some(c);
138 }
139 ConfigField::Mutable(c) => {
140 if mutable.is_some() {
141 return Err(Error::new(err_loc, "duplicate `mutable` field"));
142 }
143 mutable = Some(c);
144 }
145 }
146 }
147 Ok(Config {
148 witx: witx
149 .take()
150 .ok_or_else(|| Error::new(err_loc, "`witx` field required"))?,
151 errors: errors.take().unwrap_or_default(),
152 async_: async_.take().unwrap_or_default(),
153 wasmtime: wasmtime.unwrap_or(true),
154 tracing: tracing.unwrap_or_default(),
155 mutable: mutable.unwrap_or(true),
156 })
157 }
158
159 pub fn load_document(&self) -> witx::Document {
165 self.witx.load_document()
166 }
167}
168
169impl Parse for Config {
170 fn parse(input: ParseStream) -> Result<Self> {
171 let contents;
172 let _lbrace = braced!(contents in input);
173 let fields: Punctuated<ConfigField, Token![,]> =
174 contents.parse_terminated(ConfigField::parse, Token![,])?;
175 Ok(Config::build(fields.into_iter(), input.span())?)
176 }
177}
178
179#[derive(Debug, Clone)]
185pub enum WitxConf {
186 Paths(Paths),
188 Literal(Literal),
190}
191
192impl WitxConf {
193 pub fn load_document(&self) -> witx::Document {
200 match self {
201 Self::Paths(paths) => witx::load(paths.as_ref()).expect("loading witx"),
202 Self::Literal(doc) => witx::parse(doc.as_ref()).expect("parsing witx"),
203 }
204 }
205}
206
207#[derive(Debug, Clone)]
209pub struct Paths(Vec<PathBuf>);
210
211impl Paths {
212 pub fn new() -> Self {
214 Default::default()
215 }
216}
217
218impl Default for Paths {
219 fn default() -> Self {
220 Self(Default::default())
221 }
222}
223
224impl AsRef<[PathBuf]> for Paths {
225 fn as_ref(&self) -> &[PathBuf] {
226 self.0.as_ref()
227 }
228}
229
230impl AsMut<[PathBuf]> for Paths {
231 fn as_mut(&mut self) -> &mut [PathBuf] {
232 self.0.as_mut()
233 }
234}
235
236impl FromIterator<PathBuf> for Paths {
237 fn from_iter<I>(iter: I) -> Self
238 where
239 I: IntoIterator<Item = PathBuf>,
240 {
241 Self(iter.into_iter().collect())
242 }
243}
244
245impl Parse for Paths {
246 fn parse(input: ParseStream) -> Result<Self> {
247 let content;
248 let _ = bracketed!(content in input);
249 let path_lits: Punctuated<LitStr, Token![,]> =
250 content.parse_terminated(Parse::parse, Token![,])?;
251
252 let expanded_paths = path_lits
253 .iter()
254 .map(|lit| {
255 PathBuf::from(
256 shellexpand::env(&lit.value())
257 .expect("shell expansion")
258 .as_ref(),
259 )
260 })
261 .collect::<Vec<PathBuf>>();
262
263 Ok(Paths(expanded_paths))
264 }
265}
266
267#[derive(Debug, Clone)]
269pub struct Literal(String);
270
271impl AsRef<str> for Literal {
272 fn as_ref(&self) -> &str {
273 self.0.as_ref()
274 }
275}
276
277impl Parse for Literal {
278 fn parse(input: ParseStream) -> Result<Self> {
279 Ok(Self(input.parse::<syn::LitStr>()?.value()))
280 }
281}
282
283#[derive(Clone, Default, Debug)]
284pub struct ErrorConf(HashMap<Ident, ErrorConfField>);
286
287impl ErrorConf {
288 pub fn iter(&self) -> impl Iterator<Item = (&Ident, &ErrorConfField)> {
289 self.0.iter()
290 }
291}
292
293impl Parse for ErrorConf {
294 fn parse(input: ParseStream) -> Result<Self> {
295 let content;
296 let _ = braced!(content in input);
297 let items: Punctuated<ErrorConfField, Token![,]> =
298 content.parse_terminated(Parse::parse, Token![,])?;
299 let mut m = HashMap::new();
300 for i in items {
301 match m.insert(i.abi_error().clone(), i.clone()) {
302 None => {}
303 Some(prev_def) => {
304 return Err(Error::new(
305 *i.err_loc(),
306 format!(
307 "duplicate definition of rich error type for {:?}: previously defined at {:?}",
308 i.abi_error(), prev_def.err_loc(),
309 ),
310 ))
311 }
312 }
313 }
314 Ok(ErrorConf(m))
315 }
316}
317
318#[derive(Debug, Clone)]
319pub enum ErrorConfField {
320 Trappable(TrappableErrorConfField),
321 User(UserErrorConfField),
322}
323impl ErrorConfField {
324 pub fn abi_error(&self) -> &Ident {
325 match self {
326 Self::Trappable(t) => &t.abi_error,
327 Self::User(u) => &u.abi_error,
328 }
329 }
330 pub fn err_loc(&self) -> &Span {
331 match self {
332 Self::Trappable(t) => &t.err_loc,
333 Self::User(u) => &u.err_loc,
334 }
335 }
336}
337
338impl Parse for ErrorConfField {
339 fn parse(input: ParseStream) -> Result<Self> {
340 let err_loc = input.span();
341 let abi_error = input.parse::<Ident>()?;
342 let _arrow: Token![=>] = input.parse()?;
343
344 let lookahead = input.lookahead1();
345 if lookahead.peek(kw::trappable) {
346 let _ = input.parse::<kw::trappable>()?;
347 let rich_error = input.parse()?;
348 Ok(ErrorConfField::Trappable(TrappableErrorConfField {
349 abi_error,
350 rich_error,
351 err_loc,
352 }))
353 } else {
354 let rich_error = input.parse::<syn::Path>()?;
355 Ok(ErrorConfField::User(UserErrorConfField {
356 abi_error,
357 rich_error,
358 err_loc,
359 }))
360 }
361 }
362}
363
364#[derive(Clone, Debug)]
365pub struct TrappableErrorConfField {
366 pub abi_error: Ident,
367 pub rich_error: Ident,
368 pub err_loc: Span,
369}
370
371#[derive(Clone)]
372pub struct UserErrorConfField {
373 pub abi_error: Ident,
374 pub rich_error: syn::Path,
375 pub err_loc: Span,
376}
377
378impl std::fmt::Debug for UserErrorConfField {
379 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380 f.debug_struct("ErrorConfField")
381 .field("abi_error", &self.abi_error)
382 .field("rich_error", &"(...)")
383 .field("err_loc", &self.err_loc)
384 .finish()
385 }
386}
387
388#[derive(Clone, Default, Debug)]
389pub struct AsyncConf {
391 block_with: Option<TokenStream>,
392 functions: AsyncFunctions,
393}
394
395#[derive(Clone, Debug)]
396pub enum Asyncness {
397 Sync,
399 Blocking { block_with: TokenStream },
401 Async,
403}
404
405impl Asyncness {
406 pub fn is_async(&self) -> bool {
407 match self {
408 Self::Async => true,
409 _ => false,
410 }
411 }
412 pub fn blocking(&self) -> Option<&TokenStream> {
413 match self {
414 Self::Blocking { block_with } => Some(block_with),
415 _ => None,
416 }
417 }
418 pub fn is_sync(&self) -> bool {
419 match self {
420 Self::Sync => true,
421 _ => false,
422 }
423 }
424}
425
426#[derive(Clone, Debug)]
427pub enum AsyncFunctions {
428 Some(HashMap<String, Vec<String>>),
429 All,
430}
431impl Default for AsyncFunctions {
432 fn default() -> Self {
433 AsyncFunctions::Some(HashMap::default())
434 }
435}
436
437impl AsyncConf {
438 pub fn get(&self, module: &str, function: &str) -> Asyncness {
439 let a = match &self.block_with {
440 Some(block_with) => Asyncness::Blocking {
441 block_with: block_with.clone(),
442 },
443 None => Asyncness::Async,
444 };
445 match &self.functions {
446 AsyncFunctions::Some(fs) => {
447 if fs
448 .get(module)
449 .and_then(|fs| fs.iter().find(|f| *f == function))
450 .is_some()
451 {
452 a
453 } else {
454 Asyncness::Sync
455 }
456 }
457 AsyncFunctions::All => a,
458 }
459 }
460
461 pub fn contains_async(&self, module: &witx::Module) -> bool {
462 for f in module.funcs() {
463 if self.get(module.name.as_str(), f.name.as_str()).is_async() {
464 return true;
465 }
466 }
467 false
468 }
469}
470
471impl Parse for AsyncFunctions {
472 fn parse(input: ParseStream) -> Result<Self> {
473 let content;
474 let lookahead = input.lookahead1();
475 if lookahead.peek(syn::token::Brace) {
476 let _ = braced!(content in input);
477 let items: Punctuated<FunctionField, Token![,]> =
478 content.parse_terminated(Parse::parse, Token![,])?;
479 let mut functions: HashMap<String, Vec<String>> = HashMap::new();
480 use std::collections::hash_map::Entry;
481 for i in items {
482 let function_names = i
483 .function_names
484 .iter()
485 .map(|i| i.to_string())
486 .collect::<Vec<String>>();
487 match functions.entry(i.module_name.to_string()) {
488 Entry::Occupied(o) => o.into_mut().extend(function_names),
489 Entry::Vacant(v) => {
490 v.insert(function_names);
491 }
492 }
493 }
494 Ok(AsyncFunctions::Some(functions))
495 } else if lookahead.peek(Token![*]) {
496 let _: Token![*] = input.parse().unwrap();
497 Ok(AsyncFunctions::All)
498 } else {
499 Err(lookahead.error())
500 }
501 }
502}
503
504#[derive(Clone)]
505pub struct FunctionField {
506 pub module_name: Ident,
507 pub function_names: Vec<Ident>,
508 pub err_loc: Span,
509}
510
511impl Parse for FunctionField {
512 fn parse(input: ParseStream) -> Result<Self> {
513 let err_loc = input.span();
514 let module_name = input.parse::<Ident>()?;
515 let _doublecolon: Token![::] = input.parse()?;
516 let lookahead = input.lookahead1();
517 if lookahead.peek(syn::token::Brace) {
518 let content;
519 let _ = braced!(content in input);
520 let function_names: Punctuated<Ident, Token![,]> =
521 content.parse_terminated(Parse::parse, Token![,])?;
522 Ok(FunctionField {
523 module_name,
524 function_names: function_names.iter().cloned().collect(),
525 err_loc,
526 })
527 } else if lookahead.peek(Ident) {
528 let name = input.parse()?;
529 Ok(FunctionField {
530 module_name,
531 function_names: vec![name],
532 err_loc,
533 })
534 } else {
535 Err(lookahead.error())
536 }
537 }
538}
539
540#[derive(Clone)]
541pub struct WasmtimeConfig {
542 pub c: Config,
543 pub target: syn::Path,
544}
545
546#[derive(Clone)]
547pub enum WasmtimeConfigField {
548 Core(ConfigField),
549 Target(syn::Path),
550}
551impl WasmtimeConfig {
552 pub fn build(fields: impl Iterator<Item = WasmtimeConfigField>, err_loc: Span) -> Result<Self> {
553 let mut target = None;
554 let mut cs = Vec::new();
555 for f in fields {
556 match f {
557 WasmtimeConfigField::Target(c) => {
558 if target.is_some() {
559 return Err(Error::new(err_loc, "duplicate `target` field"));
560 }
561 target = Some(c);
562 }
563 WasmtimeConfigField::Core(c) => cs.push(c),
564 }
565 }
566 let c = Config::build(cs.into_iter(), err_loc)?;
567 Ok(WasmtimeConfig {
568 c,
569 target: target
570 .take()
571 .ok_or_else(|| Error::new(err_loc, "`target` field required"))?,
572 })
573 }
574}
575
576impl Parse for WasmtimeConfig {
577 fn parse(input: ParseStream) -> Result<Self> {
578 let contents;
579 let _lbrace = braced!(contents in input);
580 let fields: Punctuated<WasmtimeConfigField, Token![,]> =
581 contents.parse_terminated(WasmtimeConfigField::parse, Token![,])?;
582 Ok(WasmtimeConfig::build(fields.into_iter(), input.span())?)
583 }
584}
585
586impl Parse for WasmtimeConfigField {
587 fn parse(input: ParseStream) -> Result<Self> {
588 if input.peek(kw::target) {
589 input.parse::<kw::target>()?;
590 input.parse::<Token![:]>()?;
591 Ok(WasmtimeConfigField::Target(input.parse()?))
592 } else {
593 Ok(WasmtimeConfigField::Core(input.parse()?))
594 }
595 }
596}
597
598#[derive(Clone, Debug)]
599pub struct TracingConf {
600 enabled: bool,
601 excluded_functions: HashMap<String, Vec<String>>,
602}
603
604impl TracingConf {
605 pub fn enabled_for(&self, module: &str, function: &str) -> bool {
606 if !self.enabled {
607 return false;
608 }
609 self.excluded_functions
610 .get(module)
611 .and_then(|fs| fs.iter().find(|f| *f == function))
612 .is_none()
613 }
614}
615
616impl Default for TracingConf {
617 fn default() -> Self {
618 Self {
619 enabled: true,
620 excluded_functions: HashMap::new(),
621 }
622 }
623}
624
625impl Parse for TracingConf {
626 fn parse(input: ParseStream) -> Result<Self> {
627 let enabled = input.parse::<syn::LitBool>()?.value;
628
629 let lookahead = input.lookahead1();
630 if lookahead.peek(kw::disable_for) {
631 input.parse::<kw::disable_for>()?;
632 let content;
633 let _ = braced!(content in input);
634 let items: Punctuated<FunctionField, Token![,]> =
635 content.parse_terminated(Parse::parse, Token![,])?;
636 let mut functions: HashMap<String, Vec<String>> = HashMap::new();
637 use std::collections::hash_map::Entry;
638 for i in items {
639 let function_names = i
640 .function_names
641 .iter()
642 .map(|i| i.to_string())
643 .collect::<Vec<String>>();
644 match functions.entry(i.module_name.to_string()) {
645 Entry::Occupied(o) => o.into_mut().extend(function_names),
646 Entry::Vacant(v) => {
647 v.insert(function_names);
648 }
649 }
650 }
651
652 Ok(TracingConf {
653 enabled,
654 excluded_functions: functions,
655 })
656 } else {
657 Ok(TracingConf {
658 enabled,
659 excluded_functions: HashMap::new(),
660 })
661 }
662 }
663}