av1_grain/
diff.rs

1use anyhow::{ensure, Result};
2use num_rational::Rational64;
3use v_frame::{frame::Frame, pixel::Pixel};
4
5use self::solver::{FlatBlockFinder, NoiseModel};
6use crate::{util::frame_into_u8, GrainTableSegment};
7
8mod solver;
9
10const BLOCK_SIZE: usize = 32;
11const BLOCK_SIZE_SQUARED: usize = BLOCK_SIZE * BLOCK_SIZE;
12
13pub struct DiffGenerator {
14    fps: Rational64,
15    source_bit_depth: usize,
16    denoised_bit_depth: usize,
17    frame_count: usize,
18    prev_timestamp: u64,
19    flat_block_finder: FlatBlockFinder,
20    noise_model: NoiseModel,
21    grain_table: Vec<GrainTableSegment>,
22}
23
24impl DiffGenerator {
25    #[must_use]
26    pub fn new(fps: Rational64, source_bit_depth: usize, denoised_bit_depth: usize) -> Self {
27        Self {
28            frame_count: 0,
29            fps,
30            flat_block_finder: FlatBlockFinder::new(),
31            noise_model: NoiseModel::new(),
32            grain_table: Vec::new(),
33            prev_timestamp: 0,
34            source_bit_depth,
35            denoised_bit_depth,
36        }
37    }
38
39    /// Processes the next frame and adds the results to the state of this
40    /// `DiffGenerator`.
41    ///
42    /// # Errors
43    /// - If the frames do not have the same resolution
44    /// - If the frames do not have the same chroma subsampling
45    pub fn diff_frame<T: Pixel, U: Pixel>(
46        &mut self,
47        source: &Frame<T>,
48        denoised: &Frame<U>,
49    ) -> Result<()> {
50        self.diff_frame_internal(
51            &frame_into_u8(source, self.source_bit_depth),
52            &frame_into_u8(denoised, self.denoised_bit_depth),
53        )
54    }
55
56    /// Finalize the state of this `DiffGenerator` and return the resulting
57    /// grain table segments.
58    #[must_use]
59    pub fn finish(mut self) -> Vec<GrainTableSegment> {
60        log::debug!("Updating final parameters");
61        self.grain_table.push(
62            self.noise_model
63                .get_grain_parameters(self.prev_timestamp, i64::MAX as u64),
64        );
65
66        self.grain_table
67    }
68
69    fn diff_frame_internal(&mut self, source: &Frame<u8>, denoised: &Frame<u8>) -> Result<()> {
70        verify_dimensions_match(source, denoised)?;
71
72        let (flat_blocks, num_flat_blocks) = self.flat_block_finder.run(&source.planes[0]);
73        log::debug!("Num flat blocks: {}", num_flat_blocks);
74
75        log::debug!("Updating noise model");
76        let status = self.noise_model.update(source, denoised, &flat_blocks);
77
78        if status == NoiseStatus::DifferentType {
79            let cur_timestamp = self.frame_count as u64 * 10_000_000u64 * *self.fps.denom() as u64
80                / *self.fps.numer() as u64;
81            log::debug!(
82                "Updating parameters for times {} to {}",
83                self.prev_timestamp,
84                cur_timestamp
85            );
86            self.grain_table.push(
87                self.noise_model
88                    .get_grain_parameters(self.prev_timestamp, cur_timestamp),
89            );
90            self.noise_model.save_latest();
91            self.prev_timestamp = cur_timestamp;
92        }
93        log::debug!("Noise model updated for frame {}", self.frame_count);
94        self.frame_count += 1;
95
96        Ok(())
97    }
98}
99
100#[derive(Debug)]
101enum NoiseStatus {
102    Ok,
103    DifferentType,
104    Error(anyhow::Error),
105}
106
107impl PartialEq for NoiseStatus {
108    fn eq(&self, other: &Self) -> bool {
109        match (self, other) {
110            (&Self::Error(_), &Self::Error(_)) => true,
111            _ => core::mem::discriminant(self) == core::mem::discriminant(other),
112        }
113    }
114}
115
116fn verify_dimensions_match(source: &Frame<u8>, denoised: &Frame<u8>) -> Result<()> {
117    let res_1 = (source.planes[0].cfg.width, source.planes[0].cfg.height);
118    let res_2 = (denoised.planes[0].cfg.width, denoised.planes[0].cfg.height);
119    ensure!(
120        res_1 == res_2,
121        "Luma resolutions were not equal, {}x{} != {}x{}",
122        res_1.0,
123        res_1.1,
124        res_2.0,
125        res_2.1
126    );
127
128    let res_1 = (source.planes[1].cfg.width, source.planes[1].cfg.height);
129    let res_2 = (denoised.planes[1].cfg.width, denoised.planes[1].cfg.height);
130    ensure!(
131        res_1 == res_2,
132        "Chroma resolutions were not equal, {}x{} != {}x{}",
133        res_1.0,
134        res_1.1,
135        res_2.0,
136        res_2.1
137    );
138
139    Ok(())
140}