1use jxl_bitstream::{unpack_signed, Bitstream};
2use jxl_image::ImageHeader;
3use jxl_oxide_common::Bundle;
4
5use crate::{FrameHeader, Result};
6
7#[derive(Debug)]
8pub struct Patches {
9 pub patches: Vec<PatchRef>,
10}
11
12#[derive(Debug)]
13pub struct PatchRef {
14 pub ref_idx: u32,
15 pub x0: u32,
16 pub y0: u32,
17 pub width: u32,
18 pub height: u32,
19 pub patch_targets: Vec<PatchTarget>,
20}
21
22#[derive(Debug)]
23pub struct PatchTarget {
24 pub x: i32,
25 pub y: i32,
26 pub blending: Vec<BlendingModeInformation>,
27}
28
29#[derive(Debug)]
30pub struct BlendingModeInformation {
31 pub mode: PatchBlendMode,
32 pub alpha_channel: u32,
33 pub clamp: bool,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37#[repr(u8)]
38pub enum PatchBlendMode {
39 None = 0,
40 Replace,
41 Add,
42 Mul,
43 BlendAbove,
44 BlendBelow,
45 MulAddAbove,
46 MulAddBelow,
47}
48
49impl TryFrom<u32> for PatchBlendMode {
50 type Error = jxl_bitstream::Error;
51
52 fn try_from(value: u32) -> std::result::Result<Self, Self::Error> {
53 use PatchBlendMode::*;
54
55 Ok(match value {
56 0 => PatchBlendMode::None,
57 1 => Replace,
58 2 => Add,
59 3 => Mul,
60 4 => BlendAbove,
61 5 => BlendBelow,
62 6 => MulAddAbove,
63 7 => MulAddBelow,
64 _ => {
65 return Err(jxl_bitstream::Error::InvalidEnum {
66 name: "PatchBlendMode",
67 value,
68 })
69 }
70 })
71 }
72}
73
74impl PatchBlendMode {
75 #[inline]
76 pub fn use_alpha(self) -> bool {
77 matches!(
78 self,
79 Self::BlendAbove | Self::BlendBelow | Self::MulAddAbove | Self::MulAddBelow
80 )
81 }
82}
83
84impl Bundle<(&ImageHeader, &FrameHeader)> for Patches {
85 type Error = crate::Error;
86
87 fn parse(
88 bitstream: &mut Bitstream,
89 (image_header, frame_header): (&ImageHeader, &FrameHeader),
90 ) -> Result<Self> {
91 let num_extra = image_header.metadata.ec_info.len();
92 let alpha_channel_indices = image_header
93 .metadata
94 .ec_info
95 .iter()
96 .enumerate()
97 .filter_map(|(idx, info)| info.is_alpha().then_some(idx as u32))
98 .collect::<Vec<_>>();
99
100 let mut decoder = jxl_coding::Decoder::parse(bitstream, 10)?;
101 decoder.begin(bitstream)?;
102
103 let frame_width = frame_header.width;
104 let frame_height = frame_header.height;
105 let max_num_patch_refs =
106 (1 << 24).min((frame_width as u64 * frame_height as u64 / 16) as u32);
107 let max_num_patches = max_num_patch_refs * 4; let num_patch_refs = decoder.read_varint(bitstream, 0)?;
110 tracing::trace!(num_patch_refs, "Patch ref");
111 if num_patch_refs > max_num_patch_refs {
112 tracing::error!(num_patch_refs, max_num_patch_refs, "Too many patches");
113 return Err(jxl_bitstream::Error::ProfileConformance("too many patches").into());
114 }
115
116 let mut total_patches = 0u32;
117 let patches = std::iter::repeat_with(|| -> Result<_> {
118 let ref_idx = decoder.read_varint(bitstream, 1)?;
119 if ref_idx >= 4 {
120 tracing::error!(ref_idx, "PatchRef index out of bounds");
121 return Err(
122 jxl_bitstream::Error::ValidationFailed("PatchRef index out of bounds").into(),
123 );
124 }
125 let x0 = decoder.read_varint(bitstream, 3)?;
126 let y0 = decoder.read_varint(bitstream, 3)?;
127 let width = decoder.read_varint(bitstream, 2)? + 1;
128 let height = decoder.read_varint(bitstream, 2)? + 1;
129 let count = decoder.read_varint(bitstream, 7)? + 1;
130 tracing::trace!(ref_idx, x0, y0, width, height, count, "Patch target");
131
132 total_patches += count;
133 if total_patches > max_num_patches {
134 tracing::error!(total_patches, max_num_patches, "Too many patches");
135 return Err(jxl_bitstream::Error::ProfileConformance("too many patches").into());
136 }
137
138 let mut prev_xy = None;
139 let patch_targets = std::iter::repeat_with(|| -> Result<_> {
140 let (x, y) = if let Some((px, py)) = prev_xy {
141 let dx = decoder.read_varint(bitstream, 6)?;
142 let dy = decoder.read_varint(bitstream, 6)?;
143 let dx = unpack_signed(dx);
144 let dy = unpack_signed(dy);
145 let x = dx.checked_add(px);
146 let y = dy.checked_add(py);
147 let (Some(x), Some(y)) = (x, y) else {
148 tracing::error!(px, py, dx, dy, "Patch coord overflow");
149 return Err(
150 jxl_bitstream::Error::ValidationFailed("patch coord overflow").into(),
151 );
152 };
153 (x, y)
154 } else {
155 (
156 decoder.read_varint(bitstream, 4)? as i32,
157 decoder.read_varint(bitstream, 4)? as i32,
158 )
159 };
160 prev_xy = Some((x, y));
161
162 let blending = std::iter::repeat_with(|| -> Result<_> {
163 let raw_mode = decoder.read_varint(bitstream, 5)?;
164 let mode = PatchBlendMode::try_from(raw_mode)?;
165 let alpha_channel = if raw_mode >= 4 && alpha_channel_indices.len() >= 2 {
166 decoder.read_varint(bitstream, 8)?
167 } else {
168 alpha_channel_indices.first().copied().unwrap_or_default()
169 };
170 let clamp = if raw_mode >= 3 {
171 decoder.read_varint(bitstream, 9)? != 0
172 } else {
173 false
174 };
175
176 Ok(BlendingModeInformation {
177 mode,
178 alpha_channel,
179 clamp,
180 })
181 })
182 .take(num_extra + 1)
183 .collect::<Result<Vec<_>>>()?;
184
185 Ok(PatchTarget { x, y, blending })
186 })
187 .take(count as usize)
188 .collect::<Result<Vec<_>>>()?;
189
190 Ok(PatchRef {
191 ref_idx,
192 x0,
193 y0,
194 width,
195 height,
196 patch_targets,
197 })
198 })
199 .take(num_patch_refs as usize)
200 .collect::<Result<Vec<_>>>()?;
201
202 decoder.finalize()?;
203 Ok(Self { patches })
204 }
205}