interceptor/twcc/
mod.rs

1#[cfg(test)]
2mod twcc_test;
3
4pub mod receiver;
5pub mod sender;
6
7use std::cmp::Ordering;
8
9use rtcp::transport_feedbacks::transport_layer_cc::{
10    PacketStatusChunk, RecvDelta, RunLengthChunk, StatusChunkTypeTcc, StatusVectorChunk,
11    SymbolSizeTypeTcc, SymbolTypeTcc, TransportLayerCc,
12};
13
14#[derive(Default, Debug, PartialEq, Clone)]
15struct PktInfo {
16    sequence_number: u32,
17    arrival_time: i64,
18}
19
20/// Recorder records incoming RTP packets and their delays and creates
21/// transport wide congestion control feedback reports as specified in
22/// <https://datatracker.ietf.org/doc/html/draft-holmer-rmcat-transport-wide-cc-extensions-01>
23#[derive(Default, Debug, PartialEq, Clone)]
24pub struct Recorder {
25    received_packets: Vec<PktInfo>,
26
27    cycles: u32,
28    last_sequence_number: u16,
29
30    sender_ssrc: u32,
31    media_ssrc: u32,
32    fb_pkt_cnt: u8,
33}
34
35impl Recorder {
36    /// new creates a new Recorder which uses the given sender_ssrc in the created
37    /// feedback packets.
38    pub fn new(sender_ssrc: u32) -> Self {
39        Recorder {
40            sender_ssrc,
41            ..Default::default()
42        }
43    }
44
45    /// record marks a packet with media_ssrc and a transport wide sequence number sequence_number as received at arrival_time.
46    pub fn record(&mut self, media_ssrc: u32, sequence_number: u16, arrival_time: i64) {
47        self.media_ssrc = media_ssrc;
48        if sequence_number < 0x0fff && self.last_sequence_number > 0xf000 {
49            self.cycles += 1 << 16;
50        }
51        self.received_packets.push(PktInfo {
52            sequence_number: self.cycles | sequence_number as u32,
53            arrival_time,
54        });
55        self.last_sequence_number = sequence_number;
56    }
57
58    /// build_feedback_packet creates a new RTCP packet containing a TWCC feedback report.
59    pub fn build_feedback_packet(&mut self) -> Vec<Box<dyn rtcp::packet::Packet + Send + Sync>> {
60        if self.received_packets.len() < 2 {
61            return vec![];
62        }
63        let mut feedback = Feedback::new(self.sender_ssrc, self.media_ssrc, self.fb_pkt_cnt);
64        self.fb_pkt_cnt = self.fb_pkt_cnt.wrapping_add(1);
65
66        self.received_packets
67            .sort_by(|a: &PktInfo, b: &PktInfo| -> Ordering {
68                a.sequence_number.cmp(&b.sequence_number)
69            });
70        feedback.set_base(
71            (self.received_packets[0].sequence_number & 0xffff) as u16,
72            self.received_packets[0].arrival_time,
73        );
74
75        let mut pkts = vec![];
76        for pkt in &self.received_packets {
77            let built =
78                feedback.add_received((pkt.sequence_number & 0xffff) as u16, pkt.arrival_time);
79            if !built {
80                let p: Box<dyn rtcp::packet::Packet + Send + Sync> = Box::new(feedback.get_rtcp());
81                pkts.push(p);
82                feedback = Feedback::new(self.sender_ssrc, self.media_ssrc, self.fb_pkt_cnt);
83                self.fb_pkt_cnt = self.fb_pkt_cnt.wrapping_add(1);
84                feedback.add_received((pkt.sequence_number & 0xffff) as u16, pkt.arrival_time);
85            }
86        }
87        self.received_packets.clear();
88        let p: Box<dyn rtcp::packet::Packet + Send + Sync> = Box::new(feedback.get_rtcp());
89        pkts.push(p);
90        pkts
91    }
92}
93
94#[derive(Default, Debug, PartialEq, Clone)]
95struct Feedback {
96    rtcp: TransportLayerCc,
97    base_sequence_number: u16,
98    ref_timestamp64ms: i64,
99    last_timestamp_us: i64,
100    next_sequence_number: u16,
101    sequence_number_count: u16,
102    len: usize,
103    last_chunk: Chunk,
104    chunks: Vec<PacketStatusChunk>,
105    deltas: Vec<RecvDelta>,
106}
107
108impl Feedback {
109    fn new(sender_ssrc: u32, media_ssrc: u32, fb_pkt_count: u8) -> Self {
110        Feedback {
111            rtcp: TransportLayerCc {
112                sender_ssrc,
113                media_ssrc,
114                fb_pkt_count,
115                ..Default::default()
116            },
117            ..Default::default()
118        }
119    }
120
121    fn set_base(&mut self, sequence_number: u16, time_us: i64) {
122        self.base_sequence_number = sequence_number;
123        self.next_sequence_number = self.base_sequence_number;
124        self.ref_timestamp64ms = time_us / 64000;
125        self.last_timestamp_us = self.ref_timestamp64ms * 64000;
126    }
127
128    fn get_rtcp(&mut self) -> TransportLayerCc {
129        self.rtcp.packet_status_count = self.sequence_number_count;
130        self.rtcp.reference_time = self.ref_timestamp64ms as u32;
131        self.rtcp.base_sequence_number = self.base_sequence_number;
132        while !self.last_chunk.deltas.is_empty() {
133            self.chunks.push(self.last_chunk.encode());
134        }
135        self.rtcp.packet_chunks.extend_from_slice(&self.chunks);
136        self.rtcp.recv_deltas.clone_from(&self.deltas);
137
138        self.rtcp.clone()
139    }
140
141    fn add_received(&mut self, sequence_number: u16, timestamp_us: i64) -> bool {
142        let delta_us = timestamp_us - self.last_timestamp_us;
143        let delta250us = delta_us / 250;
144        if delta250us < i16::MIN as i64 || delta250us > i16::MAX as i64 {
145            // delta doesn't fit into 16 bit, need to create new packet
146            return false;
147        }
148
149        while self.next_sequence_number != sequence_number {
150            if !self
151                .last_chunk
152                .can_add(SymbolTypeTcc::PacketNotReceived as u16)
153            {
154                self.chunks.push(self.last_chunk.encode());
155            }
156            self.last_chunk.add(SymbolTypeTcc::PacketNotReceived as u16);
157            self.sequence_number_count = self.sequence_number_count.wrapping_add(1);
158            self.next_sequence_number = self.next_sequence_number.wrapping_add(1);
159        }
160
161        let recv_delta = if (0..=0xff).contains(&delta250us) {
162            self.len += 1;
163            SymbolTypeTcc::PacketReceivedSmallDelta
164        } else {
165            self.len += 2;
166            SymbolTypeTcc::PacketReceivedLargeDelta
167        };
168
169        if !self.last_chunk.can_add(recv_delta as u16) {
170            self.chunks.push(self.last_chunk.encode());
171        }
172        self.last_chunk.add(recv_delta as u16);
173        self.deltas.push(RecvDelta {
174            type_tcc_packet: recv_delta,
175            delta: delta_us,
176        });
177        self.last_timestamp_us = timestamp_us;
178        self.sequence_number_count = self.sequence_number_count.wrapping_add(1);
179        self.next_sequence_number = self.next_sequence_number.wrapping_add(1);
180        true
181    }
182}
183
184const MAX_RUN_LENGTH_CAP: usize = 0x1fff; // 13 bits
185const MAX_ONE_BIT_CAP: usize = 14; // bits
186const MAX_TWO_BIT_CAP: usize = 7; // bits
187
188#[derive(Default, Debug, PartialEq, Clone)]
189struct Chunk {
190    has_large_delta: bool,
191    has_different_types: bool,
192    deltas: Vec<u16>,
193}
194
195impl Chunk {
196    fn can_add(&self, delta: u16) -> bool {
197        if self.deltas.len() < MAX_TWO_BIT_CAP {
198            return true;
199        }
200        if self.deltas.len() < MAX_ONE_BIT_CAP
201            && !self.has_large_delta
202            && delta != SymbolTypeTcc::PacketReceivedLargeDelta as u16
203        {
204            return true;
205        }
206        if self.deltas.len() < MAX_RUN_LENGTH_CAP
207            && !self.has_different_types
208            && delta == self.deltas[0]
209        {
210            return true;
211        }
212        false
213    }
214
215    fn add(&mut self, delta: u16) {
216        self.deltas.push(delta);
217        self.has_large_delta =
218            self.has_large_delta || delta == SymbolTypeTcc::PacketReceivedLargeDelta as u16;
219        self.has_different_types = self.has_different_types || delta != self.deltas[0];
220    }
221
222    fn encode(&mut self) -> PacketStatusChunk {
223        if !self.has_different_types {
224            let p = PacketStatusChunk::RunLengthChunk(RunLengthChunk {
225                type_tcc: StatusChunkTypeTcc::RunLengthChunk,
226                packet_status_symbol: self.deltas[0].into(),
227                run_length: self.deltas.len() as u16,
228            });
229            self.reset();
230            return p;
231        }
232        if self.deltas.len() == MAX_ONE_BIT_CAP {
233            let p = PacketStatusChunk::StatusVectorChunk(StatusVectorChunk {
234                type_tcc: StatusChunkTypeTcc::StatusVectorChunk,
235                symbol_size: SymbolSizeTypeTcc::OneBit,
236                symbol_list: self
237                    .deltas
238                    .iter()
239                    .map(|x| SymbolTypeTcc::from(*x))
240                    .collect::<Vec<SymbolTypeTcc>>(),
241            });
242            self.reset();
243            return p;
244        }
245
246        let min_cap = std::cmp::min(MAX_TWO_BIT_CAP, self.deltas.len());
247        let svc = PacketStatusChunk::StatusVectorChunk(StatusVectorChunk {
248            type_tcc: StatusChunkTypeTcc::StatusVectorChunk,
249            symbol_size: SymbolSizeTypeTcc::TwoBit,
250            symbol_list: self.deltas[..min_cap]
251                .iter()
252                .map(|x| SymbolTypeTcc::from(*x))
253                .collect::<Vec<SymbolTypeTcc>>(),
254        });
255        self.deltas.drain(..min_cap);
256        self.has_different_types = false;
257        self.has_large_delta = false;
258
259        if !self.deltas.is_empty() {
260            let tmp = self.deltas[0];
261            for d in &self.deltas {
262                if tmp != *d {
263                    self.has_different_types = true;
264                }
265                if *d == SymbolTypeTcc::PacketReceivedLargeDelta as u16 {
266                    self.has_large_delta = true;
267                }
268            }
269        }
270
271        svc
272    }
273
274    fn reset(&mut self) {
275        self.deltas = vec![];
276        self.has_large_delta = false;
277        self.has_different_types = false;
278    }
279}