intuicio_parser/
extension.rs

1use crate::{ParseResult, Parser, ParserExt, ParserHandle, ParserRegistry};
2use std::sync::Arc;
3
4pub mod shorthand {
5    use super::*;
6
7    pub fn ext<T: Send + Sync + 'static>(
8        f: impl Fn(Arc<T>) -> ParserHandle + Send + Sync + 'static,
9    ) -> ParserHandle {
10        ExtensionParser::new(f).into_handle()
11    }
12}
13
14#[derive(Clone)]
15pub struct ExtensionParser<T: Send + Sync + 'static> {
16    parser_generator: Arc<dyn Fn(Arc<T>) -> ParserHandle + Send + Sync>,
17}
18
19impl<T: Send + Sync + 'static> ExtensionParser<T> {
20    pub fn new(f: impl Fn(Arc<T>) -> ParserHandle + Send + Sync + 'static) -> Self {
21        Self {
22            parser_generator: Arc::new(f),
23        }
24    }
25}
26
27impl<T: Send + Sync + 'static> Parser for ExtensionParser<T> {
28    fn parse<'a>(&self, registry: &ParserRegistry, input: &'a str) -> ParseResult<'a> {
29        if let Some(extension) = registry.extension::<T>() {
30            (self.parser_generator)(extension).parse(registry, input)
31        } else {
32            Err("Could not get ExtensionParser extension!".into())
33        }
34    }
35}
36
37#[cfg(test)]
38mod tests {
39    use std::sync::RwLock;
40
41    use crate::{
42        extension::ExtensionParser,
43        shorthand::{ext, lit},
44        ParserRegistry,
45    };
46
47    fn is_async<T: Send + Sync>() {}
48
49    #[derive(Default)]
50    struct Extension {
51        pub counter: RwLock<usize>,
52    }
53
54    #[test]
55    fn test_extension() {
56        is_async::<ExtensionParser<()>>();
57
58        let registry = ParserRegistry::default().with_extension(Extension::default());
59        let parser = ext::<Extension>(|extension| {
60            *extension.counter.write().unwrap() += 1;
61            lit("foo")
62        });
63        let (rest, _) = parser.parse(&registry, "foo").unwrap();
64        assert_eq!(rest, "");
65        assert_eq!(
66            *registry
67                .extension::<Extension>()
68                .unwrap()
69                .counter
70                .read()
71                .unwrap(),
72            1
73        );
74    }
75}