webrtc_util/buffer/
mod.rs1#[cfg(test)]
2mod buffer_test;
3
4use std::sync::Arc;
5
6use tokio::sync::{Mutex, Notify};
7use tokio::time::{timeout, Duration};
8
9use crate::error::{Error, Result};
10
11const MIN_SIZE: usize = 2048;
12const CUTOFF_SIZE: usize = 128 * 1024;
13const MAX_SIZE: usize = 4 * 1024 * 1024;
14
15#[derive(Debug)]
18struct BufferInternal {
19 data: Vec<u8>,
20 head: usize,
21 tail: usize,
22
23 closed: bool,
24 subs: bool,
25
26 count: usize,
27 limit_count: usize,
28 limit_size: usize,
29}
30
31impl BufferInternal {
32 fn available(&self, size: usize) -> bool {
35 let mut available = self.head as isize - self.tail as isize;
36 if available <= 0 {
37 available += self.data.len() as isize;
38 }
39 size as isize + 2 < available
41 }
42
43 fn grow(&mut self) -> Result<()> {
46 let mut newsize = if self.data.len() < CUTOFF_SIZE {
47 2 * self.data.len()
48 } else {
49 5 * self.data.len() / 4
50 };
51
52 if newsize < MIN_SIZE {
53 newsize = MIN_SIZE
54 }
55 if (self.limit_size == 0) && newsize > MAX_SIZE {
56 newsize = MAX_SIZE
57 }
58
59 if self.limit_size > 0 && newsize > self.limit_size + 1 {
61 newsize = self.limit_size + 1
62 }
63
64 if newsize <= self.data.len() {
65 return Err(Error::ErrBufferFull);
66 }
67
68 let mut newdata: Vec<u8> = vec![0; newsize];
69
70 let mut n;
71 if self.head <= self.tail {
72 n = self.tail - self.head;
74 newdata[..n].copy_from_slice(&self.data[self.head..self.tail]);
75 } else {
76 n = self.data.len() - self.head;
78 newdata[..n].copy_from_slice(&self.data[self.head..]);
79 newdata[n..n + self.tail].copy_from_slice(&self.data[..self.tail]);
80 n += self.tail;
81 }
82 self.head = 0;
83 self.tail = n;
84 self.data = newdata;
85
86 Ok(())
87 }
88
89 fn size(&self) -> usize {
90 let mut size = self.tail as isize - self.head as isize;
91 if size < 0 {
92 size += self.data.len() as isize;
93 }
94 size as usize
95 }
96}
97
98#[derive(Debug, Clone)]
99pub struct Buffer {
100 buffer: Arc<Mutex<BufferInternal>>,
101 notify: Arc<Notify>,
102}
103
104impl Buffer {
105 pub fn new(limit_count: usize, limit_size: usize) -> Self {
106 Buffer {
107 buffer: Arc::new(Mutex::new(BufferInternal {
108 data: vec![],
109 head: 0,
110 tail: 0,
111
112 closed: false,
113 subs: false,
114
115 count: 0,
116 limit_count,
117 limit_size,
118 })),
119 notify: Arc::new(Notify::new()),
120 }
121 }
122
123 pub async fn write(&self, packet: &[u8]) -> Result<usize> {
128 if packet.len() >= 0x10000 {
129 return Err(Error::ErrPacketTooBig);
130 }
131
132 let mut b = self.buffer.lock().await;
133
134 if b.closed {
135 return Err(Error::ErrBufferClosed);
136 }
137
138 if (b.limit_count > 0 && b.count >= b.limit_count)
139 || (b.limit_size > 0 && b.size() + 2 + packet.len() > b.limit_size)
140 {
141 return Err(Error::ErrBufferFull);
142 }
143
144 while !b.available(packet.len()) {
146 b.grow()?;
147 }
148
149 let tail = b.tail;
151 b.data[tail] = (packet.len() >> 8) as u8;
152 b.tail += 1;
153 if b.tail >= b.data.len() {
154 b.tail = 0;
155 }
156
157 let tail = b.tail;
158 b.data[tail] = packet.len() as u8;
159 b.tail += 1;
160 if b.tail >= b.data.len() {
161 b.tail = 0;
162 }
163
164 let end = std::cmp::min(b.data.len(), b.tail + packet.len());
166 let n = end - b.tail;
167 let tail = b.tail;
168 b.data[tail..end].copy_from_slice(&packet[..n]);
169 b.tail += n;
170 if b.tail >= b.data.len() {
171 let m = packet.len() - n;
173 b.data[..m].copy_from_slice(&packet[n..]);
174 b.tail = m;
175 }
176 b.count += 1;
177
178 if b.subs {
179 self.notify.notify_one();
181 b.subs = false;
182 }
183
184 Ok(packet.len())
185 }
186
187 pub async fn read(&self, packet: &mut [u8], duration: Option<Duration>) -> Result<usize> {
192 loop {
193 {
194 let mut b = self.buffer.lock().await;
196
197 if b.head != b.tail {
198 let n1 = b.data[b.head];
200 b.head += 1;
201 if b.head >= b.data.len() {
202 b.head = 0;
203 }
204 let n2 = b.data[b.head];
205 b.head += 1;
206 if b.head >= b.data.len() {
207 b.head = 0;
208 }
209 let count = ((n1 as usize) << 8) | n2 as usize;
210
211 let mut copied = count;
213 if copied > packet.len() {
214 copied = packet.len();
215 }
216
217 if b.head + copied < b.data.len() {
219 packet[..copied].copy_from_slice(&b.data[b.head..b.head + copied]);
220 } else {
221 let k = b.data.len() - b.head;
222 packet[..k].copy_from_slice(&b.data[b.head..]);
223 packet[k..copied].copy_from_slice(&b.data[..copied - k]);
224 }
225
226 b.head += count;
228 if b.head >= b.data.len() {
229 b.head -= b.data.len();
230 }
231
232 if b.head == b.tail {
233 b.head = 0;
236 b.tail = 0;
237 }
238
239 b.count -= 1;
240
241 if copied < count {
242 return Err(Error::ErrBufferShort);
243 }
244 return Ok(copied);
245 } else {
246 b.subs = true;
248 }
249
250 if b.closed {
251 return Err(Error::ErrBufferClosed);
252 }
253 }
254
255 if let Some(d) = duration {
257 if timeout(d, self.notify.notified()).await.is_err() {
258 return Err(Error::ErrTimeout);
259 }
260 } else {
261 self.notify.notified().await;
262 }
263 }
264 }
265
266 pub async fn close(&self) {
269 let mut b = self.buffer.lock().await;
272
273 if b.closed {
274 return;
275 }
276
277 b.closed = true;
278 self.notify.notify_waiters();
279 }
280
281 pub async fn is_closed(&self) -> bool {
282 let b = self.buffer.lock().await;
283
284 b.closed
285 }
286
287 pub async fn count(&self) -> usize {
289 let b = self.buffer.lock().await;
290
291 b.count
292 }
293
294 pub async fn set_limit_count(&self, limit: usize) {
298 let mut b = self.buffer.lock().await;
299
300 b.limit_count = limit
301 }
302
303 pub async fn size(&self) -> usize {
305 let b = self.buffer.lock().await;
306
307 b.size()
308 }
309
310 pub async fn set_limit_size(&self, limit: usize) {
318 let mut b = self.buffer.lock().await;
319
320 b.limit_size = limit
321 }
322}