1use crate::video::decode::Decoder;
10use crate::video::pixel::CastFromPrimitive;
11use crate::video::pixel::Pixel;
12use crate::video::ChromaWeight;
13use crate::video::{PlanarMetrics, VideoMetric};
14use crate::MetricsError;
15use std::error::Error;
16use std::mem::size_of;
17use v_frame::frame::Frame;
18use v_frame::plane::Plane;
19use v_frame::prelude::ChromaSampling;
20
21use super::FrameCompare;
22
23#[inline]
25pub fn calculate_video_psnr_hvs<D: Decoder, F: Fn(usize) + Send>(
26 decoder1: &mut D,
27 decoder2: &mut D,
28 frame_limit: Option<usize>,
29 progress_callback: F,
30) -> Result<PlanarMetrics, Box<dyn Error>> {
31 let cweight = Some(
32 decoder1
33 .get_video_details()
34 .chroma_sampling
35 .get_chroma_weight(),
36 );
37 PsnrHvs { cweight }.process_video(decoder1, decoder2, frame_limit, progress_callback)
38}
39
40#[inline]
42pub fn calculate_frame_psnr_hvs<T: Pixel>(
43 frame1: &Frame<T>,
44 frame2: &Frame<T>,
45 bit_depth: usize,
46 chroma_sampling: ChromaSampling,
47) -> Result<PlanarMetrics, Box<dyn Error>> {
48 let processor = PsnrHvs::default();
49 let result = processor.process_frame(frame1, frame2, bit_depth, chroma_sampling)?;
50 let cweight = chroma_sampling.get_chroma_weight();
51 Ok(PlanarMetrics {
52 y: log10_convert(result.y, 1.0),
53 u: log10_convert(result.u, 1.0),
54 v: log10_convert(result.v, 1.0),
55 avg: log10_convert(
56 result.y + cweight * (result.u + result.v),
57 1.0 + 2.0 * cweight,
58 ),
59 })
60}
61
62#[derive(Default)]
63struct PsnrHvs {
64 pub cweight: Option<f64>,
65}
66
67impl VideoMetric for PsnrHvs {
68 type FrameResult = PlanarMetrics;
69 type VideoResult = PlanarMetrics;
70
71 fn process_frame<T: Pixel>(
74 &self,
75 frame1: &Frame<T>,
76 frame2: &Frame<T>,
77 bit_depth: usize,
78 _chroma_sampling: ChromaSampling,
79 ) -> Result<Self::FrameResult, Box<dyn Error>> {
80 if (size_of::<T>() == 1 && bit_depth > 8) || (size_of::<T>() == 2 && bit_depth <= 8) {
81 return Err(Box::new(MetricsError::InputMismatch {
82 reason: "Bit depths does not match pixel width",
83 }));
84 }
85
86 frame1.can_compare(frame2)?;
87
88 let bit_depth = bit_depth;
89 let mut y = 0.0;
90 let mut u = 0.0;
91 let mut v = 0.0;
92
93 rayon::scope(|s| {
94 s.spawn(|_| {
95 y = calculate_plane_psnr_hvs(&frame1.planes[0], &frame2.planes[0], 0, bit_depth)
96 });
97 s.spawn(|_| {
98 u = calculate_plane_psnr_hvs(&frame1.planes[1], &frame2.planes[1], 1, bit_depth)
99 });
100 s.spawn(|_| {
101 v = calculate_plane_psnr_hvs(&frame1.planes[2], &frame2.planes[2], 2, bit_depth)
102 });
103 });
104
105 Ok(PlanarMetrics {
106 y,
107 u,
108 v,
109 avg: 0.,
111 })
112 }
113
114 fn aggregate_frame_results(
115 &self,
116 metrics: &[Self::FrameResult],
117 ) -> Result<Self::VideoResult, Box<dyn Error>> {
118 let cweight = self.cweight.unwrap_or(1.0);
119 let sum_y = metrics.iter().map(|m| m.y).sum::<f64>();
120 let sum_u = metrics.iter().map(|m| m.u).sum::<f64>();
121 let sum_v = metrics.iter().map(|m| m.v).sum::<f64>();
122 Ok(PlanarMetrics {
123 y: log10_convert(sum_y, 1. / metrics.len() as f64),
124 u: log10_convert(sum_u, 1. / metrics.len() as f64),
125 v: log10_convert(sum_v, 1. / metrics.len() as f64),
126 avg: log10_convert(
127 sum_y + cweight * (sum_u + sum_v),
128 (1. + 2. * cweight) * 1. / metrics.len() as f64,
129 ),
130 })
131 }
132}
133
134#[rustfmt::skip]
138const CSF_Y: [[f64; 8]; 8] = [
139 [1.6193873005, 2.2901594831, 2.08509755623, 1.48366094411, 1.00227514334, 0.678296995242, 0.466224900598, 0.3265091542],
140 [2.2901594831, 1.94321815382, 2.04793073064, 1.68731108984, 1.2305666963, 0.868920337363, 0.61280991668, 0.436405793551],
141 [2.08509755623, 2.04793073064, 1.34329019223, 1.09205635862, 0.875748795257, 0.670882927016, 0.501731932449, 0.372504254596],
142 [1.48366094411, 1.68731108984, 1.09205635862, 0.772819797575, 0.605636379554, 0.48309405692, 0.380429446972, 0.295774038565],
143 [1.00227514334, 1.2305666963, 0.875748795257, 0.605636379554, 0.448996256676, 0.352889268808, 0.283006984131, 0.226951348204],
144 [0.678296995242, 0.868920337363, 0.670882927016, 0.48309405692, 0.352889268808, 0.27032073436, 0.215017739696, 0.17408067321],
145 [0.466224900598, 0.61280991668, 0.501731932449, 0.380429446972, 0.283006984131, 0.215017739696, 0.168869545842, 0.136153931001],
146 [0.3265091542, 0.436405793551, 0.372504254596, 0.295774038565, 0.226951348204, 0.17408067321, 0.136153931001, 0.109083846276]
147];
148
149#[rustfmt::skip]
150const CSF_CB420: [[f64; 8]; 8] = [
151 [1.91113096927, 2.46074210438, 1.18284184739, 1.14982565193, 1.05017074788, 0.898018824055, 0.74725392039, 0.615105596242],
152 [2.46074210438, 1.58529308355, 1.21363250036, 1.38190029285, 1.33100189972, 1.17428548929, 0.996404342439, 0.830890433625],
153 [1.18284184739, 1.21363250036, 0.978712413627, 1.02624506078, 1.03145147362, 0.960060382087, 0.849823426169, 0.731221236837],
154 [1.14982565193, 1.38190029285, 1.02624506078, 0.861317501629, 0.801821139099, 0.751437590932, 0.685398513368, 0.608694761374],
155 [1.05017074788, 1.33100189972, 1.03145147362, 0.801821139099, 0.676555426187, 0.605503172737, 0.55002013668, 0.495804539034],
156 [0.898018824055, 1.17428548929, 0.960060382087, 0.751437590932, 0.605503172737, 0.514674450957, 0.454353482512, 0.407050308965],
157 [0.74725392039, 0.996404342439, 0.849823426169, 0.685398513368, 0.55002013668, 0.454353482512, 0.389234902883, 0.342353999733],
158 [0.615105596242, 0.830890433625, 0.731221236837, 0.608694761374, 0.495804539034, 0.407050308965, 0.342353999733, 0.295530605237]
159];
160
161#[rustfmt::skip]
162const CSF_CR420: [[f64; 8]; 8] = [
163 [2.03871978502, 2.62502345193, 1.26180942886, 1.11019789803, 1.01397751469, 0.867069376285, 0.721500455585, 0.593906509971],
164 [2.62502345193, 1.69112867013, 1.17180569821, 1.3342742857, 1.28513006198, 1.13381474809, 0.962064122248, 0.802254508198],
165 [1.26180942886, 1.17180569821, 0.944981930573, 0.990876405848, 0.995903384143, 0.926972725286, 0.820534991409, 0.706020324706],
166 [1.11019789803, 1.3342742857, 0.990876405848, 0.831632933426, 0.77418706195, 0.725539939514, 0.661776842059, 0.587716619023],
167 [1.01397751469, 1.28513006198, 0.995903384143, 0.77418706195, 0.653238524286, 0.584635025748, 0.531064164893, 0.478717061273],
168 [0.867069376285, 1.13381474809, 0.926972725286, 0.725539939514, 0.584635025748, 0.496936637883, 0.438694579826, 0.393021669543],
169 [0.721500455585, 0.962064122248, 0.820534991409, 0.661776842059, 0.531064164893, 0.438694579826, 0.375820256136, 0.330555063063],
170 [0.593906509971, 0.802254508198, 0.706020324706, 0.587716619023, 0.478717061273, 0.393021669543, 0.330555063063, 0.285345396658]
171];
172
173fn calculate_plane_psnr_hvs<T: Pixel>(
174 plane1: &Plane<T>,
175 plane2: &Plane<T>,
176 plane_idx: usize,
177 bit_depth: usize,
178) -> f64 {
179 const STEP: usize = 7;
180 let mut result = 0.0;
181 let mut pixels = 0usize;
182 let csf = match plane_idx {
183 0 => &CSF_Y,
184 1 => &CSF_CB420,
185 2 => &CSF_CR420,
186 _ => unreachable!(),
187 };
188
189 const CSF_MULTIPLIER: f64 = 0.3885746225901003;
206 let mut mask = [[0.0; 8]; 8];
207 for x in 0..8 {
208 for y in 0..8 {
209 mask[x][y] = (csf[x][y] * CSF_MULTIPLIER).powi(2);
210 }
211 }
212
213 let height = plane1.cfg.height;
214 let width = plane1.cfg.width;
215 let stride = plane1.cfg.stride;
216 let mut p1 = [0i16; 8 * 8];
217 let mut p2 = [0i16; 8 * 8];
218 let mut dct_p1 = [0i32; 8 * 8];
219 let mut dct_p2 = [0i32; 8 * 8];
220 assert!(plane1.data.len() >= stride * height);
221 assert!(plane2.data.len() >= stride * height);
222 for y in (0..(height - STEP)).step_by(STEP) {
223 for x in (0..(width - STEP)).step_by(STEP) {
224 let mut p1_means = [0.0; 4];
225 let mut p2_means = [0.0; 4];
226 let mut p1_vars = [0.0; 4];
227 let mut p2_vars = [0.0; 4];
228 let mut p1_gmean = 0.0;
229 let mut p2_gmean = 0.0;
230 let mut p1_gvar = 0.0;
231 let mut p2_gvar = 0.0;
232 let mut p1_mask = 0.0;
233 let mut p2_mask = 0.0;
234
235 for i in 0..8 {
236 for j in 0..8 {
237 p1[i * 8 + j] = i16::cast_from(plane1.data[(y + i) * stride + x + j]);
238 p2[i * 8 + j] = i16::cast_from(plane2.data[(y + i) * stride + x + j]);
239
240 let sub = ((i & 12) >> 2) + ((j & 12) >> 1);
241 p1_gmean += p1[i * 8 + j] as f64;
242 p2_gmean += p2[i * 8 + j] as f64;
243 p1_means[sub] += p1[i * 8 + j] as f64;
244 p2_means[sub] += p2[i * 8 + j] as f64;
245 }
246 }
247 p1_gmean /= 64.0;
248 p2_gmean /= 64.0;
249 for i in 0..4 {
250 p1_means[i] /= 16.0;
251 p2_means[i] /= 16.0;
252 }
253
254 for i in 0..8 {
255 for j in 0..8 {
256 let sub = ((i & 12) >> 2) + ((j & 12) >> 1);
257 p1_gvar +=
258 (p1[i * 8 + j] as f64 - p1_gmean) * (p1[i * 8 + j] as f64 - p1_gmean);
259 p2_gvar +=
260 (p2[i * 8 + j] as f64 - p2_gmean) * (p2[i * 8 + j] as f64 - p2_gmean);
261 p1_vars[sub] += (p1[i * 8 + j] as f64 - p1_means[sub])
262 * (p1[i * 8 + j] as f64 - p1_means[sub]);
263 p2_vars[sub] += (p2[i * 8 + j] as f64 - p2_means[sub])
264 * (p2[i * 8 + j] as f64 - p2_means[sub]);
265 }
266 }
267 p1_gvar *= 64.0 / 63.0;
268 p2_gvar *= 64.0 / 63.0;
269 for i in 0..4 {
270 p1_vars[i] *= 16.0 / 15.0;
271 p2_vars[i] *= 16.0 / 15.0;
272 }
273 if p1_gvar > 0.0 {
274 p1_gvar = p1_vars.iter().sum::<f64>() / p1_gvar;
275 }
276 if p2_gvar > 0.0 {
277 p2_gvar = p2_vars.iter().sum::<f64>() / p2_gvar;
278 }
279
280 p1.iter().copied().enumerate().for_each(|(i, v)| {
281 dct_p1[i] = v as i32;
282 });
283 p2.iter().copied().enumerate().for_each(|(i, v)| {
284 dct_p2[i] = v as i32;
285 });
286 od_bin_fdct8x8(&mut dct_p1);
287 od_bin_fdct8x8(&mut dct_p2);
288 for i in 0..8 {
289 for j in (i == 0) as usize..8 {
290 p1_mask += dct_p1[i * 8 + j].pow(2) as f64 * mask[i][j];
291 p2_mask += dct_p2[i * 8 + j].pow(2) as f64 * mask[i][j];
292 }
293 }
294 p1_mask = (p1_mask * p1_gvar).sqrt() / 32.0;
295 p2_mask = (p2_mask * p2_gvar).sqrt() / 32.0;
296 if p2_mask > p1_mask {
297 p1_mask = p2_mask;
298 }
299 for i in 0..8 {
300 for j in 0..8 {
301 let mut err = (dct_p1[i * 8 + j] - dct_p2[i * 8 + j]).abs() as f64;
302 if i != 0 || j != 0 {
303 let err_mask = p1_mask / mask[i][j];
304 err = if err < err_mask { 0.0 } else { err - err_mask };
305 }
306 result += (err * csf[i][j]).powi(2);
307 pixels += 1;
308 }
309 }
310 }
311 }
312
313 result /= pixels as f64;
314 let sample_max: usize = (1 << bit_depth) - 1;
315 result /= sample_max.pow(2) as f64;
316 result
317}
318
319fn log10_convert(score: f64, weight: f64) -> f64 {
320 10.0 * (-1.0 * (weight * score).log10())
321}
322
323const DCT_STRIDE: usize = 8;
324
325fn od_bin_fdct8x8(data: &mut [i32]) {
327 assert!(data.len() >= 64);
328 let mut z = [0; 64];
329 for i in 0..8 {
330 od_bin_fdct8(&mut z[(DCT_STRIDE * i)..], &data[i..]);
331 }
332 for i in 0..8 {
333 od_bin_fdct8(&mut data[(DCT_STRIDE * i)..], &z[i..]);
334 }
335}
336
337#[allow(clippy::identity_op)]
338fn od_bin_fdct8(y: &mut [i32], x: &[i32]) {
339 assert!(y.len() >= 8);
340 assert!(x.len() > 7 * DCT_STRIDE);
341 let mut t = [0; 8];
342 let mut th = [0; 8];
343 t[0] = x[0];
345 t[4] = x[1 * DCT_STRIDE];
346 t[2] = x[2 * DCT_STRIDE];
347 t[6] = x[3 * DCT_STRIDE];
348 t[7] = x[4 * DCT_STRIDE];
349 t[3] = x[5 * DCT_STRIDE];
350 t[5] = x[6 * DCT_STRIDE];
351 t[1] = x[7 * DCT_STRIDE];
352 t[1] = t[0] - t[1];
354 th[1] = od_dct_rshift(t[1], 1);
355 t[0] -= th[1];
356 t[4] += t[5];
357 th[4] = od_dct_rshift(t[4], 1);
358 t[5] -= th[4];
359 t[3] = t[2] - t[3];
360 t[2] -= od_dct_rshift(t[3], 1);
361 t[6] += t[7];
362 th[6] = od_dct_rshift(t[6], 1);
363 t[7] = th[6] - t[7];
364 t[0] += th[6];
366 t[6] = t[0] - t[6];
367 t[2] = th[4] - t[2];
368 t[4] = t[2] - t[4];
369 t[0] -= (t[4] * 13573 + 16384) >> 15;
371 t[4] += (t[0] * 11585 + 8192) >> 14;
372 t[0] -= (t[4] * 13573 + 16384) >> 15;
373 t[6] -= (t[2] * 21895 + 16384) >> 15;
375 t[2] += (t[6] * 15137 + 8192) >> 14;
376 t[6] -= (t[2] * 21895 + 16384) >> 15;
377 t[3] += (t[5] * 19195 + 16384) >> 15;
379 t[5] += (t[3] * 11585 + 8192) >> 14;
380 t[3] -= (t[5] * 7489 + 4096) >> 13;
381 t[7] = od_dct_rshift(t[5], 1) - t[7];
382 t[5] -= t[7];
383 t[3] = th[1] - t[3];
384 t[1] -= t[3];
385 t[7] += (t[1] * 3227 + 16384) >> 15;
386 t[1] -= (t[7] * 6393 + 16384) >> 15;
387 t[7] += (t[1] * 3227 + 16384) >> 15;
388 t[5] += (t[3] * 2485 + 4096) >> 13;
389 t[3] -= (t[5] * 18205 + 16384) >> 15;
390 t[5] += (t[3] * 2485 + 4096) >> 13;
391 y[0] = t[0];
392 y[1] = t[1];
393 y[2] = t[2];
394 y[3] = t[3];
395 y[4] = t[4];
396 y[5] = t[5];
397 y[6] = t[6];
398 y[7] = t[7];
399}
400
401#[inline(always)]
405fn od_dct_rshift(a: i32, b: u32) -> i32 {
406 debug_assert!(b > 0);
407 debug_assert!(b <= 32);
408
409 ((a as u32 >> (32 - b)) as i32 + a) >> b
410}