1use std::collections::HashMap;
2
3use jxl_bitstream::Bitstream;
4use jxl_coding::{Decoder, DecoderRleMode, RleToken};
5use jxl_grid::{AlignedGrid, AllocTracker, MutableSubgrid};
6
7use crate::{
8 ma::{FlatMaTree, MaTreeLeafClustered, SimpleMaTable},
9 predictor::{Predictor, PredictorState, Properties, WpHeader},
10 sample::Sample,
11 MaConfig, ModularChannelInfo, ModularChannels, ModularHeader, Result,
12};
13
14#[derive(Debug)]
15pub enum TransformedGrid<'dest, S: Sample> {
16 Single(MutableSubgrid<'dest, S>),
17 Merged {
18 leader: MutableSubgrid<'dest, S>,
19 members: Vec<TransformedGrid<'dest, S>>,
20 },
21}
22
23impl<'dest, S: Sample> From<MutableSubgrid<'dest, S>> for TransformedGrid<'dest, S> {
24 fn from(value: MutableSubgrid<'dest, S>) -> Self {
25 Self::Single(value)
26 }
27}
28
29impl<S: Sample> TransformedGrid<'_, S> {
30 fn reborrow(&mut self) -> TransformedGrid<S> {
31 match self {
32 TransformedGrid::Single(g) => TransformedGrid::Single(g.split_horizontal(0).1),
33 TransformedGrid::Merged { leader, .. } => {
34 TransformedGrid::Single(leader.split_horizontal(0).1)
35 }
36 }
37 }
38}
39
40impl<'dest, S: Sample> TransformedGrid<'dest, S> {
41 pub(crate) fn grid(&self) -> &MutableSubgrid<'dest, S> {
42 match self {
43 Self::Single(g) => g,
44 Self::Merged { leader, .. } => leader,
45 }
46 }
47
48 pub(crate) fn grid_mut(&mut self) -> &mut MutableSubgrid<'dest, S> {
49 match self {
50 Self::Single(g) => g,
51 Self::Merged { leader, .. } => leader,
52 }
53 }
54
55 pub(crate) fn merge(&mut self, members: Vec<TransformedGrid<'dest, S>>) {
56 if members.is_empty() {
57 return;
58 }
59
60 match self {
61 Self::Single(leader) => {
62 let tmp = MutableSubgrid::empty();
63 let leader = std::mem::replace(leader, tmp);
64 *self = Self::Merged { leader, members };
65 }
66 Self::Merged {
67 members: original_members,
68 ..
69 } => {
70 original_members.extend(members);
71 }
72 }
73 }
74
75 pub(crate) fn unmerge(&mut self, count: usize) -> Vec<TransformedGrid<'dest, S>> {
76 if count == 0 {
77 return Vec::new();
78 }
79
80 match self {
81 Self::Single(_) => panic!("cannot unmerge TransformedGrid::Single"),
82 Self::Merged { leader, members } => {
83 let len = members.len();
84 let members = members.drain((len - count)..).collect();
85 if len == count {
86 let tmp = MutableSubgrid::empty();
87 let leader = std::mem::replace(leader, tmp);
88 *self = Self::Single(leader);
89 }
90 members
91 }
92 }
93 }
94}
95
96#[derive(Debug)]
97pub struct ModularImageDestination<S: Sample> {
98 header: ModularHeader,
99 ma_ctx: MaConfig,
100 group_dim: u32,
101 bit_depth: u32,
102 channels: ModularChannels,
103 meta_channels: Vec<AlignedGrid<S>>,
104 image_channels: Vec<AlignedGrid<S>>,
105}
106
107impl<S: Sample> ModularImageDestination<S> {
108 pub(crate) fn new(
109 header: ModularHeader,
110 ma_ctx: MaConfig,
111 group_dim: u32,
112 bit_depth: u32,
113 channels: ModularChannels,
114 tracker: Option<&AllocTracker>,
115 ) -> Result<Self> {
116 let mut meta_channels = Vec::new();
117 for tr in &header.transform {
118 tr.prepare_meta_channels(&mut meta_channels, tracker)?;
119 }
120
121 let image_channels = channels
122 .info
123 .iter()
124 .map(|ch| {
125 AlignedGrid::with_alloc_tracker(ch.width as usize, ch.height as usize, tracker)
126 })
127 .collect::<std::result::Result<_, _>>()?;
128
129 Ok(Self {
130 header,
131 ma_ctx,
132 group_dim,
133 bit_depth,
134 channels,
135 meta_channels,
136 image_channels,
137 })
138 }
139
140 pub fn try_clone(&self) -> Result<Self> {
141 Ok(Self {
142 header: self.header.clone(),
143 ma_ctx: self.ma_ctx.clone(),
144 group_dim: self.group_dim,
145 bit_depth: self.bit_depth,
146 channels: self.channels.clone(),
147 meta_channels: self
148 .meta_channels
149 .iter()
150 .map(|x| x.try_clone())
151 .collect::<std::result::Result<_, _>>()?,
152 image_channels: self
153 .image_channels
154 .iter()
155 .map(|x| x.try_clone())
156 .collect::<std::result::Result<_, _>>()?,
157 })
158 }
159
160 pub fn image_channels(&self) -> &[AlignedGrid<S>] {
161 &self.image_channels
162 }
163
164 pub fn into_image_channels(self) -> Vec<AlignedGrid<S>> {
165 self.image_channels
166 }
167
168 pub fn into_image_channels_with_info(
169 self,
170 ) -> impl Iterator<Item = (AlignedGrid<S>, ModularChannelInfo)> {
171 self.image_channels.into_iter().zip(self.channels.info)
172 }
173
174 pub fn has_palette(&self) -> bool {
175 self.header.transform.iter().any(|tr| tr.is_palette())
176 }
177
178 pub fn has_squeeze(&self) -> bool {
179 self.header.transform.iter().any(|tr| tr.is_squeeze())
180 }
181}
182
183impl<S: Sample> ModularImageDestination<S> {
184 pub fn prepare_gmodular(&mut self) -> Result<TransformedModularSubimage<S>> {
185 assert_ne!(self.group_dim, 0);
186
187 let group_dim = self.group_dim;
188 let subimage = self.prepare_subimage()?;
189 let (channel_info, grids): (Vec<_>, Vec<_>) = subimage
190 .channel_info
191 .into_iter()
192 .zip(subimage.grid)
193 .enumerate()
194 .take_while(|&(i, (ref info, _))| {
195 i < subimage.nb_meta_channels
196 || (info.width <= group_dim && info.height <= group_dim)
197 })
198 .map(|(_, x)| x)
199 .unzip();
200 let channel_indices = (0..channel_info.len()).collect();
201 Ok(TransformedModularSubimage {
202 channel_info,
203 channel_indices,
204 grid: grids,
205 ..subimage
206 })
207 }
208
209 pub fn prepare_groups(
210 &mut self,
211 pass_shifts: &std::collections::BTreeMap<u32, (i32, i32)>,
212 ) -> Result<TransformedGlobalModular<S>> {
213 assert_ne!(self.group_dim, 0);
214
215 let num_passes = *pass_shifts.last_key_value().unwrap().0 as usize + 1;
216
217 let group_dim = self.group_dim;
218 let group_dim_shift = group_dim.trailing_zeros();
219 let bit_depth = self.bit_depth;
220 let subimage = self.prepare_subimage()?;
221 let it = subimage
222 .channel_info
223 .into_iter()
224 .zip(subimage.grid)
225 .enumerate()
226 .skip_while(|&(i, (ref info, _))| {
227 i < subimage.nb_meta_channels
228 || (info.width <= group_dim && info.height <= group_dim)
229 });
230
231 let mut lf_groups = Vec::new();
232 let mut pass_groups = Vec::with_capacity(num_passes);
233 pass_groups.resize_with(num_passes, Vec::new);
234 for (i, (info, grid)) in it {
235 let ModularChannelInfo {
236 original_width,
237 original_height,
238 hshift,
239 vshift,
240 original_shift,
241 ..
242 } = info;
243 assert!(hshift >= 0 && vshift >= 0);
244
245 let grid = match grid {
246 TransformedGrid::Single(g) => g,
247 TransformedGrid::Merged { leader, .. } => leader,
248 };
249 tracing::trace!(
250 i,
251 width = grid.width(),
252 height = grid.height(),
253 hshift,
254 vshift
255 );
256
257 let (groups, grids) = if hshift < 3 || vshift < 3 {
258 let shift = hshift.min(vshift); let pass_idx = *pass_shifts
260 .iter()
261 .find(|(_, &(minshift, maxshift))| (minshift..maxshift).contains(&shift))
262 .unwrap()
263 .0;
264 let pass_idx = pass_idx as usize;
265
266 let group_width = group_dim >> hshift;
267 let group_height = group_dim >> vshift;
268 if group_width == 0 || group_height == 0 {
269 tracing::error!(
270 group_dim,
271 hshift,
272 vshift,
273 "Channel shift value too large after transform"
274 );
275 return Err(crate::Error::InvalidSqueezeParams);
276 }
277
278 let grids = grid.into_groups_with_fixed_count(
279 group_width as usize,
280 group_height as usize,
281 (original_width + group_dim - 1) as usize >> group_dim_shift,
282 (original_height + group_dim - 1) as usize >> group_dim_shift,
283 );
284 (&mut pass_groups[pass_idx], grids)
285 } else {
286 let lf_group_width = group_dim >> (hshift - 3);
288 let lf_group_height = group_dim >> (vshift - 3);
289 if lf_group_width == 0 || lf_group_height == 0 {
290 tracing::error!(
291 group_dim,
292 hshift,
293 vshift,
294 "Channel shift value too large after transform"
295 );
296 return Err(crate::Error::InvalidSqueezeParams);
297 }
298 let grids = grid.into_groups_with_fixed_count(
299 lf_group_width as usize,
300 lf_group_height as usize,
301 (original_width + (group_dim << 3) - 1) as usize >> (group_dim_shift + 3),
302 (original_height + (group_dim << 3) - 1) as usize >> (group_dim_shift + 3),
303 );
304 (&mut lf_groups, grids)
305 };
306
307 if groups.is_empty() {
308 groups.resize_with(grids.len(), || {
309 TransformedModularSubimage::empty(&subimage.header, &subimage.ma_ctx, bit_depth)
310 });
311 } else if groups.len() != grids.len() {
312 panic!();
313 }
314
315 for (subimage, grid) in groups.iter_mut().zip(grids) {
316 let width = grid.width() as u32;
317 let height = grid.height() as u32;
318 if width == 0 || height == 0 {
319 continue;
320 }
321
322 subimage.channel_info.push(ModularChannelInfo {
323 width,
324 height,
325 original_width: width << hshift,
326 original_height: height << vshift,
327 hshift,
328 vshift,
329 original_shift,
330 });
331 subimage.channel_indices.push(i);
332 subimage.grid.push(grid.into());
333 subimage.partial = true;
334 }
335 }
336
337 Ok(TransformedGlobalModular {
338 lf_groups,
339 pass_groups,
340 })
341 }
342
343 pub fn prepare_subimage(&mut self) -> Result<TransformedModularSubimage<S>> {
344 let mut channels = self.channels.clone();
345 let mut meta_channel_grids = self
346 .meta_channels
347 .iter_mut()
348 .map(MutableSubgrid::from)
349 .collect::<Vec<_>>();
350 let mut grids = self
351 .image_channels
352 .iter_mut()
353 .map(|g| g.as_subgrid_mut().into())
354 .collect::<Vec<_>>();
355 for tr in &self.header.transform {
356 tr.transform_channels(&mut channels, &mut meta_channel_grids, &mut grids)?;
357 }
358
359 let channel_info = channels.info;
360 let channel_indices = (0..channel_info.len()).collect();
361 Ok(TransformedModularSubimage {
362 header: self.header.clone(),
363 ma_ctx: self.ma_ctx.clone(),
364 bit_depth: self.bit_depth,
365 nb_meta_channels: channels.nb_meta_channels as usize,
366 channel_info,
367 channel_indices,
368 grid: grids,
369 partial: true,
370 })
371 }
372}
373
374#[derive(Debug)]
375pub struct TransformedGlobalModular<'dest, S: Sample> {
376 pub lf_groups: Vec<TransformedModularSubimage<'dest, S>>,
377 pub pass_groups: Vec<Vec<TransformedModularSubimage<'dest, S>>>,
378}
379
380#[derive(Debug)]
381pub struct TransformedModularSubimage<'dest, S: Sample> {
382 header: ModularHeader,
383 ma_ctx: MaConfig,
384 bit_depth: u32,
385 nb_meta_channels: usize,
386 channel_info: Vec<ModularChannelInfo>,
387 channel_indices: Vec<usize>,
388 grid: Vec<TransformedGrid<'dest, S>>,
389 partial: bool,
390}
391
392impl<S: Sample> TransformedModularSubimage<'_, S> {
393 fn empty(header: &ModularHeader, ma_ctx: &MaConfig, bit_depth: u32) -> Self {
394 Self {
395 header: header.clone(),
396 ma_ctx: ma_ctx.clone(),
397 bit_depth,
398 nb_meta_channels: 0,
399 channel_info: Vec::new(),
400 channel_indices: Vec::new(),
401 grid: Vec::new(),
402 partial: false,
403 }
404 }
405}
406
407impl<'dest, S: Sample> TransformedModularSubimage<'dest, S> {
408 pub fn is_empty(&self) -> bool {
409 self.channel_info.is_empty()
410 }
411
412 pub fn recursive(
413 self,
414 bitstream: &mut Bitstream,
415 global_ma_config: Option<&MaConfig>,
416 tracker: Option<&AllocTracker>,
417 ) -> Result<RecursiveModularImage<'dest, S>> {
418 let channels = crate::ModularChannels {
419 info: self.channel_info,
420 nb_meta_channels: 0,
421 };
422 let (header, ma_ctx) = crate::read_and_validate_local_modular_header(
423 bitstream,
424 &channels,
425 global_ma_config,
426 tracker,
427 )?;
428
429 let mut image = RecursiveModularImage {
430 header,
431 ma_ctx,
432 bit_depth: self.bit_depth,
433 channels,
434 meta_channels: Vec::new(),
435 image_channels: self.grid,
436 };
437 for tr in &image.header.transform {
438 tr.prepare_meta_channels(&mut image.meta_channels, tracker)?;
439 }
440 Ok(image)
441 }
442
443 pub fn finish(mut self, pool: &jxl_threadpool::JxlThreadPool) -> bool {
444 for tr in self.header.transform.iter().rev() {
445 tr.inverse(&mut self.grid, self.bit_depth, pool);
446 }
447 !self.partial
448 }
449}
450
451impl<S: Sample> TransformedModularSubimage<'_, S> {
452 fn decode_inner(&mut self, bitstream: &mut Bitstream, stream_index: u32) -> Result<()> {
453 let span = tracing::span!(tracing::Level::TRACE, "decode channels", stream_index);
454 let _guard = span.enter();
455
456 let dist_multiplier = self
457 .channel_info
458 .iter()
459 .map(|info| info.width)
460 .max()
461 .unwrap_or(0);
462
463 let mut decoder = self.ma_ctx.decoder().clone();
464 decoder.begin(bitstream)?;
465
466 let mut ma_tree_list = Vec::with_capacity(self.channel_info.len());
467 for (i, info) in self.channel_info.iter().enumerate() {
468 if info.width == 0 || info.height == 0 {
469 ma_tree_list.push(None);
470 continue;
471 }
472
473 let filtered_prev_len = self.channel_info[..i]
474 .iter()
475 .filter(|prev_info| {
476 info.width == prev_info.width
477 && info.height == prev_info.height
478 && info.hshift == prev_info.hshift
479 && info.vshift == prev_info.vshift
480 })
481 .count();
482
483 let ma_tree =
484 self.ma_ctx
485 .make_flat_tree(i as u32, stream_index, filtered_prev_len as u32);
486 ma_tree_list.push(Some(ma_tree));
487 }
488
489 if let Some(mut rle_decoder) = decoder.as_rle() {
490 let is_fast_lossless = ma_tree_list.iter().all(|ma_tree| {
491 ma_tree
492 .as_ref()
493 .map(|ma_tree| {
494 matches!(
495 ma_tree.single_node(),
496 Some(MaTreeLeafClustered {
497 predictor: Predictor::Gradient,
498 offset: 0,
499 multiplier: 1,
500 ..
501 })
502 )
503 })
504 .unwrap_or(true)
505 });
506
507 if is_fast_lossless {
508 tracing::trace!("libjxl fast-lossless");
509 let mut rle_state = RleState::<S>::new();
510
511 for (ma_tree, grid) in ma_tree_list.into_iter().zip(&mut self.grid) {
512 let Some(ma_tree) = ma_tree else {
513 continue;
514 };
515
516 let node = ma_tree.single_node().unwrap();
517 let cluster = node.cluster;
518 decode_fast_lossless(
519 bitstream,
520 &mut rle_decoder,
521 &mut rle_state,
522 cluster,
523 grid.grid_mut(),
524 );
525 }
526
527 rle_state.check_error()?;
528 return Ok(());
530 }
531 }
532
533 let wp_header = &self.header.wp_params;
534 let mut predictor = PredictorState::new();
535 let mut prev_map = HashMap::new();
536 for ((info, ma_tree), grid) in self
537 .channel_info
538 .iter()
539 .zip(ma_tree_list)
540 .zip(&mut self.grid)
541 {
542 let Some(ma_tree) = ma_tree else {
543 continue;
544 };
545 let key = (info.width, info.height, info.hshift, info.vshift);
546
547 let filtered_prev = prev_map.entry(key).or_insert_with(Vec::new);
548
549 if let Some(node) = ma_tree.single_node() {
550 decode_single_node(
551 bitstream,
552 &mut decoder,
553 dist_multiplier,
554 &mut predictor,
555 wp_header,
556 grid.grid_mut(),
557 node,
558 )?;
559 } else if let Some(table) = ma_tree.simple_table() {
560 decode_simple_table(
561 bitstream,
562 &mut decoder,
563 dist_multiplier,
564 &mut predictor,
565 wp_header,
566 grid.grid_mut(),
567 &table,
568 )?;
569 } else {
570 let grid = grid.grid_mut();
571 let filtered_prev = &filtered_prev[..ma_tree.max_prev_channel_depth()];
572 let wp_header = ma_tree.need_self_correcting().then_some(wp_header);
573 predictor.reset(grid.width() as u32, filtered_prev, wp_header);
574 decode_slow(
575 bitstream,
576 &mut decoder,
577 dist_multiplier,
578 &ma_tree,
579 &mut predictor,
580 grid,
581 )?;
582 }
583
584 filtered_prev.insert(0, grid.grid());
585 }
586
587 decoder.finalize()?;
588 Ok(())
589 }
590
591 pub fn decode(
592 &mut self,
593 bitstream: &mut Bitstream,
594 stream_index: u32,
595 allow_partial: bool,
596 ) -> Result<()> {
597 match self.decode_inner(bitstream, stream_index) {
598 Err(e) if e.unexpected_eof() && allow_partial => {
599 tracing::debug!("Partially decoded Modular image");
600 }
601 Err(e) => return Err(e),
602 Ok(_) => {
603 self.partial = false;
604 }
605 }
606 Ok(())
607 }
608}
609
610#[derive(Debug)]
611pub struct RecursiveModularImage<'dest, S: Sample> {
612 header: ModularHeader,
613 ma_ctx: MaConfig,
614 bit_depth: u32,
615 channels: ModularChannels,
616 meta_channels: Vec<AlignedGrid<S>>,
617 image_channels: Vec<TransformedGrid<'dest, S>>,
618}
619
620impl<S: Sample> RecursiveModularImage<'_, S> {
621 pub fn prepare_subimage(&mut self) -> Result<TransformedModularSubimage<S>> {
622 let mut channels = self.channels.clone();
623 let mut meta_channel_grids = self
624 .meta_channels
625 .iter_mut()
626 .map(|g| {
627 let width = g.width();
628 let height = g.height();
629 MutableSubgrid::from_buf(g.buf_mut(), width, height, width)
630 })
631 .collect::<Vec<_>>();
632 let mut grids = self
633 .image_channels
634 .iter_mut()
635 .map(|g| g.reborrow())
636 .collect();
637 for tr in &self.header.transform {
638 tr.transform_channels(&mut channels, &mut meta_channel_grids, &mut grids)?;
639 }
640
641 let channel_info = channels.info;
642 let channel_indices = (0..channel_info.len()).collect();
643 Ok(TransformedModularSubimage {
644 header: self.header.clone(),
645 ma_ctx: self.ma_ctx.clone(),
646 bit_depth: self.bit_depth,
647 nb_meta_channels: channels.nb_meta_channels as usize,
648 channel_info,
649 channel_indices,
650 grid: grids,
651 partial: true,
652 })
653 }
654}
655
656struct RleState<S: Sample> {
657 value: S,
658 repeat: u32,
659 error: Option<Box<jxl_coding::Error>>,
660}
661
662impl<S: Sample> RleState<S> {
663 #[inline]
664 fn new() -> Self {
665 Self {
666 value: S::default(),
667 repeat: 0,
668 error: None,
669 }
670 }
671
672 #[inline(always)]
673 fn decode(
674 &mut self,
675 bitstream: &mut Bitstream,
676 decoder: &mut DecoderRleMode,
677 cluster: u8,
678 ) -> S {
679 if self.repeat == 0 {
680 let result = decoder.read_varint_clustered(bitstream, cluster);
681 match result {
682 Ok(RleToken::Value(v)) => {
683 self.value = S::unpack_signed_u32(v);
684 self.repeat = 1;
685 }
686 Ok(RleToken::Repeat(len)) => {
687 self.repeat = len;
688 }
689 Err(e) if self.error.is_none() => {
690 self.error = Some(Box::new(e));
691 }
692 _ => {}
693 }
694 }
695
696 self.repeat = self.repeat.wrapping_sub(1);
697 self.value
698 }
699
700 #[inline]
701 fn check_error(&mut self) -> Result<()> {
702 let error = self.error.take();
703 if let Some(error) = error {
704 let error = *error;
705 Err(error.into())
706 } else {
707 Ok(())
708 }
709 }
710}
711
712fn decode_single_node<S: Sample>(
713 bitstream: &mut Bitstream,
714 decoder: &mut Decoder,
715 dist_multiplier: u32,
716 predictor_state: &mut PredictorState<S>,
717 wp_header: &WpHeader,
718 grid: &mut MutableSubgrid<S>,
719 node: &MaTreeLeafClustered,
720) -> Result<()> {
721 let &MaTreeLeafClustered {
722 cluster,
723 predictor,
724 offset,
725 multiplier,
726 } = node;
727 tracing::trace!(cluster, ?predictor, "Single MA tree node");
728
729 let height = grid.height();
730 let single_token = decoder.single_token(cluster);
731 match (predictor, single_token) {
732 (Predictor::Zero, Some(token)) => {
733 tracing::trace!("Single token in cluster, Zero predictor: hyper fast path");
734 let value = S::unpack_signed_u32(token).wrapping_muladd_i32(multiplier as i32, offset);
735 for y in 0..height {
736 grid.get_row_mut(y).fill(value);
737 }
738 Ok(())
739 }
740 (Predictor::Zero, None) => {
741 tracing::trace!("Zero predictor: fast path");
742 for y in 0..height {
743 let row = grid.get_row_mut(y);
744 for out in row {
745 let token = decoder.read_varint_with_multiplier_clustered(
746 bitstream,
747 cluster,
748 dist_multiplier,
749 )?;
750 *out =
751 S::unpack_signed_u32(token).wrapping_muladd_i32(multiplier as i32, offset);
752 }
753 }
754 Ok(())
755 }
756 (Predictor::Gradient, _) if offset == 0 && multiplier == 1 => {
757 tracing::trace!("Simple gradient: quite fast path");
758 decode_simple_grad(bitstream, decoder, cluster, dist_multiplier, grid)
759 }
760 _ => {
761 let wp_header = (predictor == Predictor::SelfCorrecting).then_some(wp_header);
762 predictor_state.reset(grid.width() as u32, &[], wp_header);
763 decode_single_node_slow(
764 bitstream,
765 decoder,
766 dist_multiplier,
767 node,
768 predictor_state,
769 grid,
770 )
771 }
772 }
773}
774
775#[inline(never)]
776fn decode_fast_lossless<S: Sample>(
777 bitstream: &mut Bitstream,
778 decoder: &mut DecoderRleMode,
779 rle_state: &mut RleState<S>,
780 cluster: u8,
781 grid: &mut MutableSubgrid<S>,
782) {
783 let height = grid.height();
784
785 {
786 let mut w = S::default();
787 let out_row = grid.get_row_mut(0);
788 for out in &mut *out_row {
789 let token = rle_state.decode(bitstream, decoder, cluster);
790 w = w.add(token);
791 *out = w;
792 }
793 }
794
795 for y in 1..height {
796 let (u, mut d) = grid.split_vertical(y);
797 let prev_row = u.get_row(y - 1);
798 let out_row = d.get_row_mut(0);
799
800 let token = rle_state.decode(bitstream, decoder, cluster);
801 let mut w = token.add(prev_row[0]);
802 out_row[0] = w;
803
804 for (window, out) in prev_row.windows(2).zip(&mut out_row[1..]) {
805 let nw = window[0];
806 let n = window[1];
807 let pred = S::grad_clamped(n, w, nw);
808
809 let token = rle_state.decode(bitstream, decoder, cluster);
810 w = token.add(pred);
811 *out = w;
812 }
813 }
814}
815
816#[inline(never)]
817fn decode_simple_grad<S: Sample>(
818 bitstream: &mut Bitstream,
819 decoder: &mut Decoder,
820 cluster: u8,
821 dist_multiplier: u32,
822 grid: &mut MutableSubgrid<S>,
823) -> Result<()> {
824 let width = grid.width();
825 let height = grid.height();
826
827 {
828 let mut w = S::default();
829 let out_row = grid.get_row_mut(0);
830 for out in out_row[..width].iter_mut() {
831 let token = decoder.read_varint_with_multiplier_clustered(
832 bitstream,
833 cluster,
834 dist_multiplier,
835 )?;
836 w = S::unpack_signed_u32(token).add(w);
837 *out = w;
838 }
839 }
840
841 for y in 1..height {
842 let (u, mut d) = grid.split_vertical(y);
843 let prev_row = u.get_row(y - 1);
844 let out_row = d.get_row_mut(0);
845
846 let token =
847 decoder.read_varint_with_multiplier_clustered(bitstream, cluster, dist_multiplier)?;
848 let mut w = S::unpack_signed_u32(token).add(prev_row[0]);
849 out_row[0] = w;
850
851 for (window, out) in prev_row.windows(2).zip(&mut out_row[1..]) {
852 let nw = window[0];
853 let n = window[1];
854 let pred = S::grad_clamped(n, w, nw);
855
856 let token = decoder.read_varint_with_multiplier_clustered(
857 bitstream,
858 cluster,
859 dist_multiplier,
860 )?;
861 let value = S::unpack_signed_u32(token).add(pred);
862 *out = value;
863 w = value;
864 }
865 }
866
867 Ok(())
868}
869
870#[inline(always)]
871fn decode_one<S: Sample, const EDGE: bool>(
872 bitstream: &mut Bitstream,
873 decoder: &mut Decoder,
874 dist_multiplier: u32,
875 leaf: &MaTreeLeafClustered,
876 properties: &Properties<S>,
877) -> Result<S> {
878 let diff = S::unpack_signed_u32(decoder.read_varint_with_multiplier_clustered(
879 bitstream,
880 leaf.cluster,
881 dist_multiplier,
882 )?);
883 let diff = diff.wrapping_muladd_i32(leaf.multiplier as i32, leaf.offset);
884 let predictor = leaf.predictor;
885 let sample_prediction = predictor.predict::<_, EDGE>(properties);
886 Ok(diff.add(S::from_i32(sample_prediction)))
887}
888
889#[inline(never)]
890fn decode_single_node_slow<S: Sample>(
891 bitstream: &mut Bitstream,
892 decoder: &mut Decoder,
893 dist_multiplier: u32,
894 leaf: &MaTreeLeafClustered,
895 predictor: &mut PredictorState<S>,
896 grid: &mut MutableSubgrid<S>,
897) -> Result<()> {
898 let height = grid.height();
899 for y in 0..2usize.min(height) {
900 let row = grid.get_row_mut(y);
901
902 for out in row.iter_mut() {
903 let properties = predictor.properties::<true>();
904 let true_value =
905 decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
906 *out = true_value;
907 properties.record(true_value.to_i32());
908 }
909 }
910
911 for y in 2..height {
912 let row = grid.get_row_mut(y);
913 let (row_left, row_middle, row_right) = if row.len() <= 4 {
914 (row, [].as_mut(), [].as_mut())
915 } else {
916 let (l, m) = row.split_at_mut(2);
917 let (m, r) = m.split_at_mut(m.len() - 2);
918 (l, m, r)
919 };
920
921 for out in row_left {
922 let properties = predictor.properties::<true>();
923 let true_value =
924 decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
925 *out = true_value;
926 properties.record(true_value.to_i32());
927 }
928 for out in row_middle {
929 let properties = predictor.properties::<false>();
930 let true_value =
931 decode_one::<_, false>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
932 *out = true_value;
933 properties.record(true_value.to_i32());
934 }
935 for out in row_right {
936 let properties = predictor.properties::<true>();
937 let true_value =
938 decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
939 *out = true_value;
940 properties.record(true_value.to_i32());
941 }
942 }
943
944 Ok(())
945}
946
947fn decode_simple_table<S: Sample>(
948 bitstream: &mut Bitstream,
949 decoder: &mut Decoder,
950 dist_multiplier: u32,
951 predictor_state: &mut PredictorState<S>,
952 wp_header: &WpHeader,
953 grid: &mut MutableSubgrid<S>,
954 table: &SimpleMaTable,
955) -> Result<()> {
956 let &SimpleMaTable {
957 decision_prop,
958 value_base,
959 predictor,
960 offset,
961 multiplier,
962 ref cluster_table,
963 } = table;
964
965 if offset == 0 && multiplier == 1 && decision_prop == 9 && predictor == Predictor::Gradient {
966 return decode_gradient_table(
967 bitstream,
968 decoder,
969 dist_multiplier,
970 grid,
971 value_base,
972 cluster_table,
973 );
974 }
975
976 decode_simple_table_slow(
977 bitstream,
978 decoder,
979 dist_multiplier,
980 predictor_state,
981 wp_header,
982 grid,
983 table,
984 )
985}
986
987#[inline(always)]
988fn cluster_from_table(sample: i32, value_base: i32, cluster_table: &[u8]) -> u8 {
989 let index = (sample - value_base).clamp(0, cluster_table.len() as i32 - 1);
990 cluster_table[index as usize]
991}
992
993fn decode_gradient_table<S: Sample>(
994 bitstream: &mut Bitstream,
995 decoder: &mut Decoder,
996 dist_multiplier: u32,
997 grid: &mut MutableSubgrid<S>,
998 value_base: i32,
999 cluster_table: &[u8],
1000) -> Result<()> {
1001 tracing::trace!("Gradient-only lookup table");
1002
1003 let width = grid.width();
1004 let height = grid.height();
1005
1006 {
1007 let mut w = S::default();
1008 let out_row = grid.get_row_mut(0);
1009 for out in out_row[..width].iter_mut() {
1010 let cluster = cluster_from_table(w.to_i32(), value_base, cluster_table);
1011 let token = decoder.read_varint_with_multiplier_clustered(
1012 bitstream,
1013 cluster,
1014 dist_multiplier,
1015 )?;
1016 w = S::unpack_signed_u32(token).add(w);
1017 *out = w;
1018 }
1019 }
1020
1021 for y in 1..height {
1022 let (u, mut d) = grid.split_vertical(y);
1023 let prev_row = u.get_row(y - 1);
1024 let out_row = d.get_row_mut(0);
1025
1026 let cluster = cluster_from_table(prev_row[0].to_i32(), value_base, cluster_table);
1027 let token =
1028 decoder.read_varint_with_multiplier_clustered(bitstream, cluster, dist_multiplier)?;
1029 let mut w = S::unpack_signed_u32(token).add(prev_row[0]);
1030 out_row[0] = w;
1031
1032 for (window, out) in prev_row.windows(2).zip(&mut out_row[1..]) {
1033 let nw = window[0];
1034 let n = window[1];
1035 let prop = n
1036 .to_i32()
1037 .wrapping_add(w.to_i32())
1038 .wrapping_sub(nw.to_i32());
1039 let pred = S::grad_clamped(n, w, nw);
1040
1041 let cluster = cluster_from_table(prop, value_base, cluster_table);
1042 let token = decoder.read_varint_with_multiplier_clustered(
1043 bitstream,
1044 cluster,
1045 dist_multiplier,
1046 )?;
1047 let value = S::unpack_signed_u32(token).add(pred);
1048 *out = value;
1049 w = value;
1050 }
1051 }
1052
1053 Ok(())
1054}
1055
1056#[inline(always)]
1057fn decode_table_one<S: Sample, const EDGE: bool>(
1058 bitstream: &mut Bitstream,
1059 decoder: &mut Decoder,
1060 dist_multiplier: u32,
1061 table: &SimpleMaTable,
1062 properties: &Properties<S>,
1063) -> Result<S> {
1064 let prop_value = properties.get(table.decision_prop as usize);
1065
1066 let cluster = cluster_from_table(prop_value, table.value_base, &table.cluster_table);
1067
1068 let diff = S::unpack_signed_u32(decoder.read_varint_with_multiplier_clustered(
1069 bitstream,
1070 cluster,
1071 dist_multiplier,
1072 )?);
1073 let diff = diff.wrapping_muladd_i32(table.multiplier as i32, table.offset);
1074 let predictor = table.predictor;
1075 let sample_prediction = predictor.predict::<_, EDGE>(properties);
1076 Ok(diff.add(S::from_i32(sample_prediction)))
1077}
1078
1079#[inline(never)]
1080fn decode_simple_table_slow<S: Sample>(
1081 bitstream: &mut Bitstream,
1082 decoder: &mut Decoder,
1083 dist_multiplier: u32,
1084 predictor_state: &mut PredictorState<S>,
1085 wp_header: &WpHeader,
1086 grid: &mut MutableSubgrid<S>,
1087 table: &SimpleMaTable,
1088) -> Result<()> {
1089 tracing::trace!("Slow lookup table");
1090
1091 let need_wp_header = table.decision_prop == 15 || table.predictor == Predictor::SelfCorrecting;
1092 let wp_header = need_wp_header.then_some(wp_header);
1093 predictor_state.reset(grid.width() as u32, &[], wp_header);
1094
1095 let height = grid.height();
1096 for y in 0..2usize.min(height) {
1097 let row = grid.get_row_mut(y);
1098
1099 for out in row.iter_mut() {
1100 let properties = predictor_state.properties::<true>();
1101 let true_value = decode_table_one::<_, true>(
1102 bitstream,
1103 decoder,
1104 dist_multiplier,
1105 table,
1106 &properties,
1107 )?;
1108 *out = true_value;
1109 properties.record(true_value.to_i32());
1110 }
1111 }
1112
1113 for y in 2..height {
1114 let row = grid.get_row_mut(y);
1115 let (row_left, row_middle, row_right) = if row.len() <= 4 {
1116 (row, [].as_mut(), [].as_mut())
1117 } else {
1118 let (l, m) = row.split_at_mut(2);
1119 let (m, r) = m.split_at_mut(m.len() - 2);
1120 (l, m, r)
1121 };
1122
1123 for out in row_left {
1124 let properties = predictor_state.properties::<true>();
1125 let true_value = decode_table_one::<_, true>(
1126 bitstream,
1127 decoder,
1128 dist_multiplier,
1129 table,
1130 &properties,
1131 )?;
1132 *out = true_value;
1133 properties.record(true_value.to_i32());
1134 }
1135 for out in row_middle {
1136 let properties = predictor_state.properties::<false>();
1137 let true_value = decode_table_one::<_, false>(
1138 bitstream,
1139 decoder,
1140 dist_multiplier,
1141 table,
1142 &properties,
1143 )?;
1144 *out = true_value;
1145 properties.record(true_value.to_i32());
1146 }
1147 for out in row_right {
1148 let properties = predictor_state.properties::<true>();
1149 let true_value = decode_table_one::<_, true>(
1150 bitstream,
1151 decoder,
1152 dist_multiplier,
1153 table,
1154 &properties,
1155 )?;
1156 *out = true_value;
1157 properties.record(true_value.to_i32());
1158 }
1159 }
1160
1161 Ok(())
1162}
1163
1164#[inline(never)]
1165fn decode_slow<S: Sample>(
1166 bitstream: &mut Bitstream,
1167 decoder: &mut Decoder,
1168 dist_multiplier: u32,
1169 ma_tree: &FlatMaTree,
1170 predictor: &mut PredictorState<S>,
1171 grid: &mut MutableSubgrid<S>,
1172) -> Result<()> {
1173 let height = grid.height();
1174 for y in 0..2usize.min(height) {
1175 let row = grid.get_row_mut(y);
1176
1177 for out in row.iter_mut() {
1178 let properties = predictor.properties::<true>();
1179 let leaf = ma_tree.get_leaf(&properties);
1180 let true_value =
1181 decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1182 *out = true_value;
1183 properties.record(true_value.to_i32());
1184 }
1185 }
1186
1187 for y in 2..height {
1188 let row = grid.get_row_mut(y);
1189 let (row_left, row_middle, row_right) = if row.len() <= 4 {
1190 (row, [].as_mut(), [].as_mut())
1191 } else {
1192 let (l, m) = row.split_at_mut(2);
1193 let (m, r) = m.split_at_mut(m.len() - 2);
1194 (l, m, r)
1195 };
1196
1197 for out in row_left {
1198 let properties = predictor.properties::<true>();
1199 let leaf = ma_tree.get_leaf(&properties);
1200 let true_value =
1201 decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1202 *out = true_value;
1203 properties.record(true_value.to_i32());
1204 }
1205 for out in row_middle {
1206 let properties = predictor.properties::<false>();
1207 let leaf = ma_tree.get_leaf(&properties);
1208 let true_value =
1209 decode_one::<_, false>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1210 *out = true_value;
1211 properties.record(true_value.to_i32());
1212 }
1213 for out in row_right {
1214 let properties = predictor.properties::<true>();
1215 let leaf = ma_tree.get_leaf(&properties);
1216 let true_value =
1217 decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1218 *out = true_value;
1219 properties.record(true_value.to_i32());
1220 }
1221 }
1222
1223 Ok(())
1224}