intuicio_parser/
pratt.rs

1use crate::{ParseResult, Parser, ParserExt, ParserHandle, ParserOutput, ParserRegistry};
2use std::{error::Error, sync::Arc};
3
4pub mod shorthand {
5    use super::*;
6
7    pub fn pratt(tokenizer_parser: ParserHandle, rules: Vec<Vec<PrattParserRule>>) -> ParserHandle {
8        let mut result = PrattParser::new(tokenizer_parser);
9        for rule in rules {
10            result.push_rules(rule);
11        }
12        result.into_handle()
13    }
14}
15
16#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
17pub enum PrattParserAssociativity {
18    #[default]
19    Left,
20    Right,
21}
22
23#[derive(Clone)]
24pub enum PrattParserRule {
25    Prefix {
26        operator: Arc<dyn Fn(&ParserOutput) -> bool + Send + Sync>,
27        transformer: Arc<dyn Fn(ParserOutput) -> ParserOutput + Send + Sync>,
28    },
29    Postfix {
30        operator: Arc<dyn Fn(&ParserOutput) -> bool + Send + Sync>,
31        transformer: Arc<dyn Fn(ParserOutput) -> ParserOutput + Send + Sync>,
32    },
33    Infix {
34        operator: Arc<dyn Fn(&ParserOutput) -> bool + Send + Sync>,
35        transformer: Arc<dyn Fn(ParserOutput, ParserOutput) -> ParserOutput + Send + Sync>,
36        associativity: PrattParserAssociativity,
37    },
38}
39
40impl PrattParserRule {
41    pub fn prefx_raw(
42        operator: impl Fn(&ParserOutput) -> bool + Send + Sync + 'static,
43        transformer: impl Fn(ParserOutput) -> ParserOutput + Send + Sync + 'static,
44    ) -> Self {
45        Self::Prefix {
46            operator: Arc::new(operator),
47            transformer: Arc::new(transformer),
48        }
49    }
50
51    pub fn prefix<O: PartialEq + Send + Sync + 'static, V: Send + Sync + 'static>(
52        operator: O,
53        transformer: impl Fn(V) -> V + Send + Sync + 'static,
54    ) -> Self {
55        Self::prefx_raw(
56            move |token| {
57                token
58                    .read::<O>()
59                    .map(|op| *op == operator)
60                    .unwrap_or_default()
61            },
62            move |value| {
63                let value = value.consume::<V>().ok().unwrap();
64                let result = (transformer)(value);
65                ParserOutput::new(result).ok().unwrap()
66            },
67        )
68    }
69
70    pub fn postfix_raw(
71        operator: impl Fn(&ParserOutput) -> bool + Send + Sync + 'static,
72        transformer: impl Fn(ParserOutput) -> ParserOutput + Send + Sync + 'static,
73    ) -> Self {
74        Self::Postfix {
75            operator: Arc::new(operator),
76            transformer: Arc::new(transformer),
77        }
78    }
79
80    pub fn postfix<O: PartialEq + Send + Sync + 'static, V: Send + Sync + 'static>(
81        operator: O,
82        transformer: impl Fn(V) -> V + Send + Sync + 'static,
83    ) -> Self {
84        Self::postfix_raw(
85            move |token| {
86                token
87                    .read::<O>()
88                    .map(|op| *op == operator)
89                    .unwrap_or_default()
90            },
91            move |value| {
92                let value = value.consume::<V>().ok().unwrap();
93                let result = (transformer)(value);
94                ParserOutput::new(result).ok().unwrap()
95            },
96        )
97    }
98
99    pub fn infix_raw(
100        operator: impl Fn(&ParserOutput) -> bool + Send + Sync + 'static,
101        transformer: impl Fn(ParserOutput, ParserOutput) -> ParserOutput + Send + Sync + 'static,
102        associativity: PrattParserAssociativity,
103    ) -> Self {
104        Self::Infix {
105            operator: Arc::new(operator),
106            transformer: Arc::new(transformer),
107            associativity,
108        }
109    }
110
111    pub fn infix<O: PartialEq + Send + Sync + 'static, V: Send + Sync + 'static>(
112        operator: O,
113        transformer: impl Fn(V, V) -> V + Send + Sync + 'static,
114        associativity: PrattParserAssociativity,
115    ) -> Self {
116        Self::infix_raw(
117            move |token| {
118                token
119                    .read::<O>()
120                    .map(|op| *op == operator)
121                    .unwrap_or_default()
122            },
123            move |lhs, rhs| {
124                let lhs = lhs.consume::<V>().ok().unwrap();
125                let rhs = rhs.consume::<V>().ok().unwrap();
126                let result = (transformer)(lhs, rhs);
127                ParserOutput::new(result).ok().unwrap()
128            },
129            associativity,
130        )
131    }
132
133    fn flip_binding_power(&self) -> bool {
134        matches!(
135            self,
136            Self::Infix {
137                associativity: PrattParserAssociativity::Right,
138                ..
139            }
140        )
141    }
142}
143
144#[derive(Clone)]
145pub struct PrattParser {
146    tokenizer_parser: ParserHandle,
147    /// [(rule, left binding power, right binding power)]
148    rules: Vec<(PrattParserRule, usize, usize)>,
149    binding_power_generator: usize,
150}
151
152impl PrattParser {
153    pub fn new(tokenizer_parser: ParserHandle) -> Self {
154        Self {
155            tokenizer_parser,
156            rules: vec![],
157            binding_power_generator: 0,
158        }
159    }
160
161    pub fn with_rules(mut self, rules: impl IntoIterator<Item = PrattParserRule>) -> Self {
162        self.push_rules(rules);
163        self
164    }
165
166    pub fn push_rules(&mut self, rules: impl IntoIterator<Item = PrattParserRule>) {
167        let low = self.binding_power_generator + 1;
168        let high = self.binding_power_generator + 2;
169        self.binding_power_generator += 2;
170        for rule in rules {
171            if rule.flip_binding_power() {
172                self.rules.push((rule, high, low));
173            } else {
174                self.rules.push((rule, low, high));
175            }
176        }
177    }
178
179    fn parse_inner(
180        &self,
181        tokens: &mut Vec<ParserOutput>,
182        min_bp: usize,
183    ) -> Result<ParserOutput, Box<dyn Error>> {
184        let Some(mut lhs) = tokens.pop() else {
185            return Err("Expected LHS token value".into());
186        };
187        if let Some((rule, _, rbp)) = self.find_prefix_rule(&lhs) {
188            let rhs = self.parse_inner(tokens, rbp)?;
189            if let PrattParserRule::Prefix { transformer, .. } = rule {
190                lhs = (*transformer)(rhs);
191            } else {
192                return Err("Expected prefix rule".into());
193            }
194        }
195        while let Some(op) = tokens.pop() {
196            if let Some((rule, lbp, _)) = self.find_postfix_rule(&op) {
197                if lbp < min_bp {
198                    tokens.push(op);
199                    break;
200                }
201                if let PrattParserRule::Postfix { transformer, .. } = rule {
202                    lhs = (*transformer)(lhs);
203                } else {
204                    return Err("Expected postfix rule".into());
205                }
206                continue;
207            }
208            if let Some((rule, lbp, rbp)) = self.find_infix_rule(&op) {
209                if lbp < min_bp {
210                    tokens.push(op);
211                    break;
212                }
213                let rhs = self.parse_inner(tokens, rbp)?;
214                if let PrattParserRule::Infix { transformer, .. } = rule {
215                    lhs = (*transformer)(lhs, rhs);
216                } else {
217                    return Err("Expected infix rule".into());
218                }
219                continue;
220            }
221            tokens.push(op);
222            break;
223        }
224        Ok(lhs)
225    }
226
227    /// (rule, _, right binding power)
228    fn find_prefix_rule(&self, token: &ParserOutput) -> Option<(&PrattParserRule, (), usize)> {
229        self.rules
230            .iter()
231            .find(|(rule, _, _)| match rule {
232                PrattParserRule::Prefix { operator, .. } => (*operator)(token),
233                _ => false,
234            })
235            .map(|(rule, _, rbp)| (rule, (), *rbp))
236    }
237
238    /// (rule, left binding power, _)
239    fn find_postfix_rule(&self, token: &ParserOutput) -> Option<(&PrattParserRule, usize, ())> {
240        self.rules
241            .iter()
242            .find(|(rule, _, _)| match rule {
243                PrattParserRule::Postfix { operator, .. } => (*operator)(token),
244                _ => false,
245            })
246            .map(|(rule, lbp, _)| (rule, *lbp, ()))
247    }
248
249    /// (rule, left binding power, right binding power)
250    fn find_infix_rule(&self, token: &ParserOutput) -> Option<(&PrattParserRule, usize, usize)> {
251        self.rules
252            .iter()
253            .find(|(rule, _, _)| match rule {
254                PrattParserRule::Infix { operator, .. } => (*operator)(token),
255                _ => false,
256            })
257            .map(|(rule, lbp, rbp)| (rule, *lbp, *rbp))
258    }
259}
260
261impl Parser for PrattParser {
262    fn parse<'a>(&self, registry: &ParserRegistry, input: &'a str) -> ParseResult<'a> {
263        let (input, result) = self.tokenizer_parser.parse(registry, input)?;
264        let mut tokens = match result.consume::<Vec<ParserOutput>>() {
265            Ok(tokens) => tokens,
266            Err(_) => {
267                return Err("PrattParser expects `Vec<ParserOutput>` tokenization result".into())
268            }
269        };
270        tokens.reverse();
271        let result = self.parse_inner(&mut tokens, 0)?;
272        if !tokens.is_empty() {
273            return Err("PrattParser did not consumed all tokens".into());
274        }
275        Ok((input, result))
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use crate::{
282        pratt::{PrattParser, PrattParserAssociativity, PrattParserRule},
283        shorthand::{
284            alt, inject, list, lit, map, map_err, number_float, oc, ows, pratt, prefix, suffix,
285        },
286        ParserHandle, ParserRegistry,
287    };
288
289    #[derive(Debug, Clone, Copy, PartialEq, Eq)]
290    enum Operator {
291        Add,
292        Sub,
293        Mul,
294        Div,
295        // takes integer part.
296        Hash,
297        // takes fractional part.
298        Bang,
299    }
300
301    impl std::fmt::Display for Operator {
302        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303            match self {
304                Self::Add => write!(f, "+"),
305                Self::Sub => write!(f, "-"),
306                Self::Mul => write!(f, "*"),
307                Self::Div => write!(f, "/"),
308                Self::Hash => write!(f, "#"),
309                Self::Bang => write!(f, "!"),
310            }
311        }
312    }
313
314    #[derive(Debug)]
315    enum Expression {
316        Number(f32),
317        UnaryOperation {
318            op: Operator,
319            value: Box<Expression>,
320        },
321        BinaryOperation {
322            op: Operator,
323            lhs: Box<Expression>,
324            rhs: Box<Expression>,
325        },
326    }
327
328    impl Expression {
329        fn eval(&self) -> f32 {
330            match self {
331                Self::Number(value) => *value,
332                Self::UnaryOperation { op, value } => match op {
333                    Operator::Hash => value.eval().floor(),
334                    Operator::Bang => value.eval().fract(),
335                    _ => unreachable!(),
336                },
337                Self::BinaryOperation { op, lhs, rhs } => match op {
338                    Operator::Add => lhs.eval() + rhs.eval(),
339                    Operator::Sub => lhs.eval() - rhs.eval(),
340                    Operator::Mul => lhs.eval() * rhs.eval(),
341                    Operator::Div => lhs.eval() / rhs.eval(),
342                    _ => unreachable!(),
343                },
344            }
345        }
346    }
347
348    impl std::fmt::Display for Expression {
349        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350            match self {
351                Self::Number(value) => write!(f, "{}", value),
352                Self::UnaryOperation { value, op } => write!(f, "({} {})", op, value),
353                Self::BinaryOperation { op, lhs, rhs } => write!(f, "({} {} {})", op, lhs, rhs),
354            }
355        }
356    }
357
358    fn number() -> ParserHandle {
359        map_err(
360            map(number_float(), |value: String| {
361                Expression::Number(value.parse().unwrap())
362            }),
363            |_| "Expected number".into(),
364        )
365    }
366
367    fn op() -> ParserHandle {
368        map_err(
369            map(
370                alt([lit("+"), lit("-"), lit("*"), lit("/"), lit("#"), lit("!")]),
371                |value: String| match value.as_str() {
372                    "+" => Operator::Add,
373                    "-" => Operator::Sub,
374                    "*" => Operator::Mul,
375                    "/" => Operator::Div,
376                    "#" => Operator::Hash,
377                    "!" => Operator::Bang,
378                    _ => unreachable!(),
379                },
380            ),
381            |_| "Expected operator".into(),
382        )
383    }
384
385    fn sub_expr() -> ParserHandle {
386        map_err(
387            oc(
388                inject("expr"),
389                suffix(lit("("), ows()),
390                prefix(lit(")"), ows()),
391            ),
392            |_| "Expected sub-expression".into(),
393        )
394    }
395
396    fn item() -> ParserHandle {
397        alt([inject("number"), inject("op"), inject("sub_expr")])
398    }
399
400    fn expr_tokenizer() -> ParserHandle {
401        list(inject("item"), ows(), true)
402    }
403
404    fn expr() -> ParserHandle {
405        pratt(
406            inject("expr_tokenizer"),
407            vec![
408                vec![
409                    PrattParserRule::infix(
410                        Operator::Add,
411                        |lhs, rhs| Expression::BinaryOperation {
412                            op: Operator::Add,
413                            lhs: Box::new(lhs),
414                            rhs: Box::new(rhs),
415                        },
416                        PrattParserAssociativity::Left,
417                    ),
418                    PrattParserRule::infix(
419                        Operator::Sub,
420                        |lhs, rhs| Expression::BinaryOperation {
421                            op: Operator::Sub,
422                            lhs: Box::new(lhs),
423                            rhs: Box::new(rhs),
424                        },
425                        PrattParserAssociativity::Left,
426                    ),
427                ],
428                vec![
429                    PrattParserRule::infix(
430                        Operator::Mul,
431                        |lhs, rhs| Expression::BinaryOperation {
432                            op: Operator::Mul,
433                            lhs: Box::new(lhs),
434                            rhs: Box::new(rhs),
435                        },
436                        PrattParserAssociativity::Left,
437                    ),
438                    PrattParserRule::infix(
439                        Operator::Div,
440                        |lhs, rhs| Expression::BinaryOperation {
441                            op: Operator::Div,
442                            lhs: Box::new(lhs),
443                            rhs: Box::new(rhs),
444                        },
445                        PrattParserAssociativity::Left,
446                    ),
447                ],
448                vec![PrattParserRule::prefix(Operator::Hash, |value| {
449                    Expression::UnaryOperation {
450                        op: Operator::Hash,
451                        value: Box::new(value),
452                    }
453                })],
454                vec![PrattParserRule::postfix(Operator::Bang, |value| {
455                    Expression::UnaryOperation {
456                        op: Operator::Bang,
457                        value: Box::new(value),
458                    }
459                })],
460            ],
461        )
462    }
463
464    fn is_async<T: Send + Sync>() {}
465
466    #[test]
467    fn test_pratt() {
468        is_async::<PrattParser>();
469
470        let registry = ParserRegistry::default()
471            .with_parser("number", number())
472            .with_parser("op", op())
473            .with_parser("sub_expr", sub_expr())
474            .with_parser("item", item())
475            .with_parser("expr_tokenizer", expr_tokenizer())
476            .with_parser("expr", expr());
477        let (rest, result) = registry.parse("expr", "(((0)))").unwrap();
478        assert_eq!(rest, "");
479        let result = result.consume::<Expression>().ok().unwrap();
480        assert_eq!(result.to_string(), "0");
481        assert_eq!(result.eval(), 0.0);
482        let (rest, result) = registry.parse("expr", "(3 + 4) * 2 - 1 / 5").unwrap();
483        assert_eq!(rest, "");
484        let result = result.consume::<Expression>().ok().unwrap();
485        assert_eq!(result.to_string(), "(- (* (+ 3 4) 2) (/ 1 5))");
486        assert_eq!(result.eval(), 13.8);
487        let (rest, result) = registry.parse("expr", "#1.2 + 3.4!").unwrap();
488        assert_eq!(rest, "");
489        let result = result.consume::<Expression>().ok().unwrap();
490        assert_eq!(result.to_string(), "(+ (# 1.2) (! 3.4))");
491        assert_eq!(result.eval(), 1.4000001);
492        let (rest, result) = registry.parse("expr", "#(1.2 - 3.4)!").unwrap();
493        assert_eq!(rest, "");
494        let result = result.consume::<Expression>().ok().unwrap();
495        assert_eq!(result.to_string(), "(# (! (- 1.2 3.4)))");
496        assert_eq!(result.eval(), -1.0);
497    }
498}