1use crate::{ParseResult, Parser, ParserExt, ParserHandle, ParserOutput, ParserRegistry};
2use std::{error::Error, sync::Arc};
3
4pub mod shorthand {
5 use super::*;
6
7 pub fn pratt(tokenizer_parser: ParserHandle, rules: Vec<Vec<PrattParserRule>>) -> ParserHandle {
8 let mut result = PrattParser::new(tokenizer_parser);
9 for rule in rules {
10 result.push_rules(rule);
11 }
12 result.into_handle()
13 }
14}
15
16#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
17pub enum PrattParserAssociativity {
18 #[default]
19 Left,
20 Right,
21}
22
23#[derive(Clone)]
24pub enum PrattParserRule {
25 Prefix {
26 operator: Arc<dyn Fn(&ParserOutput) -> bool + Send + Sync>,
27 transformer: Arc<dyn Fn(ParserOutput) -> ParserOutput + Send + Sync>,
28 },
29 Postfix {
30 operator: Arc<dyn Fn(&ParserOutput) -> bool + Send + Sync>,
31 transformer: Arc<dyn Fn(ParserOutput) -> ParserOutput + Send + Sync>,
32 },
33 Infix {
34 operator: Arc<dyn Fn(&ParserOutput) -> bool + Send + Sync>,
35 transformer: Arc<dyn Fn(ParserOutput, ParserOutput) -> ParserOutput + Send + Sync>,
36 associativity: PrattParserAssociativity,
37 },
38}
39
40impl PrattParserRule {
41 pub fn prefx_raw(
42 operator: impl Fn(&ParserOutput) -> bool + Send + Sync + 'static,
43 transformer: impl Fn(ParserOutput) -> ParserOutput + Send + Sync + 'static,
44 ) -> Self {
45 Self::Prefix {
46 operator: Arc::new(operator),
47 transformer: Arc::new(transformer),
48 }
49 }
50
51 pub fn prefix<O: PartialEq + Send + Sync + 'static, V: Send + Sync + 'static>(
52 operator: O,
53 transformer: impl Fn(V) -> V + Send + Sync + 'static,
54 ) -> Self {
55 Self::prefx_raw(
56 move |token| {
57 token
58 .read::<O>()
59 .map(|op| *op == operator)
60 .unwrap_or_default()
61 },
62 move |value| {
63 let value = value.consume::<V>().ok().unwrap();
64 let result = (transformer)(value);
65 ParserOutput::new(result).ok().unwrap()
66 },
67 )
68 }
69
70 pub fn postfix_raw(
71 operator: impl Fn(&ParserOutput) -> bool + Send + Sync + 'static,
72 transformer: impl Fn(ParserOutput) -> ParserOutput + Send + Sync + 'static,
73 ) -> Self {
74 Self::Postfix {
75 operator: Arc::new(operator),
76 transformer: Arc::new(transformer),
77 }
78 }
79
80 pub fn postfix<O: PartialEq + Send + Sync + 'static, V: Send + Sync + 'static>(
81 operator: O,
82 transformer: impl Fn(V) -> V + Send + Sync + 'static,
83 ) -> Self {
84 Self::postfix_raw(
85 move |token| {
86 token
87 .read::<O>()
88 .map(|op| *op == operator)
89 .unwrap_or_default()
90 },
91 move |value| {
92 let value = value.consume::<V>().ok().unwrap();
93 let result = (transformer)(value);
94 ParserOutput::new(result).ok().unwrap()
95 },
96 )
97 }
98
99 pub fn infix_raw(
100 operator: impl Fn(&ParserOutput) -> bool + Send + Sync + 'static,
101 transformer: impl Fn(ParserOutput, ParserOutput) -> ParserOutput + Send + Sync + 'static,
102 associativity: PrattParserAssociativity,
103 ) -> Self {
104 Self::Infix {
105 operator: Arc::new(operator),
106 transformer: Arc::new(transformer),
107 associativity,
108 }
109 }
110
111 pub fn infix<O: PartialEq + Send + Sync + 'static, V: Send + Sync + 'static>(
112 operator: O,
113 transformer: impl Fn(V, V) -> V + Send + Sync + 'static,
114 associativity: PrattParserAssociativity,
115 ) -> Self {
116 Self::infix_raw(
117 move |token| {
118 token
119 .read::<O>()
120 .map(|op| *op == operator)
121 .unwrap_or_default()
122 },
123 move |lhs, rhs| {
124 let lhs = lhs.consume::<V>().ok().unwrap();
125 let rhs = rhs.consume::<V>().ok().unwrap();
126 let result = (transformer)(lhs, rhs);
127 ParserOutput::new(result).ok().unwrap()
128 },
129 associativity,
130 )
131 }
132
133 fn flip_binding_power(&self) -> bool {
134 matches!(
135 self,
136 Self::Infix {
137 associativity: PrattParserAssociativity::Right,
138 ..
139 }
140 )
141 }
142}
143
144#[derive(Clone)]
145pub struct PrattParser {
146 tokenizer_parser: ParserHandle,
147 rules: Vec<(PrattParserRule, usize, usize)>,
149 binding_power_generator: usize,
150}
151
152impl PrattParser {
153 pub fn new(tokenizer_parser: ParserHandle) -> Self {
154 Self {
155 tokenizer_parser,
156 rules: vec![],
157 binding_power_generator: 0,
158 }
159 }
160
161 pub fn with_rules(mut self, rules: impl IntoIterator<Item = PrattParserRule>) -> Self {
162 self.push_rules(rules);
163 self
164 }
165
166 pub fn push_rules(&mut self, rules: impl IntoIterator<Item = PrattParserRule>) {
167 let low = self.binding_power_generator + 1;
168 let high = self.binding_power_generator + 2;
169 self.binding_power_generator += 2;
170 for rule in rules {
171 if rule.flip_binding_power() {
172 self.rules.push((rule, high, low));
173 } else {
174 self.rules.push((rule, low, high));
175 }
176 }
177 }
178
179 fn parse_inner(
180 &self,
181 tokens: &mut Vec<ParserOutput>,
182 min_bp: usize,
183 ) -> Result<ParserOutput, Box<dyn Error>> {
184 let Some(mut lhs) = tokens.pop() else {
185 return Err("Expected LHS token value".into());
186 };
187 if let Some((rule, _, rbp)) = self.find_prefix_rule(&lhs) {
188 let rhs = self.parse_inner(tokens, rbp)?;
189 if let PrattParserRule::Prefix { transformer, .. } = rule {
190 lhs = (*transformer)(rhs);
191 } else {
192 return Err("Expected prefix rule".into());
193 }
194 }
195 while let Some(op) = tokens.pop() {
196 if let Some((rule, lbp, _)) = self.find_postfix_rule(&op) {
197 if lbp < min_bp {
198 tokens.push(op);
199 break;
200 }
201 if let PrattParserRule::Postfix { transformer, .. } = rule {
202 lhs = (*transformer)(lhs);
203 } else {
204 return Err("Expected postfix rule".into());
205 }
206 continue;
207 }
208 if let Some((rule, lbp, rbp)) = self.find_infix_rule(&op) {
209 if lbp < min_bp {
210 tokens.push(op);
211 break;
212 }
213 let rhs = self.parse_inner(tokens, rbp)?;
214 if let PrattParserRule::Infix { transformer, .. } = rule {
215 lhs = (*transformer)(lhs, rhs);
216 } else {
217 return Err("Expected infix rule".into());
218 }
219 continue;
220 }
221 tokens.push(op);
222 break;
223 }
224 Ok(lhs)
225 }
226
227 fn find_prefix_rule(&self, token: &ParserOutput) -> Option<(&PrattParserRule, (), usize)> {
229 self.rules
230 .iter()
231 .find(|(rule, _, _)| match rule {
232 PrattParserRule::Prefix { operator, .. } => (*operator)(token),
233 _ => false,
234 })
235 .map(|(rule, _, rbp)| (rule, (), *rbp))
236 }
237
238 fn find_postfix_rule(&self, token: &ParserOutput) -> Option<(&PrattParserRule, usize, ())> {
240 self.rules
241 .iter()
242 .find(|(rule, _, _)| match rule {
243 PrattParserRule::Postfix { operator, .. } => (*operator)(token),
244 _ => false,
245 })
246 .map(|(rule, lbp, _)| (rule, *lbp, ()))
247 }
248
249 fn find_infix_rule(&self, token: &ParserOutput) -> Option<(&PrattParserRule, usize, usize)> {
251 self.rules
252 .iter()
253 .find(|(rule, _, _)| match rule {
254 PrattParserRule::Infix { operator, .. } => (*operator)(token),
255 _ => false,
256 })
257 .map(|(rule, lbp, rbp)| (rule, *lbp, *rbp))
258 }
259}
260
261impl Parser for PrattParser {
262 fn parse<'a>(&self, registry: &ParserRegistry, input: &'a str) -> ParseResult<'a> {
263 let (input, result) = self.tokenizer_parser.parse(registry, input)?;
264 let mut tokens = match result.consume::<Vec<ParserOutput>>() {
265 Ok(tokens) => tokens,
266 Err(_) => {
267 return Err("PrattParser expects `Vec<ParserOutput>` tokenization result".into())
268 }
269 };
270 tokens.reverse();
271 let result = self.parse_inner(&mut tokens, 0)?;
272 if !tokens.is_empty() {
273 return Err("PrattParser did not consumed all tokens".into());
274 }
275 Ok((input, result))
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use crate::{
282 pratt::{PrattParser, PrattParserAssociativity, PrattParserRule},
283 shorthand::{
284 alt, inject, list, lit, map, map_err, number_float, oc, ows, pratt, prefix, suffix,
285 },
286 ParserHandle, ParserRegistry,
287 };
288
289 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
290 enum Operator {
291 Add,
292 Sub,
293 Mul,
294 Div,
295 Hash,
297 Bang,
299 }
300
301 impl std::fmt::Display for Operator {
302 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303 match self {
304 Self::Add => write!(f, "+"),
305 Self::Sub => write!(f, "-"),
306 Self::Mul => write!(f, "*"),
307 Self::Div => write!(f, "/"),
308 Self::Hash => write!(f, "#"),
309 Self::Bang => write!(f, "!"),
310 }
311 }
312 }
313
314 #[derive(Debug)]
315 enum Expression {
316 Number(f32),
317 UnaryOperation {
318 op: Operator,
319 value: Box<Expression>,
320 },
321 BinaryOperation {
322 op: Operator,
323 lhs: Box<Expression>,
324 rhs: Box<Expression>,
325 },
326 }
327
328 impl Expression {
329 fn eval(&self) -> f32 {
330 match self {
331 Self::Number(value) => *value,
332 Self::UnaryOperation { op, value } => match op {
333 Operator::Hash => value.eval().floor(),
334 Operator::Bang => value.eval().fract(),
335 _ => unreachable!(),
336 },
337 Self::BinaryOperation { op, lhs, rhs } => match op {
338 Operator::Add => lhs.eval() + rhs.eval(),
339 Operator::Sub => lhs.eval() - rhs.eval(),
340 Operator::Mul => lhs.eval() * rhs.eval(),
341 Operator::Div => lhs.eval() / rhs.eval(),
342 _ => unreachable!(),
343 },
344 }
345 }
346 }
347
348 impl std::fmt::Display for Expression {
349 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350 match self {
351 Self::Number(value) => write!(f, "{}", value),
352 Self::UnaryOperation { value, op } => write!(f, "({} {})", op, value),
353 Self::BinaryOperation { op, lhs, rhs } => write!(f, "({} {} {})", op, lhs, rhs),
354 }
355 }
356 }
357
358 fn number() -> ParserHandle {
359 map_err(
360 map(number_float(), |value: String| {
361 Expression::Number(value.parse().unwrap())
362 }),
363 |_| "Expected number".into(),
364 )
365 }
366
367 fn op() -> ParserHandle {
368 map_err(
369 map(
370 alt([lit("+"), lit("-"), lit("*"), lit("/"), lit("#"), lit("!")]),
371 |value: String| match value.as_str() {
372 "+" => Operator::Add,
373 "-" => Operator::Sub,
374 "*" => Operator::Mul,
375 "/" => Operator::Div,
376 "#" => Operator::Hash,
377 "!" => Operator::Bang,
378 _ => unreachable!(),
379 },
380 ),
381 |_| "Expected operator".into(),
382 )
383 }
384
385 fn sub_expr() -> ParserHandle {
386 map_err(
387 oc(
388 inject("expr"),
389 suffix(lit("("), ows()),
390 prefix(lit(")"), ows()),
391 ),
392 |_| "Expected sub-expression".into(),
393 )
394 }
395
396 fn item() -> ParserHandle {
397 alt([inject("number"), inject("op"), inject("sub_expr")])
398 }
399
400 fn expr_tokenizer() -> ParserHandle {
401 list(inject("item"), ows(), true)
402 }
403
404 fn expr() -> ParserHandle {
405 pratt(
406 inject("expr_tokenizer"),
407 vec![
408 vec![
409 PrattParserRule::infix(
410 Operator::Add,
411 |lhs, rhs| Expression::BinaryOperation {
412 op: Operator::Add,
413 lhs: Box::new(lhs),
414 rhs: Box::new(rhs),
415 },
416 PrattParserAssociativity::Left,
417 ),
418 PrattParserRule::infix(
419 Operator::Sub,
420 |lhs, rhs| Expression::BinaryOperation {
421 op: Operator::Sub,
422 lhs: Box::new(lhs),
423 rhs: Box::new(rhs),
424 },
425 PrattParserAssociativity::Left,
426 ),
427 ],
428 vec![
429 PrattParserRule::infix(
430 Operator::Mul,
431 |lhs, rhs| Expression::BinaryOperation {
432 op: Operator::Mul,
433 lhs: Box::new(lhs),
434 rhs: Box::new(rhs),
435 },
436 PrattParserAssociativity::Left,
437 ),
438 PrattParserRule::infix(
439 Operator::Div,
440 |lhs, rhs| Expression::BinaryOperation {
441 op: Operator::Div,
442 lhs: Box::new(lhs),
443 rhs: Box::new(rhs),
444 },
445 PrattParserAssociativity::Left,
446 ),
447 ],
448 vec![PrattParserRule::prefix(Operator::Hash, |value| {
449 Expression::UnaryOperation {
450 op: Operator::Hash,
451 value: Box::new(value),
452 }
453 })],
454 vec![PrattParserRule::postfix(Operator::Bang, |value| {
455 Expression::UnaryOperation {
456 op: Operator::Bang,
457 value: Box::new(value),
458 }
459 })],
460 ],
461 )
462 }
463
464 fn is_async<T: Send + Sync>() {}
465
466 #[test]
467 fn test_pratt() {
468 is_async::<PrattParser>();
469
470 let registry = ParserRegistry::default()
471 .with_parser("number", number())
472 .with_parser("op", op())
473 .with_parser("sub_expr", sub_expr())
474 .with_parser("item", item())
475 .with_parser("expr_tokenizer", expr_tokenizer())
476 .with_parser("expr", expr());
477 let (rest, result) = registry.parse("expr", "(((0)))").unwrap();
478 assert_eq!(rest, "");
479 let result = result.consume::<Expression>().ok().unwrap();
480 assert_eq!(result.to_string(), "0");
481 assert_eq!(result.eval(), 0.0);
482 let (rest, result) = registry.parse("expr", "(3 + 4) * 2 - 1 / 5").unwrap();
483 assert_eq!(rest, "");
484 let result = result.consume::<Expression>().ok().unwrap();
485 assert_eq!(result.to_string(), "(- (* (+ 3 4) 2) (/ 1 5))");
486 assert_eq!(result.eval(), 13.8);
487 let (rest, result) = registry.parse("expr", "#1.2 + 3.4!").unwrap();
488 assert_eq!(rest, "");
489 let result = result.consume::<Expression>().ok().unwrap();
490 assert_eq!(result.to_string(), "(+ (# 1.2) (! 3.4))");
491 assert_eq!(result.eval(), 1.4000001);
492 let (rest, result) = registry.parse("expr", "#(1.2 - 3.4)!").unwrap();
493 assert_eq!(rest, "");
494 let result = result.consume::<Expression>().ok().unwrap();
495 assert_eq!(result.to_string(), "(# (! (- 1.2 3.4)))");
496 assert_eq!(result.eval(), -1.0);
497 }
498}