regex_automata/util/
alphabet.rs

1/*!
2This module provides APIs for dealing with the alphabets of finite state
3machines.
4
5There are two principal types in this module, [`ByteClasses`] and [`Unit`].
6The former defines the alphabet of a finite state machine while the latter
7represents an element of that alphabet.
8
9To a first approximation, the alphabet of all automata in this crate is just
10a `u8`. Namely, every distinct byte value. All 256 of them. In practice, this
11can be quite wasteful when building a transition table for a DFA, since it
12requires storing a state identifier for each element in the alphabet. Instead,
13we collapse the alphabet of an automaton down into equivalence classes, where
14every byte in the same equivalence class never discriminates between a match or
15a non-match from any other byte in the same class. For example, in the regex
16`[a-z]+`, then you could consider it having an alphabet consisting of two
17equivalence classes: `a-z` and everything else. In terms of the transitions on
18an automaton, it doesn't actually require representing every distinct byte.
19Just the equivalence classes.
20
21The downside of equivalence classes is that, of course, searching a haystack
22deals with individual byte values. Those byte values need to be mapped to
23their corresponding equivalence class. This is what `ByteClasses` does. In
24practice, doing this for every state transition has negligible impact on modern
25CPUs. Moreover, it helps make more efficient use of the CPU cache by (possibly
26considerably) shrinking the size of the transition table.
27
28One last hiccup concerns `Unit`. Namely, because of look-around and how the
29DFAs in this crate work, we need to add a sentinel value to our alphabet
30of equivalence classes that represents the "end" of a search. We call that
31sentinel [`Unit::eoi`] or "end of input." Thus, a `Unit` is either an
32equivalence class corresponding to a set of bytes, or it is a special "end of
33input" sentinel.
34
35In general, you should not expect to need either of these types unless you're
36doing lower level shenanigans with DFAs, or even building your own DFAs.
37(Although, you don't have to use these types to build your own DFAs of course.)
38For example, if you're walking a DFA's state graph, it's probably useful to
39make use of [`ByteClasses`] to visit each element in the DFA's alphabet instead
40of just visiting every distinct `u8` value. The latter isn't necessarily wrong,
41but it could be potentially very wasteful.
42*/
43use crate::util::{
44    escape::DebugByte,
45    wire::{self, DeserializeError, SerializeError},
46};
47
48/// Unit represents a single unit of haystack for DFA based regex engines.
49///
50/// It is not expected for consumers of this crate to need to use this type
51/// unless they are implementing their own DFA. And even then, it's not
52/// required: implementors may use other techniques to handle haystack units.
53///
54/// Typically, a single unit of haystack for a DFA would be a single byte.
55/// However, for the DFAs in this crate, matches are delayed by a single byte
56/// in order to handle look-ahead assertions (`\b`, `$` and `\z`). Thus, once
57/// we have consumed the haystack, we must run the DFA through one additional
58/// transition using a unit that indicates the haystack has ended.
59///
60/// There is no way to represent a sentinel with a `u8` since all possible
61/// values *may* be valid haystack units to a DFA, therefore this type
62/// explicitly adds room for a sentinel value.
63///
64/// The sentinel EOI value is always its own equivalence class and is
65/// ultimately represented by adding 1 to the maximum equivalence class value.
66/// So for example, the regex `^[a-z]+$` might be split into the following
67/// equivalence classes:
68///
69/// ```text
70/// 0 => [\x00-`]
71/// 1 => [a-z]
72/// 2 => [{-\xFF]
73/// 3 => [EOI]
74/// ```
75///
76/// Where EOI is the special sentinel value that is always in its own
77/// singleton equivalence class.
78#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord)]
79pub struct Unit(UnitKind);
80
81#[derive(Clone, Copy, Eq, PartialEq, PartialOrd, Ord)]
82enum UnitKind {
83    /// Represents a byte value, or more typically, an equivalence class
84    /// represented as a byte value.
85    U8(u8),
86    /// Represents the "end of input" sentinel. We regretably use a `u16`
87    /// here since the maximum sentinel value is `256`. Thankfully, we don't
88    /// actually store a `Unit` anywhere, so this extra space shouldn't be too
89    /// bad.
90    EOI(u16),
91}
92
93impl Unit {
94    /// Create a new haystack unit from a byte value.
95    ///
96    /// All possible byte values are legal. However, when creating a haystack
97    /// unit for a specific DFA, one should be careful to only construct units
98    /// that are in that DFA's alphabet. Namely, one way to compact a DFA's
99    /// in-memory representation is to collapse its transitions to a set of
100    /// equivalence classes into a set of all possible byte values. If a DFA
101    /// uses equivalence classes instead of byte values, then the byte given
102    /// here should be the equivalence class.
103    pub fn u8(byte: u8) -> Unit {
104        Unit(UnitKind::U8(byte))
105    }
106
107    /// Create a new "end of input" haystack unit.
108    ///
109    /// The value given is the sentinel value used by this unit to represent
110    /// the "end of input." The value should be the total number of equivalence
111    /// classes in the corresponding alphabet. Its maximum value is `256`,
112    /// which occurs when every byte is its own equivalence class.
113    ///
114    /// # Panics
115    ///
116    /// This panics when `num_byte_equiv_classes` is greater than `256`.
117    pub fn eoi(num_byte_equiv_classes: usize) -> Unit {
118        assert!(
119            num_byte_equiv_classes <= 256,
120            "max number of byte-based equivalent classes is 256, but got {}",
121            num_byte_equiv_classes,
122        );
123        Unit(UnitKind::EOI(u16::try_from(num_byte_equiv_classes).unwrap()))
124    }
125
126    /// If this unit is not an "end of input" sentinel, then returns its
127    /// underlying byte value. Otherwise return `None`.
128    pub fn as_u8(self) -> Option<u8> {
129        match self.0 {
130            UnitKind::U8(b) => Some(b),
131            UnitKind::EOI(_) => None,
132        }
133    }
134
135    /// If this unit is an "end of input" sentinel, then return the underlying
136    /// sentinel value that was given to [`Unit::eoi`]. Otherwise return
137    /// `None`.
138    pub fn as_eoi(self) -> Option<u16> {
139        match self.0 {
140            UnitKind::U8(_) => None,
141            UnitKind::EOI(sentinel) => Some(sentinel),
142        }
143    }
144
145    /// Return this unit as a `usize`, regardless of whether it is a byte value
146    /// or an "end of input" sentinel. In the latter case, the underlying
147    /// sentinel value given to [`Unit::eoi`] is returned.
148    pub fn as_usize(self) -> usize {
149        match self.0 {
150            UnitKind::U8(b) => usize::from(b),
151            UnitKind::EOI(eoi) => usize::from(eoi),
152        }
153    }
154
155    /// Returns true if and only of this unit is a byte value equivalent to the
156    /// byte given. This always returns false when this is an "end of input"
157    /// sentinel.
158    pub fn is_byte(self, byte: u8) -> bool {
159        self.as_u8().map_or(false, |b| b == byte)
160    }
161
162    /// Returns true when this unit represents an "end of input" sentinel.
163    pub fn is_eoi(self) -> bool {
164        self.as_eoi().is_some()
165    }
166
167    /// Returns true when this unit corresponds to an ASCII word byte.
168    ///
169    /// This always returns false when this unit represents an "end of input"
170    /// sentinel.
171    pub fn is_word_byte(self) -> bool {
172        self.as_u8().map_or(false, crate::util::utf8::is_word_byte)
173    }
174}
175
176impl core::fmt::Debug for Unit {
177    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
178        match self.0 {
179            UnitKind::U8(b) => write!(f, "{:?}", DebugByte(b)),
180            UnitKind::EOI(_) => write!(f, "EOI"),
181        }
182    }
183}
184
185/// A representation of byte oriented equivalence classes.
186///
187/// This is used in a DFA to reduce the size of the transition table. This can
188/// have a particularly large impact not only on the total size of a dense DFA,
189/// but also on compile times.
190///
191/// The essential idea here is that the alphabet of a DFA is shrunk from the
192/// usual 256 distinct byte values down to a set of equivalence classes. The
193/// guarantee you get is that any byte belonging to the same equivalence class
194/// can be treated as if it were any other byte in the same class, and the
195/// result of a search wouldn't change.
196///
197/// # Example
198///
199/// This example shows how to get byte classes from an
200/// [`NFA`](crate::nfa::thompson::NFA) and ask for the class of various bytes.
201///
202/// ```
203/// use regex_automata::nfa::thompson::NFA;
204///
205/// let nfa = NFA::new("[a-z]+")?;
206/// let classes = nfa.byte_classes();
207/// // 'a' and 'z' are in the same class for this regex.
208/// assert_eq!(classes.get(b'a'), classes.get(b'z'));
209/// // But 'a' and 'A' are not.
210/// assert_ne!(classes.get(b'a'), classes.get(b'A'));
211///
212/// # Ok::<(), Box<dyn std::error::Error>>(())
213/// ```
214#[derive(Clone, Copy)]
215pub struct ByteClasses([u8; 256]);
216
217impl ByteClasses {
218    /// Creates a new set of equivalence classes where all bytes are mapped to
219    /// the same class.
220    #[inline]
221    pub fn empty() -> ByteClasses {
222        ByteClasses([0; 256])
223    }
224
225    /// Creates a new set of equivalence classes where each byte belongs to
226    /// its own equivalence class.
227    #[inline]
228    pub fn singletons() -> ByteClasses {
229        let mut classes = ByteClasses::empty();
230        for b in 0..=255 {
231            classes.set(b, b);
232        }
233        classes
234    }
235
236    /// Deserializes a byte class map from the given slice. If the slice is of
237    /// insufficient length or otherwise contains an impossible mapping, then
238    /// an error is returned. Upon success, the number of bytes read along with
239    /// the map are returned. The number of bytes read is always a multiple of
240    /// 8.
241    pub(crate) fn from_bytes(
242        slice: &[u8],
243    ) -> Result<(ByteClasses, usize), DeserializeError> {
244        wire::check_slice_len(slice, 256, "byte class map")?;
245        let mut classes = ByteClasses::empty();
246        for (b, &class) in slice[..256].iter().enumerate() {
247            classes.set(u8::try_from(b).unwrap(), class);
248        }
249        // We specifically don't use 'classes.iter()' here because that
250        // iterator depends on 'classes.alphabet_len()' being correct. But that
251        // is precisely the thing we're trying to verify below!
252        for &b in classes.0.iter() {
253            if usize::from(b) >= classes.alphabet_len() {
254                return Err(DeserializeError::generic(
255                    "found equivalence class greater than alphabet len",
256                ));
257            }
258        }
259        Ok((classes, 256))
260    }
261
262    /// Writes this byte class map to the given byte buffer. if the given
263    /// buffer is too small, then an error is returned. Upon success, the total
264    /// number of bytes written is returned. The number of bytes written is
265    /// guaranteed to be a multiple of 8.
266    pub(crate) fn write_to(
267        &self,
268        mut dst: &mut [u8],
269    ) -> Result<usize, SerializeError> {
270        let nwrite = self.write_to_len();
271        if dst.len() < nwrite {
272            return Err(SerializeError::buffer_too_small("byte class map"));
273        }
274        for b in 0..=255 {
275            dst[0] = self.get(b);
276            dst = &mut dst[1..];
277        }
278        Ok(nwrite)
279    }
280
281    /// Returns the total number of bytes written by `write_to`.
282    pub(crate) fn write_to_len(&self) -> usize {
283        256
284    }
285
286    /// Set the equivalence class for the given byte.
287    #[inline]
288    pub fn set(&mut self, byte: u8, class: u8) {
289        self.0[usize::from(byte)] = class;
290    }
291
292    /// Get the equivalence class for the given byte.
293    #[inline]
294    pub fn get(&self, byte: u8) -> u8 {
295        self.0[usize::from(byte)]
296    }
297
298    /// Get the equivalence class for the given haystack unit and return the
299    /// class as a `usize`.
300    #[inline]
301    pub fn get_by_unit(&self, unit: Unit) -> usize {
302        match unit.0 {
303            UnitKind::U8(b) => usize::from(self.get(b)),
304            UnitKind::EOI(b) => usize::from(b),
305        }
306    }
307
308    /// Create a unit that represents the "end of input" sentinel based on the
309    /// number of equivalence classes.
310    #[inline]
311    pub fn eoi(&self) -> Unit {
312        // The alphabet length already includes the EOI sentinel, hence why
313        // we subtract 1.
314        Unit::eoi(self.alphabet_len().checked_sub(1).unwrap())
315    }
316
317    /// Return the total number of elements in the alphabet represented by
318    /// these equivalence classes. Equivalently, this returns the total number
319    /// of equivalence classes.
320    #[inline]
321    pub fn alphabet_len(&self) -> usize {
322        // Add one since the number of equivalence classes is one bigger than
323        // the last one. But add another to account for the final EOI class
324        // that isn't explicitly represented.
325        usize::from(self.0[255]) + 1 + 1
326    }
327
328    /// Returns the stride, as a base-2 exponent, required for these
329    /// equivalence classes.
330    ///
331    /// The stride is always the smallest power of 2 that is greater than or
332    /// equal to the alphabet length, and the `stride2` returned here is the
333    /// exponent applied to `2` to get the smallest power. This is done so that
334    /// converting between premultiplied state IDs and indices can be done with
335    /// shifts alone, which is much faster than integer division.
336    #[inline]
337    pub fn stride2(&self) -> usize {
338        let zeros = self.alphabet_len().next_power_of_two().trailing_zeros();
339        usize::try_from(zeros).unwrap()
340    }
341
342    /// Returns true if and only if every byte in this class maps to its own
343    /// equivalence class. Equivalently, there are 257 equivalence classes
344    /// and each class contains either exactly one byte or corresponds to the
345    /// singleton class containing the "end of input" sentinel.
346    #[inline]
347    pub fn is_singleton(&self) -> bool {
348        self.alphabet_len() == 257
349    }
350
351    /// Returns an iterator over all equivalence classes in this set.
352    #[inline]
353    pub fn iter(&self) -> ByteClassIter<'_> {
354        ByteClassIter { classes: self, i: 0 }
355    }
356
357    /// Returns an iterator over a sequence of representative bytes from each
358    /// equivalence class within the range of bytes given.
359    ///
360    /// When the given range is unbounded on both sides, the iterator yields
361    /// exactly N items, where N is equivalent to the number of equivalence
362    /// classes. Each item is an arbitrary byte drawn from each equivalence
363    /// class.
364    ///
365    /// This is useful when one is determinizing an NFA and the NFA's alphabet
366    /// hasn't been converted to equivalence classes. Picking an arbitrary byte
367    /// from each equivalence class then permits a full exploration of the NFA
368    /// instead of using every possible byte value and thus potentially saves
369    /// quite a lot of redundant work.
370    ///
371    /// # Example
372    ///
373    /// This shows an example of what a complete sequence of representatives
374    /// might look like from a real example.
375    ///
376    /// ```
377    /// use regex_automata::{nfa::thompson::NFA, util::alphabet::Unit};
378    ///
379    /// let nfa = NFA::new("[a-z]+")?;
380    /// let classes = nfa.byte_classes();
381    /// let reps: Vec<Unit> = classes.representatives(..).collect();
382    /// // Note that the specific byte values yielded are not guaranteed!
383    /// let expected = vec![
384    ///     Unit::u8(b'\x00'),
385    ///     Unit::u8(b'a'),
386    ///     Unit::u8(b'{'),
387    ///     Unit::eoi(3),
388    /// ];
389    /// assert_eq!(expected, reps);
390    ///
391    /// # Ok::<(), Box<dyn std::error::Error>>(())
392    /// ```
393    ///
394    /// Note though, that you can ask for an arbitrary range of bytes, and only
395    /// representatives for that range will be returned:
396    ///
397    /// ```
398    /// use regex_automata::{nfa::thompson::NFA, util::alphabet::Unit};
399    ///
400    /// let nfa = NFA::new("[a-z]+")?;
401    /// let classes = nfa.byte_classes();
402    /// let reps: Vec<Unit> = classes.representatives(b'A'..=b'z').collect();
403    /// // Note that the specific byte values yielded are not guaranteed!
404    /// let expected = vec![
405    ///     Unit::u8(b'A'),
406    ///     Unit::u8(b'a'),
407    /// ];
408    /// assert_eq!(expected, reps);
409    ///
410    /// # Ok::<(), Box<dyn std::error::Error>>(())
411    /// ```
412    pub fn representatives<R: core::ops::RangeBounds<u8>>(
413        &self,
414        range: R,
415    ) -> ByteClassRepresentatives<'_> {
416        use core::ops::Bound;
417
418        let cur_byte = match range.start_bound() {
419            Bound::Included(&i) => usize::from(i),
420            Bound::Excluded(&i) => usize::from(i).checked_add(1).unwrap(),
421            Bound::Unbounded => 0,
422        };
423        let end_byte = match range.end_bound() {
424            Bound::Included(&i) => {
425                Some(usize::from(i).checked_add(1).unwrap())
426            }
427            Bound::Excluded(&i) => Some(usize::from(i)),
428            Bound::Unbounded => None,
429        };
430        assert_ne!(
431            cur_byte,
432            usize::MAX,
433            "start range must be less than usize::MAX",
434        );
435        ByteClassRepresentatives {
436            classes: self,
437            cur_byte,
438            end_byte,
439            last_class: None,
440        }
441    }
442
443    /// Returns an iterator of the bytes in the given equivalence class.
444    ///
445    /// This is useful when one needs to know the actual bytes that belong to
446    /// an equivalence class. For example, conceptually speaking, accelerating
447    /// a DFA state occurs when a state only has a few outgoing transitions.
448    /// But in reality, what is required is that there are only a small
449    /// number of distinct bytes that can lead to an outgoing transition. The
450    /// difference is that any one transition can correspond to an equivalence
451    /// class which may contains many bytes. Therefore, DFA state acceleration
452    /// considers the actual elements in each equivalence class of each
453    /// outgoing transition.
454    ///
455    /// # Example
456    ///
457    /// This shows an example of how to get all of the elements in an
458    /// equivalence class.
459    ///
460    /// ```
461    /// use regex_automata::{nfa::thompson::NFA, util::alphabet::Unit};
462    ///
463    /// let nfa = NFA::new("[a-z]+")?;
464    /// let classes = nfa.byte_classes();
465    /// let elements: Vec<Unit> = classes.elements(Unit::u8(1)).collect();
466    /// let expected: Vec<Unit> = (b'a'..=b'z').map(Unit::u8).collect();
467    /// assert_eq!(expected, elements);
468    ///
469    /// # Ok::<(), Box<dyn std::error::Error>>(())
470    /// ```
471    #[inline]
472    pub fn elements(&self, class: Unit) -> ByteClassElements {
473        ByteClassElements { classes: self, class, byte: 0 }
474    }
475
476    /// Returns an iterator of byte ranges in the given equivalence class.
477    ///
478    /// That is, a sequence of contiguous ranges are returned. Typically, every
479    /// class maps to a single contiguous range.
480    fn element_ranges(&self, class: Unit) -> ByteClassElementRanges {
481        ByteClassElementRanges { elements: self.elements(class), range: None }
482    }
483}
484
485impl Default for ByteClasses {
486    fn default() -> ByteClasses {
487        ByteClasses::singletons()
488    }
489}
490
491impl core::fmt::Debug for ByteClasses {
492    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
493        if self.is_singleton() {
494            write!(f, "ByteClasses({{singletons}})")
495        } else {
496            write!(f, "ByteClasses(")?;
497            for (i, class) in self.iter().enumerate() {
498                if i > 0 {
499                    write!(f, ", ")?;
500                }
501                write!(f, "{:?} => [", class.as_usize())?;
502                for (start, end) in self.element_ranges(class) {
503                    if start == end {
504                        write!(f, "{:?}", start)?;
505                    } else {
506                        write!(f, "{:?}-{:?}", start, end)?;
507                    }
508                }
509                write!(f, "]")?;
510            }
511            write!(f, ")")
512        }
513    }
514}
515
516/// An iterator over each equivalence class.
517///
518/// The last element in this iterator always corresponds to [`Unit::eoi`].
519///
520/// This is created by the [`ByteClasses::iter`] method.
521///
522/// The lifetime `'a` refers to the lifetime of the byte classes that this
523/// iterator was created from.
524#[derive(Debug)]
525pub struct ByteClassIter<'a> {
526    classes: &'a ByteClasses,
527    i: usize,
528}
529
530impl<'a> Iterator for ByteClassIter<'a> {
531    type Item = Unit;
532
533    fn next(&mut self) -> Option<Unit> {
534        if self.i + 1 == self.classes.alphabet_len() {
535            self.i += 1;
536            Some(self.classes.eoi())
537        } else if self.i < self.classes.alphabet_len() {
538            let class = u8::try_from(self.i).unwrap();
539            self.i += 1;
540            Some(Unit::u8(class))
541        } else {
542            None
543        }
544    }
545}
546
547/// An iterator over representative bytes from each equivalence class.
548///
549/// This is created by the [`ByteClasses::representatives`] method.
550///
551/// The lifetime `'a` refers to the lifetime of the byte classes that this
552/// iterator was created from.
553#[derive(Debug)]
554pub struct ByteClassRepresentatives<'a> {
555    classes: &'a ByteClasses,
556    cur_byte: usize,
557    end_byte: Option<usize>,
558    last_class: Option<u8>,
559}
560
561impl<'a> Iterator for ByteClassRepresentatives<'a> {
562    type Item = Unit;
563
564    fn next(&mut self) -> Option<Unit> {
565        while self.cur_byte < self.end_byte.unwrap_or(256) {
566            let byte = u8::try_from(self.cur_byte).unwrap();
567            let class = self.classes.get(byte);
568            self.cur_byte += 1;
569
570            if self.last_class != Some(class) {
571                self.last_class = Some(class);
572                return Some(Unit::u8(byte));
573            }
574        }
575        if self.cur_byte != usize::MAX && self.end_byte.is_none() {
576            // Using usize::MAX as a sentinel is OK because we ban usize::MAX
577            // from appearing as a start bound in iterator construction. But
578            // why do it this way? Well, we want to return the EOI class
579            // whenever the end of the given range is unbounded because EOI
580            // isn't really a "byte" per se, so the only way it should be
581            // excluded is if there is a bounded end to the range. Therefore,
582            // when the end is unbounded, we just need to know whether we've
583            // reported EOI or not. When we do, we set cur_byte to a value it
584            // can never otherwise be.
585            self.cur_byte = usize::MAX;
586            return Some(self.classes.eoi());
587        }
588        None
589    }
590}
591
592/// An iterator over all elements in an equivalence class.
593///
594/// This is created by the [`ByteClasses::elements`] method.
595///
596/// The lifetime `'a` refers to the lifetime of the byte classes that this
597/// iterator was created from.
598#[derive(Debug)]
599pub struct ByteClassElements<'a> {
600    classes: &'a ByteClasses,
601    class: Unit,
602    byte: usize,
603}
604
605impl<'a> Iterator for ByteClassElements<'a> {
606    type Item = Unit;
607
608    fn next(&mut self) -> Option<Unit> {
609        while self.byte < 256 {
610            let byte = u8::try_from(self.byte).unwrap();
611            self.byte += 1;
612            if self.class.is_byte(self.classes.get(byte)) {
613                return Some(Unit::u8(byte));
614            }
615        }
616        if self.byte < 257 {
617            self.byte += 1;
618            if self.class.is_eoi() {
619                return Some(Unit::eoi(256));
620            }
621        }
622        None
623    }
624}
625
626/// An iterator over all elements in an equivalence class expressed as a
627/// sequence of contiguous ranges.
628#[derive(Debug)]
629struct ByteClassElementRanges<'a> {
630    elements: ByteClassElements<'a>,
631    range: Option<(Unit, Unit)>,
632}
633
634impl<'a> Iterator for ByteClassElementRanges<'a> {
635    type Item = (Unit, Unit);
636
637    fn next(&mut self) -> Option<(Unit, Unit)> {
638        loop {
639            let element = match self.elements.next() {
640                None => return self.range.take(),
641                Some(element) => element,
642            };
643            match self.range.take() {
644                None => {
645                    self.range = Some((element, element));
646                }
647                Some((start, end)) => {
648                    if end.as_usize() + 1 != element.as_usize()
649                        || element.is_eoi()
650                    {
651                        self.range = Some((element, element));
652                        return Some((start, end));
653                    }
654                    self.range = Some((start, element));
655                }
656            }
657        }
658    }
659}
660
661/// A partitioning of bytes into equivalence classes.
662///
663/// A byte class set keeps track of an *approximation* of equivalence classes
664/// of bytes during NFA construction. That is, every byte in an equivalence
665/// class cannot discriminate between a match and a non-match.
666///
667/// For example, in the regex `[ab]+`, the bytes `a` and `b` would be in the
668/// same equivalence class because it never matters whether an `a` or a `b` is
669/// seen, and no combination of `a`s and `b`s in the text can discriminate a
670/// match.
671///
672/// Note though that this does not compute the minimal set of equivalence
673/// classes. For example, in the regex `[ac]+`, both `a` and `c` are in the
674/// same equivalence class for the same reason that `a` and `b` are in the
675/// same equivalence class in the aforementioned regex. However, in this
676/// implementation, `a` and `c` are put into distinct equivalence classes. The
677/// reason for this is implementation complexity. In the future, we should
678/// endeavor to compute the minimal equivalence classes since they can have a
679/// rather large impact on the size of the DFA. (Doing this will likely require
680/// rethinking how equivalence classes are computed, including changing the
681/// representation here, which is only able to group contiguous bytes into the
682/// same equivalence class.)
683#[cfg(feature = "alloc")]
684#[derive(Clone, Debug)]
685pub(crate) struct ByteClassSet(ByteSet);
686
687#[cfg(feature = "alloc")]
688impl Default for ByteClassSet {
689    fn default() -> ByteClassSet {
690        ByteClassSet::empty()
691    }
692}
693
694#[cfg(feature = "alloc")]
695impl ByteClassSet {
696    /// Create a new set of byte classes where all bytes are part of the same
697    /// equivalence class.
698    pub(crate) fn empty() -> Self {
699        ByteClassSet(ByteSet::empty())
700    }
701
702    /// Indicate the range of byte given (inclusive) can discriminate a
703    /// match between it and all other bytes outside of the range.
704    pub(crate) fn set_range(&mut self, start: u8, end: u8) {
705        debug_assert!(start <= end);
706        if start > 0 {
707            self.0.add(start - 1);
708        }
709        self.0.add(end);
710    }
711
712    /// Add the contiguous ranges in the set given to this byte class set.
713    pub(crate) fn add_set(&mut self, set: &ByteSet) {
714        for (start, end) in set.iter_ranges() {
715            self.set_range(start, end);
716        }
717    }
718
719    /// Convert this boolean set to a map that maps all byte values to their
720    /// corresponding equivalence class. The last mapping indicates the largest
721    /// equivalence class identifier (which is never bigger than 255).
722    pub(crate) fn byte_classes(&self) -> ByteClasses {
723        let mut classes = ByteClasses::empty();
724        let mut class = 0u8;
725        let mut b = 0u8;
726        loop {
727            classes.set(b, class);
728            if b == 255 {
729                break;
730            }
731            if self.0.contains(b) {
732                class = class.checked_add(1).unwrap();
733            }
734            b = b.checked_add(1).unwrap();
735        }
736        classes
737    }
738}
739
740/// A simple set of bytes that is reasonably cheap to copy and allocation free.
741#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
742pub(crate) struct ByteSet {
743    bits: BitSet,
744}
745
746/// The representation of a byte set. Split out so that we can define a
747/// convenient Debug impl for it while keeping "ByteSet" in the output.
748#[derive(Clone, Copy, Default, Eq, PartialEq)]
749struct BitSet([u128; 2]);
750
751impl ByteSet {
752    /// Create an empty set of bytes.
753    pub(crate) fn empty() -> ByteSet {
754        ByteSet { bits: BitSet([0; 2]) }
755    }
756
757    /// Add a byte to this set.
758    ///
759    /// If the given byte already belongs to this set, then this is a no-op.
760    pub(crate) fn add(&mut self, byte: u8) {
761        let bucket = byte / 128;
762        let bit = byte % 128;
763        self.bits.0[usize::from(bucket)] |= 1 << bit;
764    }
765
766    /// Remove a byte from this set.
767    ///
768    /// If the given byte is not in this set, then this is a no-op.
769    pub(crate) fn remove(&mut self, byte: u8) {
770        let bucket = byte / 128;
771        let bit = byte % 128;
772        self.bits.0[usize::from(bucket)] &= !(1 << bit);
773    }
774
775    /// Return true if and only if the given byte is in this set.
776    pub(crate) fn contains(&self, byte: u8) -> bool {
777        let bucket = byte / 128;
778        let bit = byte % 128;
779        self.bits.0[usize::from(bucket)] & (1 << bit) > 0
780    }
781
782    /// Return true if and only if the given inclusive range of bytes is in
783    /// this set.
784    pub(crate) fn contains_range(&self, start: u8, end: u8) -> bool {
785        (start..=end).all(|b| self.contains(b))
786    }
787
788    /// Returns an iterator over all bytes in this set.
789    pub(crate) fn iter(&self) -> ByteSetIter {
790        ByteSetIter { set: self, b: 0 }
791    }
792
793    /// Returns an iterator over all contiguous ranges of bytes in this set.
794    pub(crate) fn iter_ranges(&self) -> ByteSetRangeIter {
795        ByteSetRangeIter { set: self, b: 0 }
796    }
797
798    /// Return true if and only if this set is empty.
799    #[cfg_attr(feature = "perf-inline", inline(always))]
800    pub(crate) fn is_empty(&self) -> bool {
801        self.bits.0 == [0, 0]
802    }
803
804    /// Deserializes a byte set from the given slice. If the slice is of
805    /// incorrect length or is otherwise malformed, then an error is returned.
806    /// Upon success, the number of bytes read along with the set are returned.
807    /// The number of bytes read is always a multiple of 8.
808    pub(crate) fn from_bytes(
809        slice: &[u8],
810    ) -> Result<(ByteSet, usize), DeserializeError> {
811        use core::mem::size_of;
812
813        wire::check_slice_len(slice, 2 * size_of::<u128>(), "byte set")?;
814        let mut nread = 0;
815        let (low, nr) = wire::try_read_u128(slice, "byte set low bucket")?;
816        nread += nr;
817        let (high, nr) = wire::try_read_u128(slice, "byte set high bucket")?;
818        nread += nr;
819        Ok((ByteSet { bits: BitSet([low, high]) }, nread))
820    }
821
822    /// Writes this byte set to the given byte buffer. If the given buffer is
823    /// too small, then an error is returned. Upon success, the total number of
824    /// bytes written is returned. The number of bytes written is guaranteed to
825    /// be a multiple of 8.
826    pub(crate) fn write_to<E: crate::util::wire::Endian>(
827        &self,
828        dst: &mut [u8],
829    ) -> Result<usize, SerializeError> {
830        use core::mem::size_of;
831
832        let nwrite = self.write_to_len();
833        if dst.len() < nwrite {
834            return Err(SerializeError::buffer_too_small("byte set"));
835        }
836        let mut nw = 0;
837        E::write_u128(self.bits.0[0], &mut dst[nw..]);
838        nw += size_of::<u128>();
839        E::write_u128(self.bits.0[1], &mut dst[nw..]);
840        nw += size_of::<u128>();
841        assert_eq!(nwrite, nw, "expected to write certain number of bytes",);
842        assert_eq!(
843            nw % 8,
844            0,
845            "expected to write multiple of 8 bytes for byte set",
846        );
847        Ok(nw)
848    }
849
850    /// Returns the total number of bytes written by `write_to`.
851    pub(crate) fn write_to_len(&self) -> usize {
852        2 * core::mem::size_of::<u128>()
853    }
854}
855
856impl core::fmt::Debug for BitSet {
857    fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
858        let mut fmtd = f.debug_set();
859        for b in 0u8..=255 {
860            if (ByteSet { bits: *self }).contains(b) {
861                fmtd.entry(&b);
862            }
863        }
864        fmtd.finish()
865    }
866}
867
868#[derive(Debug)]
869pub(crate) struct ByteSetIter<'a> {
870    set: &'a ByteSet,
871    b: usize,
872}
873
874impl<'a> Iterator for ByteSetIter<'a> {
875    type Item = u8;
876
877    fn next(&mut self) -> Option<u8> {
878        while self.b <= 255 {
879            let b = u8::try_from(self.b).unwrap();
880            self.b += 1;
881            if self.set.contains(b) {
882                return Some(b);
883            }
884        }
885        None
886    }
887}
888
889#[derive(Debug)]
890pub(crate) struct ByteSetRangeIter<'a> {
891    set: &'a ByteSet,
892    b: usize,
893}
894
895impl<'a> Iterator for ByteSetRangeIter<'a> {
896    type Item = (u8, u8);
897
898    fn next(&mut self) -> Option<(u8, u8)> {
899        let asu8 = |n: usize| u8::try_from(n).unwrap();
900        while self.b <= 255 {
901            let start = asu8(self.b);
902            self.b += 1;
903            if !self.set.contains(start) {
904                continue;
905            }
906
907            let mut end = start;
908            while self.b <= 255 && self.set.contains(asu8(self.b)) {
909                end = asu8(self.b);
910                self.b += 1;
911            }
912            return Some((start, end));
913        }
914        None
915    }
916}
917
918#[cfg(all(test, feature = "alloc"))]
919mod tests {
920    use alloc::{vec, vec::Vec};
921
922    use super::*;
923
924    #[test]
925    fn byte_classes() {
926        let mut set = ByteClassSet::empty();
927        set.set_range(b'a', b'z');
928
929        let classes = set.byte_classes();
930        assert_eq!(classes.get(0), 0);
931        assert_eq!(classes.get(1), 0);
932        assert_eq!(classes.get(2), 0);
933        assert_eq!(classes.get(b'a' - 1), 0);
934        assert_eq!(classes.get(b'a'), 1);
935        assert_eq!(classes.get(b'm'), 1);
936        assert_eq!(classes.get(b'z'), 1);
937        assert_eq!(classes.get(b'z' + 1), 2);
938        assert_eq!(classes.get(254), 2);
939        assert_eq!(classes.get(255), 2);
940
941        let mut set = ByteClassSet::empty();
942        set.set_range(0, 2);
943        set.set_range(4, 6);
944        let classes = set.byte_classes();
945        assert_eq!(classes.get(0), 0);
946        assert_eq!(classes.get(1), 0);
947        assert_eq!(classes.get(2), 0);
948        assert_eq!(classes.get(3), 1);
949        assert_eq!(classes.get(4), 2);
950        assert_eq!(classes.get(5), 2);
951        assert_eq!(classes.get(6), 2);
952        assert_eq!(classes.get(7), 3);
953        assert_eq!(classes.get(255), 3);
954    }
955
956    #[test]
957    fn full_byte_classes() {
958        let mut set = ByteClassSet::empty();
959        for b in 0u8..=255 {
960            set.set_range(b, b);
961        }
962        assert_eq!(set.byte_classes().alphabet_len(), 257);
963    }
964
965    #[test]
966    fn elements_typical() {
967        let mut set = ByteClassSet::empty();
968        set.set_range(b'b', b'd');
969        set.set_range(b'g', b'm');
970        set.set_range(b'z', b'z');
971        let classes = set.byte_classes();
972        // class 0: \x00-a
973        // class 1: b-d
974        // class 2: e-f
975        // class 3: g-m
976        // class 4: n-y
977        // class 5: z-z
978        // class 6: \x7B-\xFF
979        // class 7: EOI
980        assert_eq!(classes.alphabet_len(), 8);
981
982        let elements = classes.elements(Unit::u8(0)).collect::<Vec<_>>();
983        assert_eq!(elements.len(), 98);
984        assert_eq!(elements[0], Unit::u8(b'\x00'));
985        assert_eq!(elements[97], Unit::u8(b'a'));
986
987        let elements = classes.elements(Unit::u8(1)).collect::<Vec<_>>();
988        assert_eq!(
989            elements,
990            vec![Unit::u8(b'b'), Unit::u8(b'c'), Unit::u8(b'd')],
991        );
992
993        let elements = classes.elements(Unit::u8(2)).collect::<Vec<_>>();
994        assert_eq!(elements, vec![Unit::u8(b'e'), Unit::u8(b'f')],);
995
996        let elements = classes.elements(Unit::u8(3)).collect::<Vec<_>>();
997        assert_eq!(
998            elements,
999            vec![
1000                Unit::u8(b'g'),
1001                Unit::u8(b'h'),
1002                Unit::u8(b'i'),
1003                Unit::u8(b'j'),
1004                Unit::u8(b'k'),
1005                Unit::u8(b'l'),
1006                Unit::u8(b'm'),
1007            ],
1008        );
1009
1010        let elements = classes.elements(Unit::u8(4)).collect::<Vec<_>>();
1011        assert_eq!(elements.len(), 12);
1012        assert_eq!(elements[0], Unit::u8(b'n'));
1013        assert_eq!(elements[11], Unit::u8(b'y'));
1014
1015        let elements = classes.elements(Unit::u8(5)).collect::<Vec<_>>();
1016        assert_eq!(elements, vec![Unit::u8(b'z')]);
1017
1018        let elements = classes.elements(Unit::u8(6)).collect::<Vec<_>>();
1019        assert_eq!(elements.len(), 133);
1020        assert_eq!(elements[0], Unit::u8(b'\x7B'));
1021        assert_eq!(elements[132], Unit::u8(b'\xFF'));
1022
1023        let elements = classes.elements(Unit::eoi(7)).collect::<Vec<_>>();
1024        assert_eq!(elements, vec![Unit::eoi(256)]);
1025    }
1026
1027    #[test]
1028    fn elements_singletons() {
1029        let classes = ByteClasses::singletons();
1030        assert_eq!(classes.alphabet_len(), 257);
1031
1032        let elements = classes.elements(Unit::u8(b'a')).collect::<Vec<_>>();
1033        assert_eq!(elements, vec![Unit::u8(b'a')]);
1034
1035        let elements = classes.elements(Unit::eoi(5)).collect::<Vec<_>>();
1036        assert_eq!(elements, vec![Unit::eoi(256)]);
1037    }
1038
1039    #[test]
1040    fn elements_empty() {
1041        let classes = ByteClasses::empty();
1042        assert_eq!(classes.alphabet_len(), 2);
1043
1044        let elements = classes.elements(Unit::u8(0)).collect::<Vec<_>>();
1045        assert_eq!(elements.len(), 256);
1046        assert_eq!(elements[0], Unit::u8(b'\x00'));
1047        assert_eq!(elements[255], Unit::u8(b'\xFF'));
1048
1049        let elements = classes.elements(Unit::eoi(1)).collect::<Vec<_>>();
1050        assert_eq!(elements, vec![Unit::eoi(256)]);
1051    }
1052
1053    #[test]
1054    fn representatives() {
1055        let mut set = ByteClassSet::empty();
1056        set.set_range(b'b', b'd');
1057        set.set_range(b'g', b'm');
1058        set.set_range(b'z', b'z');
1059        let classes = set.byte_classes();
1060
1061        let got: Vec<Unit> = classes.representatives(..).collect();
1062        let expected = vec![
1063            Unit::u8(b'\x00'),
1064            Unit::u8(b'b'),
1065            Unit::u8(b'e'),
1066            Unit::u8(b'g'),
1067            Unit::u8(b'n'),
1068            Unit::u8(b'z'),
1069            Unit::u8(b'\x7B'),
1070            Unit::eoi(7),
1071        ];
1072        assert_eq!(expected, got);
1073
1074        let got: Vec<Unit> = classes.representatives(..0).collect();
1075        assert!(got.is_empty());
1076        let got: Vec<Unit> = classes.representatives(1..1).collect();
1077        assert!(got.is_empty());
1078        let got: Vec<Unit> = classes.representatives(255..255).collect();
1079        assert!(got.is_empty());
1080
1081        // A weird case that is the only guaranteed to way to get an iterator
1082        // of just the EOI class by excluding all possible byte values.
1083        let got: Vec<Unit> = classes
1084            .representatives((
1085                core::ops::Bound::Excluded(255),
1086                core::ops::Bound::Unbounded,
1087            ))
1088            .collect();
1089        let expected = vec![Unit::eoi(7)];
1090        assert_eq!(expected, got);
1091
1092        let got: Vec<Unit> = classes.representatives(..=255).collect();
1093        let expected = vec![
1094            Unit::u8(b'\x00'),
1095            Unit::u8(b'b'),
1096            Unit::u8(b'e'),
1097            Unit::u8(b'g'),
1098            Unit::u8(b'n'),
1099            Unit::u8(b'z'),
1100            Unit::u8(b'\x7B'),
1101        ];
1102        assert_eq!(expected, got);
1103
1104        let got: Vec<Unit> = classes.representatives(b'b'..=b'd').collect();
1105        let expected = vec![Unit::u8(b'b')];
1106        assert_eq!(expected, got);
1107
1108        let got: Vec<Unit> = classes.representatives(b'a'..=b'd').collect();
1109        let expected = vec![Unit::u8(b'a'), Unit::u8(b'b')];
1110        assert_eq!(expected, got);
1111
1112        let got: Vec<Unit> = classes.representatives(b'b'..=b'e').collect();
1113        let expected = vec![Unit::u8(b'b'), Unit::u8(b'e')];
1114        assert_eq!(expected, got);
1115
1116        let got: Vec<Unit> = classes.representatives(b'A'..=b'Z').collect();
1117        let expected = vec![Unit::u8(b'A')];
1118        assert_eq!(expected, got);
1119
1120        let got: Vec<Unit> = classes.representatives(b'A'..=b'z').collect();
1121        let expected = vec![
1122            Unit::u8(b'A'),
1123            Unit::u8(b'b'),
1124            Unit::u8(b'e'),
1125            Unit::u8(b'g'),
1126            Unit::u8(b'n'),
1127            Unit::u8(b'z'),
1128        ];
1129        assert_eq!(expected, got);
1130
1131        let got: Vec<Unit> = classes.representatives(b'z'..).collect();
1132        let expected = vec![Unit::u8(b'z'), Unit::u8(b'\x7B'), Unit::eoi(7)];
1133        assert_eq!(expected, got);
1134
1135        let got: Vec<Unit> = classes.representatives(b'z'..=0xFF).collect();
1136        let expected = vec![Unit::u8(b'z'), Unit::u8(b'\x7B')];
1137        assert_eq!(expected, got);
1138    }
1139}