jxl_modular/
ma.rs

1use std::collections::VecDeque;
2use std::sync::Arc;
3
4use jxl_bitstream::{unpack_signed, Bitstream};
5use jxl_coding::Decoder;
6use jxl_grid::{AllocHandle, AllocTracker};
7use jxl_oxide_common::Bundle;
8
9use super::predictor::{Predictor, Properties};
10use crate::{sample::Sealed, Result, Sample};
11
12/// Meta-adaptive tree configuration.
13///
14/// Meta-adaptive (MA) tree is a decision tree that controls how the sample is decoded in the given
15/// context. The configuration consists of two components: the MA tree itself, and the distribution
16/// information of an entropy decoder. These components are read from the bitstream.
17#[derive(Debug, Clone)]
18pub struct MaConfig {
19    num_tree_nodes: usize,
20    tree_depth: usize,
21    tree: Arc<(MaTreeNode, Option<AllocHandle>)>,
22    decoder: Decoder,
23}
24
25impl MaConfig {
26    /// Returns the entropy decoder.
27    ///
28    /// The decoder should be cloned to be used for decoding.
29    pub fn decoder(&self) -> &Decoder {
30        &self.decoder
31    }
32
33    /// Creates a simplified MA tree with given channel index and stream index, which then can be
34    /// used to decode samples.
35    ///
36    /// The method will evaluate the tree with the given information and prune branches which are
37    /// always not taken.
38    pub fn make_flat_tree(&self, channel: u32, stream_idx: u32, prev_channels: u32) -> FlatMaTree {
39        let nodes = self.tree.0.flatten(channel, stream_idx, prev_channels);
40        FlatMaTree::new(nodes)
41    }
42}
43
44impl MaConfig {
45    /// Returns the number of MA tree nodes.
46    #[inline]
47    pub fn num_tree_nodes(&self) -> usize {
48        self.num_tree_nodes
49    }
50
51    /// Returns the maximum distance from root to any leaf node.
52    #[inline]
53    pub fn tree_depth(&self) -> usize {
54        self.tree_depth
55    }
56}
57
58/// Parameters for decoding [`MaConfig`].
59#[derive(Debug, Copy, Clone)]
60pub struct MaConfigParams<'a> {
61    /// Allocation tracker.
62    pub tracker: Option<&'a AllocTracker>,
63    /// Maximum number of meta-adaptive tree nodes.
64    pub node_limit: usize,
65    pub depth_limit: usize,
66}
67
68impl Bundle<MaConfigParams<'_>> for MaConfig {
69    type Error = crate::Error;
70
71    fn parse(bitstream: &mut Bitstream, params: MaConfigParams) -> crate::Result<Self> {
72        struct FoldingTreeLeaf {
73            ctx: u32,
74            predictor: super::predictor::Predictor,
75            offset: i32,
76            multiplier: u32,
77        }
78
79        enum FoldingTree {
80            Decision(u32, i32),
81            Leaf(FoldingTreeLeaf),
82        }
83
84        let MaConfigParams {
85            tracker,
86            node_limit,
87            depth_limit,
88        } = params;
89
90        let mut tree_decoder = Decoder::parse(bitstream, 6)?;
91        if is_infinite_tree_dist(&tree_decoder) {
92            tracing::error!("Infinite MA tree");
93            return Err(crate::Error::InvalidMaTree);
94        }
95
96        let mut ctx = 0u32;
97        let mut nodes_left = 1usize;
98        let mut tmp_alloc_handle = tracker
99            .map(|tracker| tracker.alloc::<FoldingTree>(16))
100            .transpose()?;
101        let mut nodes = Vec::with_capacity(16);
102        let mut max_depth = 1usize;
103
104        tree_decoder.begin(bitstream)?;
105        while nodes_left > 0 {
106            if nodes.len() >= (1 << 26) {
107                return Err(crate::Error::InvalidMaTree);
108            }
109            if nodes.len() > node_limit {
110                tracing::error!(node_limit, "MA tree limit exceeded");
111                return Err(
112                    jxl_bitstream::Error::ProfileConformance("MA tree limit exceeded").into(),
113                );
114            }
115
116            if nodes.len() == nodes.capacity() && tmp_alloc_handle.is_some() {
117                let tracker = tracker.unwrap();
118                let current_len = nodes.len();
119                if current_len <= 16 {
120                    drop(tmp_alloc_handle);
121                    tmp_alloc_handle = Some(tracker.alloc::<FoldingTree>(256)?);
122                    nodes.reserve(256 - current_len);
123                } else if current_len <= 256 {
124                    drop(tmp_alloc_handle);
125                    tmp_alloc_handle = Some(tracker.alloc::<FoldingTree>(1024)?);
126                    nodes.reserve(1024 - current_len);
127                } else {
128                    drop(tmp_alloc_handle);
129                    tmp_alloc_handle = Some(tracker.alloc::<FoldingTree>(current_len * 2)?);
130                    nodes.reserve(current_len);
131                }
132            }
133
134            nodes_left -= 1;
135            let property = tree_decoder.read_varint(bitstream, 1)?;
136            let node = if let Some(property) = property.checked_sub(1) {
137                let value = unpack_signed(tree_decoder.read_varint(bitstream, 0)?);
138                let node = FoldingTree::Decision(property, value);
139                nodes_left += 2;
140                node
141            } else {
142                let predictor = tree_decoder.read_varint(bitstream, 2)?;
143                let predictor = Predictor::try_from(predictor)?;
144                let offset = unpack_signed(tree_decoder.read_varint(bitstream, 3)?);
145                let mul_log = tree_decoder.read_varint(bitstream, 4)?;
146                if mul_log > 30 {
147                    return Err(crate::Error::InvalidMaTree);
148                }
149                let mul_bits = tree_decoder.read_varint(bitstream, 5)?;
150                if mul_bits > (1 << (31 - mul_log)) - 2 {
151                    return Err(crate::Error::InvalidMaTree);
152                }
153                let multiplier = (mul_bits + 1) << mul_log;
154                let node = FoldingTree::Leaf(FoldingTreeLeaf {
155                    ctx,
156                    predictor,
157                    offset,
158                    multiplier,
159                });
160                ctx += 1;
161                node
162            };
163            nodes.push(node);
164            max_depth = max_depth.max(nodes_left);
165        }
166        tree_decoder.finalize()?;
167        let num_tree_nodes = nodes.len();
168        let decoder = Decoder::parse(bitstream, ctx)?;
169        let cluster_map = decoder.cluster_map();
170
171        let tree_alloc_handle = tracker
172            .map(|tracker| tracker.alloc::<FoldingTree>(nodes.len()))
173            .transpose()?;
174        let mut tmp = VecDeque::<(_, usize)>::with_capacity(max_depth);
175        for node in nodes.into_iter().rev() {
176            match node {
177                FoldingTree::Decision(property, value) => {
178                    let (right, dr) = tmp.pop_front().unwrap();
179                    let (left, dl) = tmp.pop_front().unwrap();
180                    let node = Box::new(MaTreeNode::Decision {
181                        property,
182                        value,
183                        left,
184                        right,
185                    });
186                    let depth = dr.max(dl) + 1;
187                    if depth > depth_limit {
188                        tracing::error!(depth_limit, "Decoded MA tree is too deep");
189                        return Err(jxl_bitstream::Error::ProfileConformance(
190                            "decoded MA tree is too deep",
191                        )
192                        .into());
193                    }
194
195                    tmp.push_back((node, depth));
196                }
197                FoldingTree::Leaf(FoldingTreeLeaf {
198                    ctx,
199                    predictor,
200                    offset,
201                    multiplier,
202                }) => {
203                    let cluster = cluster_map[ctx as usize];
204                    let leaf = MaTreeLeafClustered {
205                        cluster,
206                        predictor,
207                        offset,
208                        multiplier,
209                    };
210                    let node = Box::new(MaTreeNode::Leaf(leaf));
211                    tmp.push_back((node, 0));
212                }
213            }
214        }
215        assert_eq!(tmp.len(), 1);
216        let (tree, tree_depth) = tmp.pop_front().unwrap();
217        let tree = *tree;
218
219        Ok(Self {
220            num_tree_nodes,
221            tree_depth,
222            tree: Arc::new((tree, tree_alloc_handle)),
223            decoder,
224        })
225    }
226}
227
228fn is_infinite_tree_dist(decoder: &Decoder) -> bool {
229    let cluster_map = decoder.cluster_map();
230
231    // Distribution #1 decides whether it's decision node or leaf node; if it reads 0 it's a leaf
232    // node. Therefore, the tree is infinitely large if the dist always reads token other than 0.
233    let cluster = cluster_map[1];
234    let Some(token) = decoder.single_token(cluster) else {
235        return false;
236    };
237    token != 0
238}
239
240/// A "flat" meta-adaptive tree, constructed with [`MaConfig::make_flat_tree`].
241#[derive(Debug)]
242pub struct FlatMaTree {
243    nodes: Vec<FlatMaTreeNode>,
244    need_self_correcting: bool,
245    max_prev_channel_depth: usize,
246}
247
248#[derive(Debug)]
249enum FlatMaTreeNode {
250    FusedDecision {
251        prop_level0: u32,
252        value_level0: i32,
253        props_level1: (u32, u32),
254        values_level1: (i32, i32),
255        index_base: u32,
256    },
257    Table {
258        prop: u32,
259        value_base: i32,
260        indices: Box<[u32]>,
261    },
262    Leaf(MaTreeLeafClustered),
263}
264
265#[derive(Debug, Clone, PartialEq, Eq)]
266pub(crate) struct MaTreeLeafClustered {
267    pub(crate) cluster: u8,
268    pub(crate) predictor: super::predictor::Predictor,
269    pub(crate) offset: i32,
270    pub(crate) multiplier: u32,
271}
272
273impl FlatMaTree {
274    fn new(nodes: Vec<FlatMaTreeNode>) -> Self {
275        let need_self_correcting = nodes.iter().any(|node| match *node {
276            FlatMaTreeNode::FusedDecision {
277                prop_level0: p,
278                props_level1: (pl, pr),
279                ..
280            } => p == 15 || pl == 15 || pr == 15,
281            FlatMaTreeNode::Table { prop, .. } => prop == 15,
282            FlatMaTreeNode::Leaf(MaTreeLeafClustered { predictor, .. }) => {
283                predictor == Predictor::SelfCorrecting
284            }
285        });
286
287        let mut max_prev_channel_depth = 0usize;
288        for node in &nodes {
289            if let FlatMaTreeNode::FusedDecision {
290                prop_level0: p,
291                props_level1: (pl, pr),
292                ..
293            } = *node
294            {
295                if let Some(p) = p.checked_sub(16) {
296                    max_prev_channel_depth = max_prev_channel_depth.max((p as usize / 4) + 1);
297                }
298                if let Some(p) = pl.checked_sub(16) {
299                    max_prev_channel_depth = max_prev_channel_depth.max((p as usize / 4) + 1);
300                }
301                if let Some(p) = pr.checked_sub(16) {
302                    max_prev_channel_depth = max_prev_channel_depth.max((p as usize / 4) + 1);
303                }
304            } else if let FlatMaTreeNode::Table { prop, .. } = *node {
305                if let Some(p) = prop.checked_sub(16) {
306                    max_prev_channel_depth = max_prev_channel_depth.max((p as usize / 4) + 1);
307                }
308            }
309        }
310
311        Self {
312            nodes,
313            need_self_correcting,
314            max_prev_channel_depth,
315        }
316    }
317
318    pub(crate) fn get_leaf<S: Sample>(&self, properties: &Properties<S>) -> &MaTreeLeafClustered {
319        let mut current_node = &self.nodes[0];
320        loop {
321            match current_node {
322                &FlatMaTreeNode::FusedDecision {
323                    prop_level0: p,
324                    value_level0: v,
325                    props_level1: (pl, pr),
326                    values_level1: (vl, vr),
327                    index_base,
328                } => {
329                    let p0v = properties.get(p as usize);
330                    let plv = properties.get(pl as usize);
331                    let prv = properties.get(pr as usize);
332                    let high_bit = p0v <= v;
333                    let l = (plv <= vl) as u32;
334                    let r = 2 | (prv <= vr) as u32;
335                    let next_node = index_base + if high_bit { r } else { l };
336                    current_node = &self.nodes[next_node as usize];
337                }
338                &FlatMaTreeNode::Table {
339                    prop,
340                    value_base,
341                    ref indices,
342                } => {
343                    let v = properties.get(prop as usize);
344                    let idx = v
345                        .saturating_sub(value_base)
346                        .clamp(0, indices.len() as i32 - 1) as usize;
347                    let next_node = indices[idx];
348                    current_node = &self.nodes[next_node as usize];
349                }
350                FlatMaTreeNode::Leaf(leaf) => return leaf,
351            }
352        }
353    }
354}
355
356impl FlatMaTree {
357    /// Returns whether self-correcting predictor should be initialized.
358    ///
359    /// The return value of this method can be used to optimize the decoding process, since
360    /// self-correcting predictors are computationally heavy.
361    #[inline]
362    pub fn need_self_correcting(&self) -> bool {
363        self.need_self_correcting
364    }
365
366    /// Returns the number of previously decoded channels needed in order to traverse the MA tree.
367    #[inline]
368    pub fn max_prev_channel_depth(&self) -> usize {
369        self.max_prev_channel_depth
370    }
371
372    /// Decode a sample with the given state.
373    pub fn decode_sample<S: Sample>(
374        &self,
375        bitstream: &mut Bitstream,
376        decoder: &mut Decoder,
377        properties: &Properties<S>,
378        dist_multiplier: u32,
379    ) -> Result<(i32, super::predictor::Predictor)> {
380        let leaf = self.get_leaf(properties);
381        let diff = decoder.read_varint_with_multiplier_clustered(
382            bitstream,
383            leaf.cluster,
384            dist_multiplier,
385        )?;
386        let diff = unpack_signed(diff).wrapping_muladd_i32(leaf.multiplier as i32, leaf.offset);
387        Ok((diff, leaf.predictor))
388    }
389
390    #[inline]
391    pub(crate) fn single_node(&self) -> Option<&MaTreeLeafClustered> {
392        match self.nodes.first() {
393            Some(FlatMaTreeNode::Leaf(node)) => Some(node),
394            _ => None,
395        }
396    }
397
398    pub(crate) fn simple_table(&self) -> Option<SimpleMaTable> {
399        let Some(&FlatMaTreeNode::Table {
400            prop: decision_prop,
401            value_base,
402            ref indices,
403        }) = self.nodes.first()
404        else {
405            return None;
406        };
407
408        let mut state: Option<(Predictor, i32, u32)> = None;
409        let mut cluster_table = Vec::with_capacity(indices.len());
410        for &index in &**indices {
411            let node = &self.nodes[index as usize];
412            let FlatMaTreeNode::Leaf(leaf) = node else {
413                return None;
414            };
415
416            let leaf_props = (leaf.predictor, leaf.offset, leaf.multiplier);
417            let &mut state = state.get_or_insert(leaf_props);
418            if leaf_props != state {
419                return None;
420            }
421
422            cluster_table.push(leaf.cluster);
423        }
424
425        let (predictor, offset, multiplier) = state.unwrap();
426        Some(SimpleMaTable {
427            decision_prop,
428            value_base,
429            predictor,
430            offset,
431            multiplier,
432            cluster_table: cluster_table.into_boxed_slice(),
433        })
434    }
435}
436
437#[derive(Debug)]
438pub(crate) struct SimpleMaTable {
439    pub(crate) decision_prop: u32,
440    pub(crate) value_base: i32,
441    pub(crate) predictor: Predictor,
442    pub(crate) offset: i32,
443    pub(crate) multiplier: u32,
444    pub(crate) cluster_table: Box<[u8]>,
445}
446
447#[derive(Debug)]
448enum MaTreeNode {
449    Decision {
450        property: u32,
451        value: i32,
452        left: Box<MaTreeNode>,
453        right: Box<MaTreeNode>,
454    },
455    Leaf(MaTreeLeafClustered),
456}
457
458impl MaTreeNode {
459    fn next_decision_node(&self, channel: u32, stream_idx: u32, prev_channels: u32) -> &MaTreeNode {
460        match *self {
461            MaTreeNode::Decision {
462                property: property @ (0 | 1),
463                value,
464                ref left,
465                ref right,
466            } => {
467                let target = if property == 0 { channel } else { stream_idx };
468                let node = if target as i32 > value { left } else { right };
469                node.next_decision_node(channel, stream_idx, prev_channels)
470            }
471            ref node @ MaTreeNode::Decision {
472                property,
473                value,
474                ref left,
475                ref right,
476            } if property >= 16 => {
477                let prev_channel_idx = (property - 16) / 4;
478                if prev_channel_idx >= prev_channels {
479                    let node = if value < 0 { left } else { right };
480                    node.next_decision_node(channel, stream_idx, prev_channels)
481                } else {
482                    node
483                }
484            }
485            ref node => node,
486        }
487    }
488
489    fn try_compile_to_table(
490        &self,
491        channel: u32,
492        stream_idx: u32,
493        prev_channels: u32,
494        next_index_base: u32,
495    ) -> Option<(FlatMaTreeNode, Vec<&MaTreeNode>)> {
496        let &MaTreeNode::Decision {
497            property,
498            value,
499            ref left,
500            ref right,
501        } = self
502        else {
503            return None;
504        };
505
506        let mut lower_bound = value;
507        let mut upper_bound = value;
508        let mut stack = vec![
509            (&**left, (value + 1)..=i32::MAX),
510            (&**right, i32::MIN..=value),
511        ];
512        let mut range_nodes = Vec::new();
513        while let Some((node, range)) = stack.pop() {
514            let node = node.next_decision_node(channel, stream_idx, prev_channels);
515            let (value, left, right) = match node {
516                &MaTreeNode::Decision {
517                    property: target_property,
518                    value,
519                    ref left,
520                    ref right,
521                } if target_property == property => (value, left, right),
522                _ => {
523                    range_nodes.push((node, *range.end()));
524                    continue;
525                }
526            };
527            let new_lower_bound = lower_bound.min(value);
528            let new_upper_bound = upper_bound.max(value);
529            if new_upper_bound.abs_diff(new_lower_bound) > 1024 - 2 {
530                range_nodes.push((node, *range.end()));
531                continue;
532            }
533            lower_bound = new_lower_bound;
534            upper_bound = new_upper_bound;
535
536            let left_range = (value + 1)..=(*range.end());
537            let right_range = (*range.start())..=value;
538            if !left_range.is_empty() {
539                stack.push((&**left, left_range));
540            }
541            if !right_range.is_empty() {
542                stack.push((&**right, right_range));
543            }
544        }
545        if range_nodes.len() < 4 {
546            return None;
547        }
548
549        range_nodes.sort_unstable_by_key(|(_, range_end)| *range_end);
550
551        let index_count = upper_bound.abs_diff(lower_bound) as usize + 2;
552        let mut indices = vec![0u32; index_count];
553        let mut nodes = Vec::with_capacity(range_nodes.len());
554
555        let mut range_start = lower_bound - 1;
556        let mut next_index = 0usize;
557        for (idx, (node, range_end)) in range_nodes.into_iter().enumerate() {
558            if range_end == i32::MAX {
559                *indices.last_mut().unwrap() = next_index_base + idx as u32;
560                nodes.push(node);
561                break;
562            }
563            let len = range_end.abs_diff(range_start) as usize;
564            let end_index = next_index + len;
565            indices[next_index..end_index].fill(next_index_base + idx as u32);
566            nodes.push(node);
567            next_index = end_index;
568            range_start = range_end;
569        }
570
571        let node = FlatMaTreeNode::Table {
572            prop: property,
573            value_base: lower_bound,
574            indices: indices.into_boxed_slice(),
575        };
576        Some((node, nodes))
577    }
578
579    fn flatten(&self, channel: u32, stream_idx: u32, prev_channels: u32) -> Vec<FlatMaTreeNode> {
580        let target = self.next_decision_node(channel, stream_idx, prev_channels);
581        let mut q = std::collections::VecDeque::new();
582        q.push_back(target);
583
584        let mut out = Vec::new();
585        let mut next_base = 1u32;
586        while let Some(target) = q.pop_front() {
587            let target = target.next_decision_node(channel, stream_idx, prev_channels);
588            if let Some((out_node, nodes)) =
589                target.try_compile_to_table(channel, stream_idx, prev_channels, next_base)
590            {
591                let len = nodes.len() as u32;
592                out.push(out_node);
593                q.extend(nodes);
594                next_base += len;
595                continue;
596            }
597
598            match *target {
599                MaTreeNode::Decision {
600                    property,
601                    value,
602                    ref left,
603                    ref right,
604                } => {
605                    let left = left.next_decision_node(channel, stream_idx, prev_channels);
606                    let (lp, lv, ll, lr) = match left {
607                        &MaTreeNode::Decision {
608                            property,
609                            value,
610                            ref left,
611                            ref right,
612                        } => (property, value, &**left, &**right),
613                        node => (0, 0, node, node),
614                    };
615                    let right = right.next_decision_node(channel, stream_idx, prev_channels);
616                    let (rp, rv, rl, rr) = match right {
617                        &MaTreeNode::Decision {
618                            property,
619                            value,
620                            ref left,
621                            ref right,
622                        } => (property, value, &**left, &**right),
623                        node => (0, 0, node, node),
624                    };
625                    out.push(FlatMaTreeNode::FusedDecision {
626                        prop_level0: property,
627                        value_level0: value,
628                        props_level1: (lp, rp),
629                        values_level1: (lv, rv),
630                        index_base: next_base,
631                    });
632                    q.push_back(ll);
633                    q.push_back(lr);
634                    q.push_back(rl);
635                    q.push_back(rr);
636                    next_base += 4;
637                }
638                MaTreeNode::Leaf(ref leaf) => {
639                    out.push(FlatMaTreeNode::Leaf(leaf.clone()));
640                }
641            }
642        }
643
644        out
645    }
646}