webrtc_util/buffer/
mod.rs

1#[cfg(test)]
2mod buffer_test;
3
4use std::sync::Arc;
5
6use tokio::sync::{Mutex, Notify};
7use tokio::time::{timeout, Duration};
8
9use crate::error::{Error, Result};
10
11const MIN_SIZE: usize = 2048;
12const CUTOFF_SIZE: usize = 128 * 1024;
13const MAX_SIZE: usize = 4 * 1024 * 1024;
14
15/// Buffer allows writing packets to an intermediate buffer, which can then be read form.
16/// This is verify similar to bytes.Buffer but avoids combining multiple writes into a single read.
17#[derive(Debug)]
18struct BufferInternal {
19    data: Vec<u8>,
20    head: usize,
21    tail: usize,
22
23    closed: bool,
24    subs: bool,
25
26    count: usize,
27    limit_count: usize,
28    limit_size: usize,
29}
30
31impl BufferInternal {
32    /// available returns true if the buffer is large enough to fit a packet
33    /// of the given size, taking overhead into account.
34    fn available(&self, size: usize) -> bool {
35        let mut available = self.head as isize - self.tail as isize;
36        if available <= 0 {
37            available += self.data.len() as isize;
38        }
39        // we interpret head=tail as empty, so always keep a byte free
40        size as isize + 2 < available
41    }
42
43    /// grow increases the size of the buffer.  If it returns nil, then the
44    /// buffer has been grown.  It returns ErrFull if hits a limit.
45    fn grow(&mut self) -> Result<()> {
46        let mut newsize = if self.data.len() < CUTOFF_SIZE {
47            2 * self.data.len()
48        } else {
49            5 * self.data.len() / 4
50        };
51
52        if newsize < MIN_SIZE {
53            newsize = MIN_SIZE
54        }
55        if (self.limit_size == 0/*|| sizeHardlimit*/) && newsize > MAX_SIZE {
56            newsize = MAX_SIZE
57        }
58
59        // one byte slack
60        if self.limit_size > 0 && newsize > self.limit_size + 1 {
61            newsize = self.limit_size + 1
62        }
63
64        if newsize <= self.data.len() {
65            return Err(Error::ErrBufferFull);
66        }
67
68        let mut newdata: Vec<u8> = vec![0; newsize];
69
70        let mut n;
71        if self.head <= self.tail {
72            // data was contiguous
73            n = self.tail - self.head;
74            newdata[..n].copy_from_slice(&self.data[self.head..self.tail]);
75        } else {
76            // data was discontiguous
77            n = self.data.len() - self.head;
78            newdata[..n].copy_from_slice(&self.data[self.head..]);
79            newdata[n..n + self.tail].copy_from_slice(&self.data[..self.tail]);
80            n += self.tail;
81        }
82        self.head = 0;
83        self.tail = n;
84        self.data = newdata;
85
86        Ok(())
87    }
88
89    fn size(&self) -> usize {
90        let mut size = self.tail as isize - self.head as isize;
91        if size < 0 {
92            size += self.data.len() as isize;
93        }
94        size as usize
95    }
96}
97
98#[derive(Debug, Clone)]
99pub struct Buffer {
100    buffer: Arc<Mutex<BufferInternal>>,
101    notify: Arc<Notify>,
102}
103
104impl Buffer {
105    pub fn new(limit_count: usize, limit_size: usize) -> Self {
106        Buffer {
107            buffer: Arc::new(Mutex::new(BufferInternal {
108                data: vec![],
109                head: 0,
110                tail: 0,
111
112                closed: false,
113                subs: false,
114
115                count: 0,
116                limit_count,
117                limit_size,
118            })),
119            notify: Arc::new(Notify::new()),
120        }
121    }
122
123    /// Write appends a copy of the packet data to the buffer.
124    /// Returns ErrFull if the packet doesn't fit.
125    /// Note that the packet size is limited to 65536 bytes since v0.11.0
126    /// due to the internal data structure.
127    pub async fn write(&self, packet: &[u8]) -> Result<usize> {
128        if packet.len() >= 0x10000 {
129            return Err(Error::ErrPacketTooBig);
130        }
131
132        let mut b = self.buffer.lock().await;
133
134        if b.closed {
135            return Err(Error::ErrBufferClosed);
136        }
137
138        if (b.limit_count > 0 && b.count >= b.limit_count)
139            || (b.limit_size > 0 && b.size() + 2 + packet.len() > b.limit_size)
140        {
141            return Err(Error::ErrBufferFull);
142        }
143
144        // grow the buffer until the packet fits
145        while !b.available(packet.len()) {
146            b.grow()?;
147        }
148
149        // store the length of the packet
150        let tail = b.tail;
151        b.data[tail] = (packet.len() >> 8) as u8;
152        b.tail += 1;
153        if b.tail >= b.data.len() {
154            b.tail = 0;
155        }
156
157        let tail = b.tail;
158        b.data[tail] = packet.len() as u8;
159        b.tail += 1;
160        if b.tail >= b.data.len() {
161            b.tail = 0;
162        }
163
164        // store the packet
165        let end = std::cmp::min(b.data.len(), b.tail + packet.len());
166        let n = end - b.tail;
167        let tail = b.tail;
168        b.data[tail..end].copy_from_slice(&packet[..n]);
169        b.tail += n;
170        if b.tail >= b.data.len() {
171            // we reached the end, wrap around
172            let m = packet.len() - n;
173            b.data[..m].copy_from_slice(&packet[n..]);
174            b.tail = m;
175        }
176        b.count += 1;
177
178        if b.subs {
179            // we have other are waiting data
180            self.notify.notify_one();
181            b.subs = false;
182        }
183
184        Ok(packet.len())
185    }
186
187    // Read populates the given byte slice, returning the number of bytes read.
188    // Blocks until data is available or the buffer is closed.
189    // Returns io.ErrShortBuffer is the packet is too small to copy the Write.
190    // Returns io.EOF if the buffer is closed.
191    pub async fn read(&self, packet: &mut [u8], duration: Option<Duration>) -> Result<usize> {
192        loop {
193            {
194                // use {} to let LockGuard RAII
195                let mut b = self.buffer.lock().await;
196
197                if b.head != b.tail {
198                    // decode the packet size
199                    let n1 = b.data[b.head];
200                    b.head += 1;
201                    if b.head >= b.data.len() {
202                        b.head = 0;
203                    }
204                    let n2 = b.data[b.head];
205                    b.head += 1;
206                    if b.head >= b.data.len() {
207                        b.head = 0;
208                    }
209                    let count = ((n1 as usize) << 8) | n2 as usize;
210
211                    // determine the number of bytes we'll actually copy
212                    let mut copied = count;
213                    if copied > packet.len() {
214                        copied = packet.len();
215                    }
216
217                    // copy the data
218                    if b.head + copied < b.data.len() {
219                        packet[..copied].copy_from_slice(&b.data[b.head..b.head + copied]);
220                    } else {
221                        let k = b.data.len() - b.head;
222                        packet[..k].copy_from_slice(&b.data[b.head..]);
223                        packet[k..copied].copy_from_slice(&b.data[..copied - k]);
224                    }
225
226                    // advance head, discarding any data that wasn't copied
227                    b.head += count;
228                    if b.head >= b.data.len() {
229                        b.head -= b.data.len();
230                    }
231
232                    if b.head == b.tail {
233                        // the buffer is empty, reset to beginning
234                        // in order to improve cache locality.
235                        b.head = 0;
236                        b.tail = 0;
237                    }
238
239                    b.count -= 1;
240
241                    if copied < count {
242                        return Err(Error::ErrBufferShort);
243                    }
244                    return Ok(copied);
245                } else {
246                    // Dont have data -> need wait
247                    b.subs = true;
248                }
249
250                if b.closed {
251                    return Err(Error::ErrBufferClosed);
252                }
253            }
254
255            // Wait for signal.
256            if let Some(d) = duration {
257                if timeout(d, self.notify.notified()).await.is_err() {
258                    return Err(Error::ErrTimeout);
259                }
260            } else {
261                self.notify.notified().await;
262            }
263        }
264    }
265
266    // Close will unblock any readers and prevent future writes.
267    // Data in the buffer can still be read, returning io.EOF when fully depleted.
268    pub async fn close(&self) {
269        // note: We don't use defer so we can close the notify channel after unlocking.
270        // This will unblock goroutines that can grab the lock immediately, instead of blocking again.
271        let mut b = self.buffer.lock().await;
272
273        if b.closed {
274            return;
275        }
276
277        b.closed = true;
278        self.notify.notify_waiters();
279    }
280
281    pub async fn is_closed(&self) -> bool {
282        let b = self.buffer.lock().await;
283
284        b.closed
285    }
286
287    // Count returns the number of packets in the buffer.
288    pub async fn count(&self) -> usize {
289        let b = self.buffer.lock().await;
290
291        b.count
292    }
293
294    // set_limit_count controls the maximum number of packets that can be buffered.
295    // Causes Write to return ErrFull when this limit is reached.
296    // A zero value will disable this limit.
297    pub async fn set_limit_count(&self, limit: usize) {
298        let mut b = self.buffer.lock().await;
299
300        b.limit_count = limit
301    }
302
303    // Size returns the total byte size of packets in the buffer.
304    pub async fn size(&self) -> usize {
305        let b = self.buffer.lock().await;
306
307        b.size()
308    }
309
310    // set_limit_size controls the maximum number of bytes that can be buffered.
311    // Causes Write to return ErrFull when this limit is reached.
312    // A zero value means 4MB since v0.11.0.
313    //
314    // User can set packetioSizeHardlimit build tag to enable 4MB hardlimit.
315    // When packetioSizeHardlimit build tag is set, set_limit_size exceeding
316    // the hardlimit will be silently discarded.
317    pub async fn set_limit_size(&self, limit: usize) {
318        let mut b = self.buffer.lock().await;
319
320        b.limit_size = limit
321    }
322}