1use std::sync::atomic::AtomicI32;
2
3use jxl_bitstream::Bitstream;
4use jxl_grid::{AllocTracker, SharedSubgrid};
5use jxl_modular::{image::TransformedModularSubimage, ChannelShift, MaConfig, Sample};
6use jxl_threadpool::JxlThreadPool;
7use jxl_vardct::{write_hf_coeff, HfCoeffParams};
8
9use super::{HfGlobal, LfGlobalVarDct, LfGroup};
10use crate::{FrameHeader, Result};
11
12#[derive(Debug)]
13pub struct PassGroupParams<'frame, 'buf, 'g, 'tracker, S: Sample> {
14 pub frame_header: &'frame FrameHeader,
15 pub lf_group: &'frame LfGroup<S>,
16 pub pass_idx: u32,
17 pub group_idx: u32,
18 pub global_ma_config: Option<&'frame MaConfig>,
19 pub modular: Option<TransformedModularSubimage<'g, S>>,
20 pub vardct: Option<PassGroupParamsVardct<'frame, 'buf, 'g>>,
21 pub allow_partial: bool,
22 pub tracker: Option<&'tracker AllocTracker>,
23 pub pool: &'frame JxlThreadPool,
24}
25
26#[derive(Debug)]
27pub struct PassGroupParamsVardct<'frame, 'buf, 'g> {
28 pub lf_vardct: &'frame LfGlobalVarDct,
29 pub hf_global: &'frame HfGlobal,
30 pub hf_coeff_output: &'buf [SharedSubgrid<'g, AtomicI32>; 3],
31}
32
33pub fn decode_pass_group<S: Sample>(
34 bitstream: &mut Bitstream,
35 params: PassGroupParams<S>,
36) -> Result<()> {
37 let PassGroupParams {
38 frame_header,
39 lf_group,
40 pass_idx,
41 group_idx,
42 global_ma_config,
43 modular,
44 vardct,
45 allow_partial,
46 tracker,
47 pool,
48 } = params;
49
50 if let (
51 Some(PassGroupParamsVardct {
52 lf_vardct,
53 hf_global,
54 hf_coeff_output,
55 }),
56 Some(hf_meta),
57 ) = (vardct, &lf_group.hf_meta)
58 {
59 let hf_pass = &hf_global.hf_passes[pass_idx as usize];
60 let coeff_shift = frame_header
61 .passes
62 .shift
63 .get(pass_idx as usize)
64 .copied()
65 .unwrap_or(0);
66
67 let group_col = group_idx % frame_header.groups_per_row();
68 let group_row = group_idx / frame_header.groups_per_row();
69 let lf_col = (group_col % 8) as usize;
70 let lf_row = (group_row % 8) as usize;
71 let group_dim_blocks = (frame_header.group_dim() / 8) as usize;
72
73 let block_info = &hf_meta.block_info;
74
75 let block_left = lf_col * group_dim_blocks;
76 let block_top = lf_row * group_dim_blocks;
77 let block_width = (block_info.width() - block_left).min(group_dim_blocks);
78 let block_height = (block_info.height() - block_top).min(group_dim_blocks);
79
80 let jpeg_upsampling = frame_header.jpeg_upsampling;
81 let block_info = block_info.as_subgrid().subgrid(
82 block_left..(block_left + block_width),
83 block_top..(block_top + block_height),
84 );
85 let lf_quant: Option<[_; 3]> = lf_group.lf_coeff.as_ref().map(|lf_coeff| {
86 let lf_quant_channels = lf_coeff.lf_quant.image().unwrap().image_channels();
87 std::array::from_fn(|idx| {
88 let lf_quant = &lf_quant_channels[[1, 0, 2][idx]];
89 let shift = ChannelShift::from_jpeg_upsampling(jpeg_upsampling, idx);
90
91 let block_left = block_left >> shift.hshift();
92 let block_top = block_top >> shift.vshift();
93 let (block_width, block_height) =
94 shift.shift_size((block_width as u32, block_height as u32));
95 lf_quant.as_subgrid().subgrid(
96 block_left..(block_left + block_width as usize),
97 block_top..(block_top + block_height as usize),
98 )
99 })
100 });
101
102 let params = HfCoeffParams {
103 num_hf_presets: hf_global.num_hf_presets,
104 hf_block_ctx: &lf_vardct.hf_block_ctx,
105 block_info,
106 jpeg_upsampling,
107 lf_quant,
108 hf_pass,
109 coeff_shift,
110 tracker,
111 };
112
113 match write_hf_coeff(bitstream, params, hf_coeff_output) {
114 Err(e) if e.unexpected_eof() && allow_partial => {
115 tracing::debug!("Partially decoded HfCoeff");
116 return Ok(());
117 }
118 Err(e) => return Err(e.into()),
119 Ok(_) => {}
120 };
121 }
122
123 if let Some(modular) = modular {
124 decode_pass_group_modular(
125 bitstream,
126 frame_header,
127 global_ma_config,
128 pass_idx,
129 group_idx,
130 modular,
131 allow_partial,
132 tracker,
133 pool,
134 )?;
135 }
136
137 Ok(())
138}
139
140#[allow(clippy::too_many_arguments)]
141pub fn decode_pass_group_modular<S: Sample>(
142 bitstream: &mut Bitstream,
143 frame_header: &FrameHeader,
144 global_ma_config: Option<&MaConfig>,
145 pass_idx: u32,
146 group_idx: u32,
147 modular: TransformedModularSubimage<S>,
148 allow_partial: bool,
149 tracker: Option<&AllocTracker>,
150 pool: &JxlThreadPool,
151) -> Result<()> {
152 if modular.is_empty() {
153 return Ok(());
154 }
155
156 let mut modular = modular.recursive(bitstream, global_ma_config, tracker)?;
157 let mut subimage = modular.prepare_subimage()?;
158 subimage.decode(
159 bitstream,
160 1 + 3 * frame_header.num_lf_groups()
161 + 17
162 + pass_idx * frame_header.num_groups()
163 + group_idx,
164 allow_partial,
165 )?;
166 subimage.finish(pool);
167 Ok(())
168}