1use jxl_bitstream::{unpack_signed, Bitstream};
2use jxl_coding::Decoder;
3use jxl_oxide_common::Bundle;
4
5use crate::{FrameHeader, Result};
6
7const MAX_NUM_SPLINES: usize = 1 << 24;
8const MAX_NUM_CONTROL_POINTS: usize = 1 << 20;
9
10#[derive(Debug)]
12pub struct Splines {
13 pub quant_splines: Vec<QuantSpline>,
14 pub quant_adjust: i32,
15}
16
17impl Bundle<&FrameHeader> for Splines {
18 type Error = crate::Error;
19
20 fn parse(bitstream: &mut Bitstream, header: &FrameHeader) -> Result<Self> {
21 let mut decoder = jxl_coding::Decoder::parse(bitstream, 6)?;
22 decoder.begin(bitstream)?;
23
24 let num_splines = decoder.read_varint(bitstream, 2)? as usize;
25 let num_pixels = (header.width * header.height) as usize;
26 let max_num_splines = usize::min(MAX_NUM_SPLINES, num_pixels / 4);
27 if num_splines >= max_num_splines {
28 tracing::error!(num_splines, max_num_splines, "Too many splines");
29 return Err(jxl_bitstream::Error::ProfileConformance("too many splines").into());
30 }
31 let num_splines = num_splines + 1;
32
33 let mut start_points = vec![(0i64, 0i64); num_splines];
34 let mut prev_point = (
35 decoder.read_varint(bitstream, 1)? as i64,
36 decoder.read_varint(bitstream, 1)? as i64,
37 );
38 start_points[0] = prev_point;
39 for next_point in &mut start_points[1..] {
40 let x = decoder.read_varint(bitstream, 1)?;
41 let y = decoder.read_varint(bitstream, 1)?;
42 prev_point.0 += unpack_signed(x) as i64;
43 prev_point.1 += unpack_signed(y) as i64;
44 *next_point = prev_point;
45 }
46
47 let quant_adjust = unpack_signed(decoder.read_varint(bitstream, 0)?);
48
49 let mut splines: Vec<QuantSpline> = Vec::with_capacity(num_splines);
50 let mut acc_control_points = 0usize;
51 for start_point in start_points {
52 let spline = QuantSpline::parse(
53 bitstream,
54 QuantSplineParams::new(start_point, num_pixels, &mut decoder, acc_control_points),
55 )?;
56
57 acc_control_points += spline.quant_points.len();
58 splines.push(spline);
59 }
60
61 decoder.finalize()?;
62
63 Ok(Self {
64 quant_adjust,
65 quant_splines: splines,
66 })
67 }
68}
69
70impl Splines {
71 pub(crate) fn estimate_area(&self, base_correlation_xb: Option<(f32, f32)>) -> u64 {
72 let base_correlation_xb = base_correlation_xb.unwrap_or((0.0, 1.0));
73 let corr_x = base_correlation_xb.0.abs().ceil() as u64;
74 let corr_b = base_correlation_xb.1.abs().ceil() as u64;
75 let quant_adjust = self.quant_adjust;
76 let mut total_area = 0u64;
77
78 for quant_spline in &self.quant_splines {
79 let log_color = {
80 let mut color_xyb = quant_spline.xyb_dct.map(|quant_color_dct| -> u64 {
81 quant_color_dct
82 .into_iter()
83 .map(|q| div_ceil_qa(q.unsigned_abs(), quant_adjust))
84 .sum()
85 });
86
87 color_xyb[0] += corr_x * color_xyb[1];
88 color_xyb[2] += corr_b * color_xyb[1];
89 log2_ceil(1u64 + color_xyb.into_iter().max().unwrap()) as u64
90 };
91
92 let mut width_estimate = 0u64;
93 for quant_sigma_dct in quant_spline.sigma_dct {
94 let quant_sigma_dct = quant_sigma_dct.unsigned_abs();
95 let weight = 1 + div_ceil_qa(quant_sigma_dct, quant_adjust);
96 width_estimate += weight * weight * log_color;
97 }
98
99 total_area += width_estimate * quant_spline.manhattan_distance;
100 }
101
102 total_area
103 }
104}
105
106#[inline]
107fn log2_ceil(x: u64) -> u32 {
108 x.next_power_of_two().trailing_zeros()
109}
110
111#[inline]
112fn div_ceil_qa(dividend: u32, quant_adjust: i32) -> u64 {
113 let dividend = dividend as u64;
114 if quant_adjust >= 0 {
115 let quant_adjust = quant_adjust as u64;
116 (8 * dividend + 7 + quant_adjust) / (8 + quant_adjust)
117 } else {
118 let abs_quant_adjust = (-quant_adjust) as u64;
119 dividend + (dividend * abs_quant_adjust + 7) / 8
120 }
121}
122
123struct QuantSplineParams<'d> {
124 start_point: (i64, i64),
125 num_pixels: usize,
126 decoder: &'d mut Decoder,
127 acc_control_points: usize,
128}
129
130impl<'d> QuantSplineParams<'d> {
131 fn new(
132 start_point: (i64, i64),
133 num_pixels: usize,
134 decoder: &'d mut Decoder,
135 acc_control_points: usize,
136 ) -> Self {
137 Self {
138 start_point,
139 num_pixels,
140 decoder,
141 acc_control_points,
142 }
143 }
144}
145
146#[derive(Debug, Default, Clone)]
148pub struct QuantSpline {
149 pub quant_points: Vec<(i64, i64)>,
150 pub manhattan_distance: u64,
151 pub xyb_dct: [[i32; 32]; 3],
152 pub sigma_dct: [i32; 32],
153}
154
155impl Bundle<QuantSplineParams<'_>> for QuantSpline {
156 type Error = crate::Error;
157
158 fn parse(
159 bitstream: &mut Bitstream,
160 params: QuantSplineParams<'_>,
161 ) -> std::result::Result<Self, Self::Error> {
162 let QuantSplineParams {
163 start_point,
164 num_pixels,
165 decoder,
166 acc_control_points,
167 } = params;
168
169 let num_points = decoder.read_varint(bitstream, 3)? as usize;
170 let acc_num_points = acc_control_points + num_points;
171 let max_num_points = usize::min(MAX_NUM_CONTROL_POINTS, num_pixels / 2);
172 if acc_num_points > max_num_points {
173 tracing::error!(num_points, max_num_points, "Too many spline points");
174 return Err(jxl_bitstream::Error::ProfileConformance("too many spline points").into());
175 }
176
177 let mut quant_points = Vec::with_capacity(1 + num_points);
178 let mut cur_value = start_point;
179 let mut cur_delta = (0, 0);
180 let mut manhattan_distance = 0u64;
181 quant_points.push(cur_value);
182 for _ in 0..num_points {
183 let prev_value = cur_value;
184 let delta_x = unpack_signed(decoder.read_varint(bitstream, 4)?) as i64;
185 let delta_y = unpack_signed(decoder.read_varint(bitstream, 4)?) as i64;
186
187 cur_delta.0 += delta_x;
188 cur_delta.1 += delta_y;
189 manhattan_distance += (cur_delta.0.abs() + cur_delta.1.abs()) as u64;
190 cur_value.0 = cur_value.0.checked_add(cur_delta.0).ok_or(
191 jxl_bitstream::Error::ValidationFailed("control point overflowed"),
192 )?;
193 cur_value.1 = cur_value.1.checked_add(cur_delta.1).ok_or(
194 jxl_bitstream::Error::ValidationFailed("control point overflowed"),
195 )?;
196 if cur_value == prev_value {
197 return Err(jxl_bitstream::Error::ValidationFailed(
198 "two consecutive control points have the same value",
199 )
200 .into());
201 }
202 quant_points.push(cur_value);
203 }
204
205 let mut xyb_dct = [[0; 32]; 3];
206 for color_dct in &mut xyb_dct {
207 for i in color_dct {
208 *i = unpack_signed(decoder.read_varint(bitstream, 5)?);
209 }
210 }
211
212 let mut sigma_dct = [0; 32];
213 for i in &mut sigma_dct {
214 *i = unpack_signed(decoder.read_varint(bitstream, 5)?);
215 }
216
217 Ok(Self {
218 quant_points,
219 manhattan_distance,
220 xyb_dct,
221 sigma_dct,
222 })
223 }
224}