jxl_modular/
image.rs

1use std::collections::HashMap;
2
3use jxl_bitstream::Bitstream;
4use jxl_coding::{Decoder, DecoderRleMode, RleToken};
5use jxl_grid::{AlignedGrid, AllocTracker, MutableSubgrid};
6
7use crate::{
8    ma::{FlatMaTree, MaTreeLeafClustered, SimpleMaTable},
9    predictor::{Predictor, PredictorState, Properties, WpHeader},
10    sample::Sample,
11    MaConfig, ModularChannelInfo, ModularChannels, ModularHeader, Result,
12};
13
14#[derive(Debug)]
15pub enum TransformedGrid<'dest, S: Sample> {
16    Single(MutableSubgrid<'dest, S>),
17    Merged {
18        leader: MutableSubgrid<'dest, S>,
19        members: Vec<TransformedGrid<'dest, S>>,
20    },
21}
22
23impl<'dest, S: Sample> From<MutableSubgrid<'dest, S>> for TransformedGrid<'dest, S> {
24    fn from(value: MutableSubgrid<'dest, S>) -> Self {
25        Self::Single(value)
26    }
27}
28
29impl<S: Sample> TransformedGrid<'_, S> {
30    fn reborrow(&mut self) -> TransformedGrid<S> {
31        match self {
32            TransformedGrid::Single(g) => TransformedGrid::Single(g.split_horizontal(0).1),
33            TransformedGrid::Merged { leader, .. } => {
34                TransformedGrid::Single(leader.split_horizontal(0).1)
35            }
36        }
37    }
38}
39
40impl<'dest, S: Sample> TransformedGrid<'dest, S> {
41    pub(crate) fn grid(&self) -> &MutableSubgrid<'dest, S> {
42        match self {
43            Self::Single(g) => g,
44            Self::Merged { leader, .. } => leader,
45        }
46    }
47
48    pub(crate) fn grid_mut(&mut self) -> &mut MutableSubgrid<'dest, S> {
49        match self {
50            Self::Single(g) => g,
51            Self::Merged { leader, .. } => leader,
52        }
53    }
54
55    pub(crate) fn merge(&mut self, members: Vec<TransformedGrid<'dest, S>>) {
56        if members.is_empty() {
57            return;
58        }
59
60        match self {
61            Self::Single(leader) => {
62                let tmp = MutableSubgrid::empty();
63                let leader = std::mem::replace(leader, tmp);
64                *self = Self::Merged { leader, members };
65            }
66            Self::Merged {
67                members: original_members,
68                ..
69            } => {
70                original_members.extend(members);
71            }
72        }
73    }
74
75    pub(crate) fn unmerge(&mut self, count: usize) -> Vec<TransformedGrid<'dest, S>> {
76        if count == 0 {
77            return Vec::new();
78        }
79
80        match self {
81            Self::Single(_) => panic!("cannot unmerge TransformedGrid::Single"),
82            Self::Merged { leader, members } => {
83                let len = members.len();
84                let members = members.drain((len - count)..).collect();
85                if len == count {
86                    let tmp = MutableSubgrid::empty();
87                    let leader = std::mem::replace(leader, tmp);
88                    *self = Self::Single(leader);
89                }
90                members
91            }
92        }
93    }
94}
95
96#[derive(Debug)]
97pub struct ModularImageDestination<S: Sample> {
98    header: ModularHeader,
99    ma_ctx: MaConfig,
100    group_dim: u32,
101    bit_depth: u32,
102    channels: ModularChannels,
103    meta_channels: Vec<AlignedGrid<S>>,
104    image_channels: Vec<AlignedGrid<S>>,
105}
106
107impl<S: Sample> ModularImageDestination<S> {
108    pub(crate) fn new(
109        header: ModularHeader,
110        ma_ctx: MaConfig,
111        group_dim: u32,
112        bit_depth: u32,
113        channels: ModularChannels,
114        tracker: Option<&AllocTracker>,
115    ) -> Result<Self> {
116        let mut meta_channels = Vec::new();
117        for tr in &header.transform {
118            tr.prepare_meta_channels(&mut meta_channels, tracker)?;
119        }
120
121        let image_channels = channels
122            .info
123            .iter()
124            .map(|ch| {
125                AlignedGrid::with_alloc_tracker(ch.width as usize, ch.height as usize, tracker)
126            })
127            .collect::<std::result::Result<_, _>>()?;
128
129        Ok(Self {
130            header,
131            ma_ctx,
132            group_dim,
133            bit_depth,
134            channels,
135            meta_channels,
136            image_channels,
137        })
138    }
139
140    pub fn try_clone(&self) -> Result<Self> {
141        Ok(Self {
142            header: self.header.clone(),
143            ma_ctx: self.ma_ctx.clone(),
144            group_dim: self.group_dim,
145            bit_depth: self.bit_depth,
146            channels: self.channels.clone(),
147            meta_channels: self
148                .meta_channels
149                .iter()
150                .map(|x| x.try_clone())
151                .collect::<std::result::Result<_, _>>()?,
152            image_channels: self
153                .image_channels
154                .iter()
155                .map(|x| x.try_clone())
156                .collect::<std::result::Result<_, _>>()?,
157        })
158    }
159
160    pub fn image_channels(&self) -> &[AlignedGrid<S>] {
161        &self.image_channels
162    }
163
164    pub fn into_image_channels(self) -> Vec<AlignedGrid<S>> {
165        self.image_channels
166    }
167
168    pub fn into_image_channels_with_info(
169        self,
170    ) -> impl Iterator<Item = (AlignedGrid<S>, ModularChannelInfo)> {
171        self.image_channels.into_iter().zip(self.channels.info)
172    }
173
174    pub fn has_palette(&self) -> bool {
175        self.header.transform.iter().any(|tr| tr.is_palette())
176    }
177
178    pub fn has_squeeze(&self) -> bool {
179        self.header.transform.iter().any(|tr| tr.is_squeeze())
180    }
181}
182
183impl<S: Sample> ModularImageDestination<S> {
184    pub fn prepare_gmodular(&mut self) -> Result<TransformedModularSubimage<S>> {
185        assert_ne!(self.group_dim, 0);
186
187        let group_dim = self.group_dim;
188        let subimage = self.prepare_subimage()?;
189        let (channel_info, grids): (Vec<_>, Vec<_>) = subimage
190            .channel_info
191            .into_iter()
192            .zip(subimage.grid)
193            .enumerate()
194            .take_while(|&(i, (ref info, _))| {
195                i < subimage.nb_meta_channels
196                    || (info.width <= group_dim && info.height <= group_dim)
197            })
198            .map(|(_, x)| x)
199            .unzip();
200        let channel_indices = (0..channel_info.len()).collect();
201        Ok(TransformedModularSubimage {
202            channel_info,
203            channel_indices,
204            grid: grids,
205            ..subimage
206        })
207    }
208
209    pub fn prepare_groups(
210        &mut self,
211        pass_shifts: &std::collections::BTreeMap<u32, (i32, i32)>,
212    ) -> Result<TransformedGlobalModular<S>> {
213        assert_ne!(self.group_dim, 0);
214
215        let num_passes = *pass_shifts.last_key_value().unwrap().0 as usize + 1;
216
217        let group_dim = self.group_dim;
218        let group_dim_shift = group_dim.trailing_zeros();
219        let bit_depth = self.bit_depth;
220        let subimage = self.prepare_subimage()?;
221        let it = subimage
222            .channel_info
223            .into_iter()
224            .zip(subimage.grid)
225            .enumerate()
226            .skip_while(|&(i, (ref info, _))| {
227                i < subimage.nb_meta_channels
228                    || (info.width <= group_dim && info.height <= group_dim)
229            });
230
231        let mut lf_groups = Vec::new();
232        let mut pass_groups = Vec::with_capacity(num_passes);
233        pass_groups.resize_with(num_passes, Vec::new);
234        for (i, (info, grid)) in it {
235            let ModularChannelInfo {
236                original_width,
237                original_height,
238                hshift,
239                vshift,
240                original_shift,
241                ..
242            } = info;
243            assert!(hshift >= 0 && vshift >= 0);
244
245            let grid = match grid {
246                TransformedGrid::Single(g) => g,
247                TransformedGrid::Merged { leader, .. } => leader,
248            };
249            tracing::trace!(
250                i,
251                width = grid.width(),
252                height = grid.height(),
253                hshift,
254                vshift
255            );
256
257            let (groups, grids) = if hshift < 3 || vshift < 3 {
258                let shift = hshift.min(vshift); // shift < 3
259                let pass_idx = *pass_shifts
260                    .iter()
261                    .find(|(_, &(minshift, maxshift))| (minshift..maxshift).contains(&shift))
262                    .unwrap()
263                    .0;
264                let pass_idx = pass_idx as usize;
265
266                let group_width = group_dim >> hshift;
267                let group_height = group_dim >> vshift;
268                if group_width == 0 || group_height == 0 {
269                    tracing::error!(
270                        group_dim,
271                        hshift,
272                        vshift,
273                        "Channel shift value too large after transform"
274                    );
275                    return Err(crate::Error::InvalidSqueezeParams);
276                }
277
278                let grids = grid.into_groups_with_fixed_count(
279                    group_width as usize,
280                    group_height as usize,
281                    (original_width + group_dim - 1) as usize >> group_dim_shift,
282                    (original_height + group_dim - 1) as usize >> group_dim_shift,
283                );
284                (&mut pass_groups[pass_idx], grids)
285            } else {
286                // hshift >= 3 && vshift >= 3
287                let lf_group_width = group_dim >> (hshift - 3);
288                let lf_group_height = group_dim >> (vshift - 3);
289                if lf_group_width == 0 || lf_group_height == 0 {
290                    tracing::error!(
291                        group_dim,
292                        hshift,
293                        vshift,
294                        "Channel shift value too large after transform"
295                    );
296                    return Err(crate::Error::InvalidSqueezeParams);
297                }
298                let grids = grid.into_groups_with_fixed_count(
299                    lf_group_width as usize,
300                    lf_group_height as usize,
301                    (original_width + (group_dim << 3) - 1) as usize >> (group_dim_shift + 3),
302                    (original_height + (group_dim << 3) - 1) as usize >> (group_dim_shift + 3),
303                );
304                (&mut lf_groups, grids)
305            };
306
307            if groups.is_empty() {
308                groups.resize_with(grids.len(), || {
309                    TransformedModularSubimage::empty(&subimage.header, &subimage.ma_ctx, bit_depth)
310                });
311            } else if groups.len() != grids.len() {
312                panic!();
313            }
314
315            for (subimage, grid) in groups.iter_mut().zip(grids) {
316                let width = grid.width() as u32;
317                let height = grid.height() as u32;
318                if width == 0 || height == 0 {
319                    continue;
320                }
321
322                subimage.channel_info.push(ModularChannelInfo {
323                    width,
324                    height,
325                    original_width: width << hshift,
326                    original_height: height << vshift,
327                    hshift,
328                    vshift,
329                    original_shift,
330                });
331                subimage.channel_indices.push(i);
332                subimage.grid.push(grid.into());
333                subimage.partial = true;
334            }
335        }
336
337        Ok(TransformedGlobalModular {
338            lf_groups,
339            pass_groups,
340        })
341    }
342
343    pub fn prepare_subimage(&mut self) -> Result<TransformedModularSubimage<S>> {
344        let mut channels = self.channels.clone();
345        let mut meta_channel_grids = self
346            .meta_channels
347            .iter_mut()
348            .map(MutableSubgrid::from)
349            .collect::<Vec<_>>();
350        let mut grids = self
351            .image_channels
352            .iter_mut()
353            .map(|g| g.as_subgrid_mut().into())
354            .collect::<Vec<_>>();
355        for tr in &self.header.transform {
356            tr.transform_channels(&mut channels, &mut meta_channel_grids, &mut grids)?;
357        }
358
359        let channel_info = channels.info;
360        let channel_indices = (0..channel_info.len()).collect();
361        Ok(TransformedModularSubimage {
362            header: self.header.clone(),
363            ma_ctx: self.ma_ctx.clone(),
364            bit_depth: self.bit_depth,
365            nb_meta_channels: channels.nb_meta_channels as usize,
366            channel_info,
367            channel_indices,
368            grid: grids,
369            partial: true,
370        })
371    }
372}
373
374#[derive(Debug)]
375pub struct TransformedGlobalModular<'dest, S: Sample> {
376    pub lf_groups: Vec<TransformedModularSubimage<'dest, S>>,
377    pub pass_groups: Vec<Vec<TransformedModularSubimage<'dest, S>>>,
378}
379
380#[derive(Debug)]
381pub struct TransformedModularSubimage<'dest, S: Sample> {
382    header: ModularHeader,
383    ma_ctx: MaConfig,
384    bit_depth: u32,
385    nb_meta_channels: usize,
386    channel_info: Vec<ModularChannelInfo>,
387    channel_indices: Vec<usize>,
388    grid: Vec<TransformedGrid<'dest, S>>,
389    partial: bool,
390}
391
392impl<S: Sample> TransformedModularSubimage<'_, S> {
393    fn empty(header: &ModularHeader, ma_ctx: &MaConfig, bit_depth: u32) -> Self {
394        Self {
395            header: header.clone(),
396            ma_ctx: ma_ctx.clone(),
397            bit_depth,
398            nb_meta_channels: 0,
399            channel_info: Vec::new(),
400            channel_indices: Vec::new(),
401            grid: Vec::new(),
402            partial: false,
403        }
404    }
405}
406
407impl<'dest, S: Sample> TransformedModularSubimage<'dest, S> {
408    pub fn is_empty(&self) -> bool {
409        self.channel_info.is_empty()
410    }
411
412    pub fn recursive(
413        self,
414        bitstream: &mut Bitstream,
415        global_ma_config: Option<&MaConfig>,
416        tracker: Option<&AllocTracker>,
417    ) -> Result<RecursiveModularImage<'dest, S>> {
418        let channels = crate::ModularChannels {
419            info: self.channel_info,
420            nb_meta_channels: 0,
421        };
422        let (header, ma_ctx) = crate::read_and_validate_local_modular_header(
423            bitstream,
424            &channels,
425            global_ma_config,
426            tracker,
427        )?;
428
429        let mut image = RecursiveModularImage {
430            header,
431            ma_ctx,
432            bit_depth: self.bit_depth,
433            channels,
434            meta_channels: Vec::new(),
435            image_channels: self.grid,
436        };
437        for tr in &image.header.transform {
438            tr.prepare_meta_channels(&mut image.meta_channels, tracker)?;
439        }
440        Ok(image)
441    }
442
443    pub fn finish(mut self, pool: &jxl_threadpool::JxlThreadPool) -> bool {
444        for tr in self.header.transform.iter().rev() {
445            tr.inverse(&mut self.grid, self.bit_depth, pool);
446        }
447        !self.partial
448    }
449}
450
451impl<S: Sample> TransformedModularSubimage<'_, S> {
452    fn decode_inner(&mut self, bitstream: &mut Bitstream, stream_index: u32) -> Result<()> {
453        let span = tracing::span!(tracing::Level::TRACE, "decode channels", stream_index);
454        let _guard = span.enter();
455
456        let dist_multiplier = self
457            .channel_info
458            .iter()
459            .map(|info| info.width)
460            .max()
461            .unwrap_or(0);
462
463        let mut decoder = self.ma_ctx.decoder().clone();
464        decoder.begin(bitstream)?;
465
466        let mut ma_tree_list = Vec::with_capacity(self.channel_info.len());
467        for (i, info) in self.channel_info.iter().enumerate() {
468            if info.width == 0 || info.height == 0 {
469                ma_tree_list.push(None);
470                continue;
471            }
472
473            let filtered_prev_len = self.channel_info[..i]
474                .iter()
475                .filter(|prev_info| {
476                    info.width == prev_info.width
477                        && info.height == prev_info.height
478                        && info.hshift == prev_info.hshift
479                        && info.vshift == prev_info.vshift
480                })
481                .count();
482
483            let ma_tree =
484                self.ma_ctx
485                    .make_flat_tree(i as u32, stream_index, filtered_prev_len as u32);
486            ma_tree_list.push(Some(ma_tree));
487        }
488
489        if let Some(mut rle_decoder) = decoder.as_rle() {
490            let is_fast_lossless = ma_tree_list.iter().all(|ma_tree| {
491                ma_tree
492                    .as_ref()
493                    .map(|ma_tree| {
494                        matches!(
495                            ma_tree.single_node(),
496                            Some(MaTreeLeafClustered {
497                                predictor: Predictor::Gradient,
498                                offset: 0,
499                                multiplier: 1,
500                                ..
501                            })
502                        )
503                    })
504                    .unwrap_or(true)
505            });
506
507            if is_fast_lossless {
508                tracing::trace!("libjxl fast-lossless");
509                let mut rle_state = RleState::<S>::new();
510
511                for (ma_tree, grid) in ma_tree_list.into_iter().zip(&mut self.grid) {
512                    let Some(ma_tree) = ma_tree else {
513                        continue;
514                    };
515
516                    let node = ma_tree.single_node().unwrap();
517                    let cluster = node.cluster;
518                    decode_fast_lossless(
519                        bitstream,
520                        &mut rle_decoder,
521                        &mut rle_state,
522                        cluster,
523                        grid.grid_mut(),
524                    );
525                }
526
527                rle_state.check_error()?;
528                // Prefix code doesn't have checksum
529                return Ok(());
530            }
531        }
532
533        let wp_header = &self.header.wp_params;
534        let mut predictor = PredictorState::new();
535        let mut prev_map = HashMap::new();
536        for ((info, ma_tree), grid) in self
537            .channel_info
538            .iter()
539            .zip(ma_tree_list)
540            .zip(&mut self.grid)
541        {
542            let Some(ma_tree) = ma_tree else {
543                continue;
544            };
545            let key = (info.width, info.height, info.hshift, info.vshift);
546
547            let filtered_prev = prev_map.entry(key).or_insert_with(Vec::new);
548
549            if let Some(node) = ma_tree.single_node() {
550                decode_single_node(
551                    bitstream,
552                    &mut decoder,
553                    dist_multiplier,
554                    &mut predictor,
555                    wp_header,
556                    grid.grid_mut(),
557                    node,
558                )?;
559            } else if let Some(table) = ma_tree.simple_table() {
560                decode_simple_table(
561                    bitstream,
562                    &mut decoder,
563                    dist_multiplier,
564                    &mut predictor,
565                    wp_header,
566                    grid.grid_mut(),
567                    &table,
568                )?;
569            } else {
570                let grid = grid.grid_mut();
571                let filtered_prev = &filtered_prev[..ma_tree.max_prev_channel_depth()];
572                let wp_header = ma_tree.need_self_correcting().then_some(wp_header);
573                predictor.reset(grid.width() as u32, filtered_prev, wp_header);
574                decode_slow(
575                    bitstream,
576                    &mut decoder,
577                    dist_multiplier,
578                    &ma_tree,
579                    &mut predictor,
580                    grid,
581                )?;
582            }
583
584            filtered_prev.insert(0, grid.grid());
585        }
586
587        decoder.finalize()?;
588        Ok(())
589    }
590
591    pub fn decode(
592        &mut self,
593        bitstream: &mut Bitstream,
594        stream_index: u32,
595        allow_partial: bool,
596    ) -> Result<()> {
597        match self.decode_inner(bitstream, stream_index) {
598            Err(e) if e.unexpected_eof() && allow_partial => {
599                tracing::debug!("Partially decoded Modular image");
600            }
601            Err(e) => return Err(e),
602            Ok(_) => {
603                self.partial = false;
604            }
605        }
606        Ok(())
607    }
608}
609
610#[derive(Debug)]
611pub struct RecursiveModularImage<'dest, S: Sample> {
612    header: ModularHeader,
613    ma_ctx: MaConfig,
614    bit_depth: u32,
615    channels: ModularChannels,
616    meta_channels: Vec<AlignedGrid<S>>,
617    image_channels: Vec<TransformedGrid<'dest, S>>,
618}
619
620impl<S: Sample> RecursiveModularImage<'_, S> {
621    pub fn prepare_subimage(&mut self) -> Result<TransformedModularSubimage<S>> {
622        let mut channels = self.channels.clone();
623        let mut meta_channel_grids = self
624            .meta_channels
625            .iter_mut()
626            .map(|g| {
627                let width = g.width();
628                let height = g.height();
629                MutableSubgrid::from_buf(g.buf_mut(), width, height, width)
630            })
631            .collect::<Vec<_>>();
632        let mut grids = self
633            .image_channels
634            .iter_mut()
635            .map(|g| g.reborrow())
636            .collect();
637        for tr in &self.header.transform {
638            tr.transform_channels(&mut channels, &mut meta_channel_grids, &mut grids)?;
639        }
640
641        let channel_info = channels.info;
642        let channel_indices = (0..channel_info.len()).collect();
643        Ok(TransformedModularSubimage {
644            header: self.header.clone(),
645            ma_ctx: self.ma_ctx.clone(),
646            bit_depth: self.bit_depth,
647            nb_meta_channels: channels.nb_meta_channels as usize,
648            channel_info,
649            channel_indices,
650            grid: grids,
651            partial: true,
652        })
653    }
654}
655
656struct RleState<S: Sample> {
657    value: S,
658    repeat: u32,
659    error: Option<Box<jxl_coding::Error>>,
660}
661
662impl<S: Sample> RleState<S> {
663    #[inline]
664    fn new() -> Self {
665        Self {
666            value: S::default(),
667            repeat: 0,
668            error: None,
669        }
670    }
671
672    #[inline(always)]
673    fn decode(
674        &mut self,
675        bitstream: &mut Bitstream,
676        decoder: &mut DecoderRleMode,
677        cluster: u8,
678    ) -> S {
679        if self.repeat == 0 {
680            let result = decoder.read_varint_clustered(bitstream, cluster);
681            match result {
682                Ok(RleToken::Value(v)) => {
683                    self.value = S::unpack_signed_u32(v);
684                    self.repeat = 1;
685                }
686                Ok(RleToken::Repeat(len)) => {
687                    self.repeat = len;
688                }
689                Err(e) if self.error.is_none() => {
690                    self.error = Some(Box::new(e));
691                }
692                _ => {}
693            }
694        }
695
696        self.repeat = self.repeat.wrapping_sub(1);
697        self.value
698    }
699
700    #[inline]
701    fn check_error(&mut self) -> Result<()> {
702        let error = self.error.take();
703        if let Some(error) = error {
704            let error = *error;
705            Err(error.into())
706        } else {
707            Ok(())
708        }
709    }
710}
711
712fn decode_single_node<S: Sample>(
713    bitstream: &mut Bitstream,
714    decoder: &mut Decoder,
715    dist_multiplier: u32,
716    predictor_state: &mut PredictorState<S>,
717    wp_header: &WpHeader,
718    grid: &mut MutableSubgrid<S>,
719    node: &MaTreeLeafClustered,
720) -> Result<()> {
721    let &MaTreeLeafClustered {
722        cluster,
723        predictor,
724        offset,
725        multiplier,
726    } = node;
727    tracing::trace!(cluster, ?predictor, "Single MA tree node");
728
729    let height = grid.height();
730    let single_token = decoder.single_token(cluster);
731    match (predictor, single_token) {
732        (Predictor::Zero, Some(token)) => {
733            tracing::trace!("Single token in cluster, Zero predictor: hyper fast path");
734            let value = S::unpack_signed_u32(token).wrapping_muladd_i32(multiplier as i32, offset);
735            for y in 0..height {
736                grid.get_row_mut(y).fill(value);
737            }
738            Ok(())
739        }
740        (Predictor::Zero, None) => {
741            tracing::trace!("Zero predictor: fast path");
742            for y in 0..height {
743                let row = grid.get_row_mut(y);
744                for out in row {
745                    let token = decoder.read_varint_with_multiplier_clustered(
746                        bitstream,
747                        cluster,
748                        dist_multiplier,
749                    )?;
750                    *out =
751                        S::unpack_signed_u32(token).wrapping_muladd_i32(multiplier as i32, offset);
752                }
753            }
754            Ok(())
755        }
756        (Predictor::Gradient, _) if offset == 0 && multiplier == 1 => {
757            tracing::trace!("Simple gradient: quite fast path");
758            decode_simple_grad(bitstream, decoder, cluster, dist_multiplier, grid)
759        }
760        _ => {
761            let wp_header = (predictor == Predictor::SelfCorrecting).then_some(wp_header);
762            predictor_state.reset(grid.width() as u32, &[], wp_header);
763            decode_single_node_slow(
764                bitstream,
765                decoder,
766                dist_multiplier,
767                node,
768                predictor_state,
769                grid,
770            )
771        }
772    }
773}
774
775#[inline(never)]
776fn decode_fast_lossless<S: Sample>(
777    bitstream: &mut Bitstream,
778    decoder: &mut DecoderRleMode,
779    rle_state: &mut RleState<S>,
780    cluster: u8,
781    grid: &mut MutableSubgrid<S>,
782) {
783    let height = grid.height();
784
785    {
786        let mut w = S::default();
787        let out_row = grid.get_row_mut(0);
788        for out in &mut *out_row {
789            let token = rle_state.decode(bitstream, decoder, cluster);
790            w = w.add(token);
791            *out = w;
792        }
793    }
794
795    for y in 1..height {
796        let (u, mut d) = grid.split_vertical(y);
797        let prev_row = u.get_row(y - 1);
798        let out_row = d.get_row_mut(0);
799
800        let token = rle_state.decode(bitstream, decoder, cluster);
801        let mut w = token.add(prev_row[0]);
802        out_row[0] = w;
803
804        for (window, out) in prev_row.windows(2).zip(&mut out_row[1..]) {
805            let nw = window[0];
806            let n = window[1];
807            let pred = S::grad_clamped(n, w, nw);
808
809            let token = rle_state.decode(bitstream, decoder, cluster);
810            w = token.add(pred);
811            *out = w;
812        }
813    }
814}
815
816#[inline(never)]
817fn decode_simple_grad<S: Sample>(
818    bitstream: &mut Bitstream,
819    decoder: &mut Decoder,
820    cluster: u8,
821    dist_multiplier: u32,
822    grid: &mut MutableSubgrid<S>,
823) -> Result<()> {
824    let width = grid.width();
825    let height = grid.height();
826
827    {
828        let mut w = S::default();
829        let out_row = grid.get_row_mut(0);
830        for out in out_row[..width].iter_mut() {
831            let token = decoder.read_varint_with_multiplier_clustered(
832                bitstream,
833                cluster,
834                dist_multiplier,
835            )?;
836            w = S::unpack_signed_u32(token).add(w);
837            *out = w;
838        }
839    }
840
841    for y in 1..height {
842        let (u, mut d) = grid.split_vertical(y);
843        let prev_row = u.get_row(y - 1);
844        let out_row = d.get_row_mut(0);
845
846        let token =
847            decoder.read_varint_with_multiplier_clustered(bitstream, cluster, dist_multiplier)?;
848        let mut w = S::unpack_signed_u32(token).add(prev_row[0]);
849        out_row[0] = w;
850
851        for (window, out) in prev_row.windows(2).zip(&mut out_row[1..]) {
852            let nw = window[0];
853            let n = window[1];
854            let pred = S::grad_clamped(n, w, nw);
855
856            let token = decoder.read_varint_with_multiplier_clustered(
857                bitstream,
858                cluster,
859                dist_multiplier,
860            )?;
861            let value = S::unpack_signed_u32(token).add(pred);
862            *out = value;
863            w = value;
864        }
865    }
866
867    Ok(())
868}
869
870#[inline(always)]
871fn decode_one<S: Sample, const EDGE: bool>(
872    bitstream: &mut Bitstream,
873    decoder: &mut Decoder,
874    dist_multiplier: u32,
875    leaf: &MaTreeLeafClustered,
876    properties: &Properties<S>,
877) -> Result<S> {
878    let diff = S::unpack_signed_u32(decoder.read_varint_with_multiplier_clustered(
879        bitstream,
880        leaf.cluster,
881        dist_multiplier,
882    )?);
883    let diff = diff.wrapping_muladd_i32(leaf.multiplier as i32, leaf.offset);
884    let predictor = leaf.predictor;
885    let sample_prediction = predictor.predict::<_, EDGE>(properties);
886    Ok(diff.add(S::from_i32(sample_prediction)))
887}
888
889#[inline(never)]
890fn decode_single_node_slow<S: Sample>(
891    bitstream: &mut Bitstream,
892    decoder: &mut Decoder,
893    dist_multiplier: u32,
894    leaf: &MaTreeLeafClustered,
895    predictor: &mut PredictorState<S>,
896    grid: &mut MutableSubgrid<S>,
897) -> Result<()> {
898    let height = grid.height();
899    for y in 0..2usize.min(height) {
900        let row = grid.get_row_mut(y);
901
902        for out in row.iter_mut() {
903            let properties = predictor.properties::<true>();
904            let true_value =
905                decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
906            *out = true_value;
907            properties.record(true_value.to_i32());
908        }
909    }
910
911    for y in 2..height {
912        let row = grid.get_row_mut(y);
913        let (row_left, row_middle, row_right) = if row.len() <= 4 {
914            (row, [].as_mut(), [].as_mut())
915        } else {
916            let (l, m) = row.split_at_mut(2);
917            let (m, r) = m.split_at_mut(m.len() - 2);
918            (l, m, r)
919        };
920
921        for out in row_left {
922            let properties = predictor.properties::<true>();
923            let true_value =
924                decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
925            *out = true_value;
926            properties.record(true_value.to_i32());
927        }
928        for out in row_middle {
929            let properties = predictor.properties::<false>();
930            let true_value =
931                decode_one::<_, false>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
932            *out = true_value;
933            properties.record(true_value.to_i32());
934        }
935        for out in row_right {
936            let properties = predictor.properties::<true>();
937            let true_value =
938                decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
939            *out = true_value;
940            properties.record(true_value.to_i32());
941        }
942    }
943
944    Ok(())
945}
946
947fn decode_simple_table<S: Sample>(
948    bitstream: &mut Bitstream,
949    decoder: &mut Decoder,
950    dist_multiplier: u32,
951    predictor_state: &mut PredictorState<S>,
952    wp_header: &WpHeader,
953    grid: &mut MutableSubgrid<S>,
954    table: &SimpleMaTable,
955) -> Result<()> {
956    let &SimpleMaTable {
957        decision_prop,
958        value_base,
959        predictor,
960        offset,
961        multiplier,
962        ref cluster_table,
963    } = table;
964
965    if offset == 0 && multiplier == 1 && decision_prop == 9 && predictor == Predictor::Gradient {
966        return decode_gradient_table(
967            bitstream,
968            decoder,
969            dist_multiplier,
970            grid,
971            value_base,
972            cluster_table,
973        );
974    }
975
976    decode_simple_table_slow(
977        bitstream,
978        decoder,
979        dist_multiplier,
980        predictor_state,
981        wp_header,
982        grid,
983        table,
984    )
985}
986
987#[inline(always)]
988fn cluster_from_table(sample: i32, value_base: i32, cluster_table: &[u8]) -> u8 {
989    let index = (sample - value_base).clamp(0, cluster_table.len() as i32 - 1);
990    cluster_table[index as usize]
991}
992
993fn decode_gradient_table<S: Sample>(
994    bitstream: &mut Bitstream,
995    decoder: &mut Decoder,
996    dist_multiplier: u32,
997    grid: &mut MutableSubgrid<S>,
998    value_base: i32,
999    cluster_table: &[u8],
1000) -> Result<()> {
1001    tracing::trace!("Gradient-only lookup table");
1002
1003    let width = grid.width();
1004    let height = grid.height();
1005
1006    {
1007        let mut w = S::default();
1008        let out_row = grid.get_row_mut(0);
1009        for out in out_row[..width].iter_mut() {
1010            let cluster = cluster_from_table(w.to_i32(), value_base, cluster_table);
1011            let token = decoder.read_varint_with_multiplier_clustered(
1012                bitstream,
1013                cluster,
1014                dist_multiplier,
1015            )?;
1016            w = S::unpack_signed_u32(token).add(w);
1017            *out = w;
1018        }
1019    }
1020
1021    for y in 1..height {
1022        let (u, mut d) = grid.split_vertical(y);
1023        let prev_row = u.get_row(y - 1);
1024        let out_row = d.get_row_mut(0);
1025
1026        let cluster = cluster_from_table(prev_row[0].to_i32(), value_base, cluster_table);
1027        let token =
1028            decoder.read_varint_with_multiplier_clustered(bitstream, cluster, dist_multiplier)?;
1029        let mut w = S::unpack_signed_u32(token).add(prev_row[0]);
1030        out_row[0] = w;
1031
1032        for (window, out) in prev_row.windows(2).zip(&mut out_row[1..]) {
1033            let nw = window[0];
1034            let n = window[1];
1035            let prop = n
1036                .to_i32()
1037                .wrapping_add(w.to_i32())
1038                .wrapping_sub(nw.to_i32());
1039            let pred = S::grad_clamped(n, w, nw);
1040
1041            let cluster = cluster_from_table(prop, value_base, cluster_table);
1042            let token = decoder.read_varint_with_multiplier_clustered(
1043                bitstream,
1044                cluster,
1045                dist_multiplier,
1046            )?;
1047            let value = S::unpack_signed_u32(token).add(pred);
1048            *out = value;
1049            w = value;
1050        }
1051    }
1052
1053    Ok(())
1054}
1055
1056#[inline(always)]
1057fn decode_table_one<S: Sample, const EDGE: bool>(
1058    bitstream: &mut Bitstream,
1059    decoder: &mut Decoder,
1060    dist_multiplier: u32,
1061    table: &SimpleMaTable,
1062    properties: &Properties<S>,
1063) -> Result<S> {
1064    let prop_value = properties.get(table.decision_prop as usize);
1065
1066    let cluster = cluster_from_table(prop_value, table.value_base, &table.cluster_table);
1067
1068    let diff = S::unpack_signed_u32(decoder.read_varint_with_multiplier_clustered(
1069        bitstream,
1070        cluster,
1071        dist_multiplier,
1072    )?);
1073    let diff = diff.wrapping_muladd_i32(table.multiplier as i32, table.offset);
1074    let predictor = table.predictor;
1075    let sample_prediction = predictor.predict::<_, EDGE>(properties);
1076    Ok(diff.add(S::from_i32(sample_prediction)))
1077}
1078
1079#[inline(never)]
1080fn decode_simple_table_slow<S: Sample>(
1081    bitstream: &mut Bitstream,
1082    decoder: &mut Decoder,
1083    dist_multiplier: u32,
1084    predictor_state: &mut PredictorState<S>,
1085    wp_header: &WpHeader,
1086    grid: &mut MutableSubgrid<S>,
1087    table: &SimpleMaTable,
1088) -> Result<()> {
1089    tracing::trace!("Slow lookup table");
1090
1091    let need_wp_header = table.decision_prop == 15 || table.predictor == Predictor::SelfCorrecting;
1092    let wp_header = need_wp_header.then_some(wp_header);
1093    predictor_state.reset(grid.width() as u32, &[], wp_header);
1094
1095    let height = grid.height();
1096    for y in 0..2usize.min(height) {
1097        let row = grid.get_row_mut(y);
1098
1099        for out in row.iter_mut() {
1100            let properties = predictor_state.properties::<true>();
1101            let true_value = decode_table_one::<_, true>(
1102                bitstream,
1103                decoder,
1104                dist_multiplier,
1105                table,
1106                &properties,
1107            )?;
1108            *out = true_value;
1109            properties.record(true_value.to_i32());
1110        }
1111    }
1112
1113    for y in 2..height {
1114        let row = grid.get_row_mut(y);
1115        let (row_left, row_middle, row_right) = if row.len() <= 4 {
1116            (row, [].as_mut(), [].as_mut())
1117        } else {
1118            let (l, m) = row.split_at_mut(2);
1119            let (m, r) = m.split_at_mut(m.len() - 2);
1120            (l, m, r)
1121        };
1122
1123        for out in row_left {
1124            let properties = predictor_state.properties::<true>();
1125            let true_value = decode_table_one::<_, true>(
1126                bitstream,
1127                decoder,
1128                dist_multiplier,
1129                table,
1130                &properties,
1131            )?;
1132            *out = true_value;
1133            properties.record(true_value.to_i32());
1134        }
1135        for out in row_middle {
1136            let properties = predictor_state.properties::<false>();
1137            let true_value = decode_table_one::<_, false>(
1138                bitstream,
1139                decoder,
1140                dist_multiplier,
1141                table,
1142                &properties,
1143            )?;
1144            *out = true_value;
1145            properties.record(true_value.to_i32());
1146        }
1147        for out in row_right {
1148            let properties = predictor_state.properties::<true>();
1149            let true_value = decode_table_one::<_, true>(
1150                bitstream,
1151                decoder,
1152                dist_multiplier,
1153                table,
1154                &properties,
1155            )?;
1156            *out = true_value;
1157            properties.record(true_value.to_i32());
1158        }
1159    }
1160
1161    Ok(())
1162}
1163
1164#[inline(never)]
1165fn decode_slow<S: Sample>(
1166    bitstream: &mut Bitstream,
1167    decoder: &mut Decoder,
1168    dist_multiplier: u32,
1169    ma_tree: &FlatMaTree,
1170    predictor: &mut PredictorState<S>,
1171    grid: &mut MutableSubgrid<S>,
1172) -> Result<()> {
1173    let height = grid.height();
1174    for y in 0..2usize.min(height) {
1175        let row = grid.get_row_mut(y);
1176
1177        for out in row.iter_mut() {
1178            let properties = predictor.properties::<true>();
1179            let leaf = ma_tree.get_leaf(&properties);
1180            let true_value =
1181                decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1182            *out = true_value;
1183            properties.record(true_value.to_i32());
1184        }
1185    }
1186
1187    for y in 2..height {
1188        let row = grid.get_row_mut(y);
1189        let (row_left, row_middle, row_right) = if row.len() <= 4 {
1190            (row, [].as_mut(), [].as_mut())
1191        } else {
1192            let (l, m) = row.split_at_mut(2);
1193            let (m, r) = m.split_at_mut(m.len() - 2);
1194            (l, m, r)
1195        };
1196
1197        for out in row_left {
1198            let properties = predictor.properties::<true>();
1199            let leaf = ma_tree.get_leaf(&properties);
1200            let true_value =
1201                decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1202            *out = true_value;
1203            properties.record(true_value.to_i32());
1204        }
1205        for out in row_middle {
1206            let properties = predictor.properties::<false>();
1207            let leaf = ma_tree.get_leaf(&properties);
1208            let true_value =
1209                decode_one::<_, false>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1210            *out = true_value;
1211            properties.record(true_value.to_i32());
1212        }
1213        for out in row_right {
1214            let properties = predictor.properties::<true>();
1215            let leaf = ma_tree.get_leaf(&properties);
1216            let true_value =
1217                decode_one::<_, true>(bitstream, decoder, dist_multiplier, leaf, &properties)?;
1218            *out = true_value;
1219            properties.record(true_value.to_i32());
1220        }
1221    }
1222
1223    Ok(())
1224}