1use std::collections::VecDeque;
2use std::sync::Arc;
3
4use jxl_bitstream::{unpack_signed, Bitstream};
5use jxl_coding::Decoder;
6use jxl_grid::{AllocHandle, AllocTracker};
7use jxl_oxide_common::Bundle;
8
9use super::predictor::{Predictor, Properties};
10use crate::{sample::Sealed, Result, Sample};
11
12#[derive(Debug, Clone)]
18pub struct MaConfig {
19 num_tree_nodes: usize,
20 tree_depth: usize,
21 tree: Arc<(MaTreeNode, Option<AllocHandle>)>,
22 decoder: Decoder,
23}
24
25impl MaConfig {
26 pub fn decoder(&self) -> &Decoder {
30 &self.decoder
31 }
32
33 pub fn make_flat_tree(&self, channel: u32, stream_idx: u32, prev_channels: u32) -> FlatMaTree {
39 let nodes = self.tree.0.flatten(channel, stream_idx, prev_channels);
40 FlatMaTree::new(nodes)
41 }
42}
43
44impl MaConfig {
45 #[inline]
47 pub fn num_tree_nodes(&self) -> usize {
48 self.num_tree_nodes
49 }
50
51 #[inline]
53 pub fn tree_depth(&self) -> usize {
54 self.tree_depth
55 }
56}
57
58#[derive(Debug, Copy, Clone)]
60pub struct MaConfigParams<'a> {
61 pub tracker: Option<&'a AllocTracker>,
63 pub node_limit: usize,
65 pub depth_limit: usize,
66}
67
68impl Bundle<MaConfigParams<'_>> for MaConfig {
69 type Error = crate::Error;
70
71 fn parse(bitstream: &mut Bitstream, params: MaConfigParams) -> crate::Result<Self> {
72 struct FoldingTreeLeaf {
73 ctx: u32,
74 predictor: super::predictor::Predictor,
75 offset: i32,
76 multiplier: u32,
77 }
78
79 enum FoldingTree {
80 Decision(u32, i32),
81 Leaf(FoldingTreeLeaf),
82 }
83
84 let MaConfigParams {
85 tracker,
86 node_limit,
87 depth_limit,
88 } = params;
89
90 let mut tree_decoder = Decoder::parse(bitstream, 6)?;
91 if is_infinite_tree_dist(&tree_decoder) {
92 tracing::error!("Infinite MA tree");
93 return Err(crate::Error::InvalidMaTree);
94 }
95
96 let mut ctx = 0u32;
97 let mut nodes_left = 1usize;
98 let mut tmp_alloc_handle = tracker
99 .map(|tracker| tracker.alloc::<FoldingTree>(16))
100 .transpose()?;
101 let mut nodes = Vec::with_capacity(16);
102 let mut max_depth = 1usize;
103
104 tree_decoder.begin(bitstream)?;
105 while nodes_left > 0 {
106 if nodes.len() >= (1 << 26) {
107 return Err(crate::Error::InvalidMaTree);
108 }
109 if nodes.len() > node_limit {
110 tracing::error!(node_limit, "MA tree limit exceeded");
111 return Err(
112 jxl_bitstream::Error::ProfileConformance("MA tree limit exceeded").into(),
113 );
114 }
115
116 if nodes.len() == nodes.capacity() && tmp_alloc_handle.is_some() {
117 let tracker = tracker.unwrap();
118 let current_len = nodes.len();
119 if current_len <= 16 {
120 drop(tmp_alloc_handle);
121 tmp_alloc_handle = Some(tracker.alloc::<FoldingTree>(256)?);
122 nodes.reserve(256 - current_len);
123 } else if current_len <= 256 {
124 drop(tmp_alloc_handle);
125 tmp_alloc_handle = Some(tracker.alloc::<FoldingTree>(1024)?);
126 nodes.reserve(1024 - current_len);
127 } else {
128 drop(tmp_alloc_handle);
129 tmp_alloc_handle = Some(tracker.alloc::<FoldingTree>(current_len * 2)?);
130 nodes.reserve(current_len);
131 }
132 }
133
134 nodes_left -= 1;
135 let property = tree_decoder.read_varint(bitstream, 1)?;
136 let node = if let Some(property) = property.checked_sub(1) {
137 let value = unpack_signed(tree_decoder.read_varint(bitstream, 0)?);
138 let node = FoldingTree::Decision(property, value);
139 nodes_left += 2;
140 node
141 } else {
142 let predictor = tree_decoder.read_varint(bitstream, 2)?;
143 let predictor = Predictor::try_from(predictor)?;
144 let offset = unpack_signed(tree_decoder.read_varint(bitstream, 3)?);
145 let mul_log = tree_decoder.read_varint(bitstream, 4)?;
146 if mul_log > 30 {
147 return Err(crate::Error::InvalidMaTree);
148 }
149 let mul_bits = tree_decoder.read_varint(bitstream, 5)?;
150 if mul_bits > (1 << (31 - mul_log)) - 2 {
151 return Err(crate::Error::InvalidMaTree);
152 }
153 let multiplier = (mul_bits + 1) << mul_log;
154 let node = FoldingTree::Leaf(FoldingTreeLeaf {
155 ctx,
156 predictor,
157 offset,
158 multiplier,
159 });
160 ctx += 1;
161 node
162 };
163 nodes.push(node);
164 max_depth = max_depth.max(nodes_left);
165 }
166 tree_decoder.finalize()?;
167 let num_tree_nodes = nodes.len();
168 let decoder = Decoder::parse(bitstream, ctx)?;
169 let cluster_map = decoder.cluster_map();
170
171 let tree_alloc_handle = tracker
172 .map(|tracker| tracker.alloc::<FoldingTree>(nodes.len()))
173 .transpose()?;
174 let mut tmp = VecDeque::<(_, usize)>::with_capacity(max_depth);
175 for node in nodes.into_iter().rev() {
176 match node {
177 FoldingTree::Decision(property, value) => {
178 let (right, dr) = tmp.pop_front().unwrap();
179 let (left, dl) = tmp.pop_front().unwrap();
180 let node = Box::new(MaTreeNode::Decision {
181 property,
182 value,
183 left,
184 right,
185 });
186 let depth = dr.max(dl) + 1;
187 if depth > depth_limit {
188 tracing::error!(depth_limit, "Decoded MA tree is too deep");
189 return Err(jxl_bitstream::Error::ProfileConformance(
190 "decoded MA tree is too deep",
191 )
192 .into());
193 }
194
195 tmp.push_back((node, depth));
196 }
197 FoldingTree::Leaf(FoldingTreeLeaf {
198 ctx,
199 predictor,
200 offset,
201 multiplier,
202 }) => {
203 let cluster = cluster_map[ctx as usize];
204 let leaf = MaTreeLeafClustered {
205 cluster,
206 predictor,
207 offset,
208 multiplier,
209 };
210 let node = Box::new(MaTreeNode::Leaf(leaf));
211 tmp.push_back((node, 0));
212 }
213 }
214 }
215 assert_eq!(tmp.len(), 1);
216 let (tree, tree_depth) = tmp.pop_front().unwrap();
217 let tree = *tree;
218
219 Ok(Self {
220 num_tree_nodes,
221 tree_depth,
222 tree: Arc::new((tree, tree_alloc_handle)),
223 decoder,
224 })
225 }
226}
227
228fn is_infinite_tree_dist(decoder: &Decoder) -> bool {
229 let cluster_map = decoder.cluster_map();
230
231 let cluster = cluster_map[1];
234 let Some(token) = decoder.single_token(cluster) else {
235 return false;
236 };
237 token != 0
238}
239
240#[derive(Debug)]
242pub struct FlatMaTree {
243 nodes: Vec<FlatMaTreeNode>,
244 need_self_correcting: bool,
245 max_prev_channel_depth: usize,
246}
247
248#[derive(Debug)]
249enum FlatMaTreeNode {
250 FusedDecision {
251 prop_level0: u32,
252 value_level0: i32,
253 props_level1: (u32, u32),
254 values_level1: (i32, i32),
255 index_base: u32,
256 },
257 Table {
258 prop: u32,
259 value_base: i32,
260 indices: Box<[u32]>,
261 },
262 Leaf(MaTreeLeafClustered),
263}
264
265#[derive(Debug, Clone, PartialEq, Eq)]
266pub(crate) struct MaTreeLeafClustered {
267 pub(crate) cluster: u8,
268 pub(crate) predictor: super::predictor::Predictor,
269 pub(crate) offset: i32,
270 pub(crate) multiplier: u32,
271}
272
273impl FlatMaTree {
274 fn new(nodes: Vec<FlatMaTreeNode>) -> Self {
275 let need_self_correcting = nodes.iter().any(|node| match *node {
276 FlatMaTreeNode::FusedDecision {
277 prop_level0: p,
278 props_level1: (pl, pr),
279 ..
280 } => p == 15 || pl == 15 || pr == 15,
281 FlatMaTreeNode::Table { prop, .. } => prop == 15,
282 FlatMaTreeNode::Leaf(MaTreeLeafClustered { predictor, .. }) => {
283 predictor == Predictor::SelfCorrecting
284 }
285 });
286
287 let mut max_prev_channel_depth = 0usize;
288 for node in &nodes {
289 if let FlatMaTreeNode::FusedDecision {
290 prop_level0: p,
291 props_level1: (pl, pr),
292 ..
293 } = *node
294 {
295 if let Some(p) = p.checked_sub(16) {
296 max_prev_channel_depth = max_prev_channel_depth.max((p as usize / 4) + 1);
297 }
298 if let Some(p) = pl.checked_sub(16) {
299 max_prev_channel_depth = max_prev_channel_depth.max((p as usize / 4) + 1);
300 }
301 if let Some(p) = pr.checked_sub(16) {
302 max_prev_channel_depth = max_prev_channel_depth.max((p as usize / 4) + 1);
303 }
304 } else if let FlatMaTreeNode::Table { prop, .. } = *node {
305 if let Some(p) = prop.checked_sub(16) {
306 max_prev_channel_depth = max_prev_channel_depth.max((p as usize / 4) + 1);
307 }
308 }
309 }
310
311 Self {
312 nodes,
313 need_self_correcting,
314 max_prev_channel_depth,
315 }
316 }
317
318 pub(crate) fn get_leaf<S: Sample>(&self, properties: &Properties<S>) -> &MaTreeLeafClustered {
319 let mut current_node = &self.nodes[0];
320 loop {
321 match current_node {
322 &FlatMaTreeNode::FusedDecision {
323 prop_level0: p,
324 value_level0: v,
325 props_level1: (pl, pr),
326 values_level1: (vl, vr),
327 index_base,
328 } => {
329 let p0v = properties.get(p as usize);
330 let plv = properties.get(pl as usize);
331 let prv = properties.get(pr as usize);
332 let high_bit = p0v <= v;
333 let l = (plv <= vl) as u32;
334 let r = 2 | (prv <= vr) as u32;
335 let next_node = index_base + if high_bit { r } else { l };
336 current_node = &self.nodes[next_node as usize];
337 }
338 &FlatMaTreeNode::Table {
339 prop,
340 value_base,
341 ref indices,
342 } => {
343 let v = properties.get(prop as usize);
344 let idx = v
345 .saturating_sub(value_base)
346 .clamp(0, indices.len() as i32 - 1) as usize;
347 let next_node = indices[idx];
348 current_node = &self.nodes[next_node as usize];
349 }
350 FlatMaTreeNode::Leaf(leaf) => return leaf,
351 }
352 }
353 }
354}
355
356impl FlatMaTree {
357 #[inline]
362 pub fn need_self_correcting(&self) -> bool {
363 self.need_self_correcting
364 }
365
366 #[inline]
368 pub fn max_prev_channel_depth(&self) -> usize {
369 self.max_prev_channel_depth
370 }
371
372 pub fn decode_sample<S: Sample>(
374 &self,
375 bitstream: &mut Bitstream,
376 decoder: &mut Decoder,
377 properties: &Properties<S>,
378 dist_multiplier: u32,
379 ) -> Result<(i32, super::predictor::Predictor)> {
380 let leaf = self.get_leaf(properties);
381 let diff = decoder.read_varint_with_multiplier_clustered(
382 bitstream,
383 leaf.cluster,
384 dist_multiplier,
385 )?;
386 let diff = unpack_signed(diff).wrapping_muladd_i32(leaf.multiplier as i32, leaf.offset);
387 Ok((diff, leaf.predictor))
388 }
389
390 #[inline]
391 pub(crate) fn single_node(&self) -> Option<&MaTreeLeafClustered> {
392 match self.nodes.first() {
393 Some(FlatMaTreeNode::Leaf(node)) => Some(node),
394 _ => None,
395 }
396 }
397
398 pub(crate) fn simple_table(&self) -> Option<SimpleMaTable> {
399 let Some(&FlatMaTreeNode::Table {
400 prop: decision_prop,
401 value_base,
402 ref indices,
403 }) = self.nodes.first()
404 else {
405 return None;
406 };
407
408 let mut state: Option<(Predictor, i32, u32)> = None;
409 let mut cluster_table = Vec::with_capacity(indices.len());
410 for &index in &**indices {
411 let node = &self.nodes[index as usize];
412 let FlatMaTreeNode::Leaf(leaf) = node else {
413 return None;
414 };
415
416 let leaf_props = (leaf.predictor, leaf.offset, leaf.multiplier);
417 let &mut state = state.get_or_insert(leaf_props);
418 if leaf_props != state {
419 return None;
420 }
421
422 cluster_table.push(leaf.cluster);
423 }
424
425 let (predictor, offset, multiplier) = state.unwrap();
426 Some(SimpleMaTable {
427 decision_prop,
428 value_base,
429 predictor,
430 offset,
431 multiplier,
432 cluster_table: cluster_table.into_boxed_slice(),
433 })
434 }
435}
436
437#[derive(Debug)]
438pub(crate) struct SimpleMaTable {
439 pub(crate) decision_prop: u32,
440 pub(crate) value_base: i32,
441 pub(crate) predictor: Predictor,
442 pub(crate) offset: i32,
443 pub(crate) multiplier: u32,
444 pub(crate) cluster_table: Box<[u8]>,
445}
446
447#[derive(Debug)]
448enum MaTreeNode {
449 Decision {
450 property: u32,
451 value: i32,
452 left: Box<MaTreeNode>,
453 right: Box<MaTreeNode>,
454 },
455 Leaf(MaTreeLeafClustered),
456}
457
458impl MaTreeNode {
459 fn next_decision_node(&self, channel: u32, stream_idx: u32, prev_channels: u32) -> &MaTreeNode {
460 match *self {
461 MaTreeNode::Decision {
462 property: property @ (0 | 1),
463 value,
464 ref left,
465 ref right,
466 } => {
467 let target = if property == 0 { channel } else { stream_idx };
468 let node = if target as i32 > value { left } else { right };
469 node.next_decision_node(channel, stream_idx, prev_channels)
470 }
471 ref node @ MaTreeNode::Decision {
472 property,
473 value,
474 ref left,
475 ref right,
476 } if property >= 16 => {
477 let prev_channel_idx = (property - 16) / 4;
478 if prev_channel_idx >= prev_channels {
479 let node = if value < 0 { left } else { right };
480 node.next_decision_node(channel, stream_idx, prev_channels)
481 } else {
482 node
483 }
484 }
485 ref node => node,
486 }
487 }
488
489 fn try_compile_to_table(
490 &self,
491 channel: u32,
492 stream_idx: u32,
493 prev_channels: u32,
494 next_index_base: u32,
495 ) -> Option<(FlatMaTreeNode, Vec<&MaTreeNode>)> {
496 let &MaTreeNode::Decision {
497 property,
498 value,
499 ref left,
500 ref right,
501 } = self
502 else {
503 return None;
504 };
505
506 let mut lower_bound = value;
507 let mut upper_bound = value;
508 let mut stack = vec![
509 (&**left, (value + 1)..=i32::MAX),
510 (&**right, i32::MIN..=value),
511 ];
512 let mut range_nodes = Vec::new();
513 while let Some((node, range)) = stack.pop() {
514 let node = node.next_decision_node(channel, stream_idx, prev_channels);
515 let (value, left, right) = match node {
516 &MaTreeNode::Decision {
517 property: target_property,
518 value,
519 ref left,
520 ref right,
521 } if target_property == property => (value, left, right),
522 _ => {
523 range_nodes.push((node, *range.end()));
524 continue;
525 }
526 };
527 let new_lower_bound = lower_bound.min(value);
528 let new_upper_bound = upper_bound.max(value);
529 if new_upper_bound.abs_diff(new_lower_bound) > 1024 - 2 {
530 range_nodes.push((node, *range.end()));
531 continue;
532 }
533 lower_bound = new_lower_bound;
534 upper_bound = new_upper_bound;
535
536 let left_range = (value + 1)..=(*range.end());
537 let right_range = (*range.start())..=value;
538 if !left_range.is_empty() {
539 stack.push((&**left, left_range));
540 }
541 if !right_range.is_empty() {
542 stack.push((&**right, right_range));
543 }
544 }
545 if range_nodes.len() < 4 {
546 return None;
547 }
548
549 range_nodes.sort_unstable_by_key(|(_, range_end)| *range_end);
550
551 let index_count = upper_bound.abs_diff(lower_bound) as usize + 2;
552 let mut indices = vec![0u32; index_count];
553 let mut nodes = Vec::with_capacity(range_nodes.len());
554
555 let mut range_start = lower_bound - 1;
556 let mut next_index = 0usize;
557 for (idx, (node, range_end)) in range_nodes.into_iter().enumerate() {
558 if range_end == i32::MAX {
559 *indices.last_mut().unwrap() = next_index_base + idx as u32;
560 nodes.push(node);
561 break;
562 }
563 let len = range_end.abs_diff(range_start) as usize;
564 let end_index = next_index + len;
565 indices[next_index..end_index].fill(next_index_base + idx as u32);
566 nodes.push(node);
567 next_index = end_index;
568 range_start = range_end;
569 }
570
571 let node = FlatMaTreeNode::Table {
572 prop: property,
573 value_base: lower_bound,
574 indices: indices.into_boxed_slice(),
575 };
576 Some((node, nodes))
577 }
578
579 fn flatten(&self, channel: u32, stream_idx: u32, prev_channels: u32) -> Vec<FlatMaTreeNode> {
580 let target = self.next_decision_node(channel, stream_idx, prev_channels);
581 let mut q = std::collections::VecDeque::new();
582 q.push_back(target);
583
584 let mut out = Vec::new();
585 let mut next_base = 1u32;
586 while let Some(target) = q.pop_front() {
587 let target = target.next_decision_node(channel, stream_idx, prev_channels);
588 if let Some((out_node, nodes)) =
589 target.try_compile_to_table(channel, stream_idx, prev_channels, next_base)
590 {
591 let len = nodes.len() as u32;
592 out.push(out_node);
593 q.extend(nodes);
594 next_base += len;
595 continue;
596 }
597
598 match *target {
599 MaTreeNode::Decision {
600 property,
601 value,
602 ref left,
603 ref right,
604 } => {
605 let left = left.next_decision_node(channel, stream_idx, prev_channels);
606 let (lp, lv, ll, lr) = match left {
607 &MaTreeNode::Decision {
608 property,
609 value,
610 ref left,
611 ref right,
612 } => (property, value, &**left, &**right),
613 node => (0, 0, node, node),
614 };
615 let right = right.next_decision_node(channel, stream_idx, prev_channels);
616 let (rp, rv, rl, rr) = match right {
617 &MaTreeNode::Decision {
618 property,
619 value,
620 ref left,
621 ref right,
622 } => (property, value, &**left, &**right),
623 node => (0, 0, node, node),
624 };
625 out.push(FlatMaTreeNode::FusedDecision {
626 prop_level0: property,
627 value_level0: value,
628 props_level1: (lp, rp),
629 values_level1: (lv, rv),
630 index_base: next_base,
631 });
632 q.push_back(ll);
633 q.push_back(lr);
634 q.push_back(rl);
635 q.push_back(rr);
636 next_base += 4;
637 }
638 MaTreeNode::Leaf(ref leaf) => {
639 out.push(FlatMaTreeNode::Leaf(leaf.clone()));
640 }
641 }
642 }
643
644 out
645 }
646}