webrtc_srtp/session/
mod.rs

1#[cfg(test)]
2mod session_rtcp_test;
3#[cfg(test)]
4mod session_rtp_test;
5
6use std::collections::{HashMap, HashSet};
7use std::marker::{Send, Sync};
8use std::sync::Arc;
9
10use bytes::Bytes;
11use tokio::sync::{mpsc, Mutex};
12use util::conn::Conn;
13use util::marshal::*;
14
15use crate::config::*;
16use crate::context::*;
17use crate::error::{Error, Result};
18use crate::option::*;
19use crate::stream::*;
20
21const DEFAULT_SESSION_SRTP_REPLAY_PROTECTION_WINDOW: usize = 64;
22const DEFAULT_SESSION_SRTCP_REPLAY_PROTECTION_WINDOW: usize = 64;
23
24/// Session implements io.ReadWriteCloser and provides a bi-directional SRTP session
25/// SRTP itself does not have a design like this, but it is common in most applications
26/// for local/remote to each have their own keying material. This provides those patterns
27/// instead of making everyone re-implement
28pub struct Session {
29    local_context: Arc<Mutex<Context>>,
30    streams_map: Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
31    new_stream_rx: Arc<Mutex<mpsc::Receiver<Arc<Stream>>>>,
32    close_stream_tx: mpsc::Sender<u32>,
33    close_session_tx: mpsc::Sender<()>,
34    pub(crate) udp_tx: Arc<dyn Conn + Send + Sync>,
35    is_rtp: bool,
36}
37
38impl Session {
39    pub async fn new(
40        conn: Arc<dyn Conn + Send + Sync>,
41        config: Config,
42        is_rtp: bool,
43    ) -> Result<Self> {
44        let local_context = Context::new(
45            &config.keys.local_master_key,
46            &config.keys.local_master_salt,
47            config.profile,
48            config.local_rtp_options,
49            config.local_rtcp_options,
50        )?;
51
52        let mut remote_context = Context::new(
53            &config.keys.remote_master_key,
54            &config.keys.remote_master_salt,
55            config.profile,
56            if config.remote_rtp_options.is_none() {
57                Some(srtp_replay_protection(
58                    DEFAULT_SESSION_SRTP_REPLAY_PROTECTION_WINDOW,
59                ))
60            } else {
61                config.remote_rtp_options
62            },
63            if config.remote_rtcp_options.is_none() {
64                Some(srtcp_replay_protection(
65                    DEFAULT_SESSION_SRTCP_REPLAY_PROTECTION_WINDOW,
66                ))
67            } else {
68                config.remote_rtcp_options
69            },
70        )?;
71
72        let streams_map = Arc::new(Mutex::new(HashMap::new()));
73        let (mut new_stream_tx, new_stream_rx) = mpsc::channel(8);
74        let (close_stream_tx, mut close_stream_rx) = mpsc::channel(8);
75        let (close_session_tx, mut close_session_rx) = mpsc::channel(8);
76        let udp_tx = Arc::clone(&conn);
77        let udp_rx = Arc::clone(&conn);
78        let cloned_streams_map = Arc::clone(&streams_map);
79        let cloned_close_stream_tx = close_stream_tx.clone();
80
81        tokio::spawn(async move {
82            let mut buf = vec![0u8; 8192];
83
84            loop {
85                let incoming_stream = Session::incoming(
86                    &udp_rx,
87                    &mut buf,
88                    &cloned_streams_map,
89                    &cloned_close_stream_tx,
90                    &mut new_stream_tx,
91                    &mut remote_context,
92                    is_rtp,
93                );
94                let close_stream = close_stream_rx.recv();
95                let close_session = close_session_rx.recv();
96
97                tokio::select! {
98                    result = incoming_stream => match result{
99                        Ok(()) => {},
100                        Err(err) => log::info!("{}", err),
101                    },
102                    opt = close_stream => if let Some(ssrc) = opt {
103                        Session::close_stream(&cloned_streams_map, ssrc).await
104                    },
105                    _ = close_session => break
106                }
107            }
108        });
109
110        Ok(Session {
111            local_context: Arc::new(Mutex::new(local_context)),
112            streams_map,
113            new_stream_rx: Arc::new(Mutex::new(new_stream_rx)),
114            close_stream_tx,
115            close_session_tx,
116            udp_tx,
117            is_rtp,
118        })
119    }
120
121    async fn close_stream(streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>, ssrc: u32) {
122        let mut streams = streams_map.lock().await;
123        streams.remove(&ssrc);
124    }
125
126    async fn incoming(
127        udp_rx: &Arc<dyn Conn + Send + Sync>,
128        buf: &mut [u8],
129        streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
130        close_stream_tx: &mpsc::Sender<u32>,
131        new_stream_tx: &mut mpsc::Sender<Arc<Stream>>,
132        remote_context: &mut Context,
133        is_rtp: bool,
134    ) -> Result<()> {
135        let n = udp_rx.recv(buf).await?;
136        if n == 0 {
137            return Err(Error::SessionEof);
138        }
139
140        let decrypted = if is_rtp {
141            remote_context.decrypt_rtp(&buf[0..n])?
142        } else {
143            remote_context.decrypt_rtcp(&buf[0..n])?
144        };
145
146        let mut buf = &decrypted[..];
147        let ssrcs = if is_rtp {
148            vec![rtp::header::Header::unmarshal(&mut buf)?.ssrc]
149        } else {
150            let pkts = rtcp::packet::unmarshal(&mut buf)?;
151            destination_ssrc(&pkts)
152        };
153
154        for ssrc in ssrcs {
155            let (stream, is_new) =
156                Session::get_or_create_stream(streams_map, close_stream_tx.clone(), is_rtp, ssrc)
157                    .await;
158            if is_new {
159                log::trace!(
160                    "srtp session got new {} stream {}",
161                    if is_rtp { "rtp" } else { "rtcp" },
162                    ssrc
163                );
164                new_stream_tx.send(Arc::clone(&stream)).await?;
165            }
166
167            match stream.buffer.write(&decrypted).await {
168                Ok(_) => {}
169                Err(err) => {
170                    // Silently drop data when the buffer is full.
171                    if util::Error::ErrBufferFull != err {
172                        return Err(err.into());
173                    }
174                }
175            }
176        }
177
178        Ok(())
179    }
180
181    async fn get_or_create_stream(
182        streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
183        close_stream_tx: mpsc::Sender<u32>,
184        is_rtp: bool,
185        ssrc: u32,
186    ) -> (Arc<Stream>, bool) {
187        let mut streams = streams_map.lock().await;
188
189        if let Some(stream) = streams.get(&ssrc) {
190            (Arc::clone(stream), false)
191        } else {
192            let stream = Arc::new(Stream::new(ssrc, close_stream_tx, is_rtp));
193            streams.insert(ssrc, Arc::clone(&stream));
194            (stream, true)
195        }
196    }
197
198    /// open on the given SSRC to create a stream, it can be used
199    /// if you want a certain SSRC, but don't want to wait for Accept
200    pub async fn open(&self, ssrc: u32) -> Arc<Stream> {
201        let (stream, _) = Session::get_or_create_stream(
202            &self.streams_map,
203            self.close_stream_tx.clone(),
204            self.is_rtp,
205            ssrc,
206        )
207        .await;
208
209        stream
210    }
211
212    /// accept returns a stream to handle RTCP for a single SSRC
213    pub async fn accept(&self) -> Result<Arc<Stream>> {
214        let mut new_stream_rx = self.new_stream_rx.lock().await;
215        let result = new_stream_rx.recv().await;
216        if let Some(stream) = result {
217            Ok(stream)
218        } else {
219            Err(Error::SessionSrtpAlreadyClosed)
220        }
221    }
222
223    pub async fn close(&self) -> Result<()> {
224        self.close_session_tx.send(()).await?;
225
226        Ok(())
227    }
228
229    pub async fn write(&self, buf: &Bytes, is_rtp: bool) -> Result<usize> {
230        if self.is_rtp != is_rtp {
231            return Err(Error::SessionRtpRtcpTypeMismatch);
232        }
233
234        let encrypted = {
235            let mut local_context = self.local_context.lock().await;
236
237            if is_rtp {
238                local_context.encrypt_rtp(buf)?
239            } else {
240                local_context.encrypt_rtcp(buf)?
241            }
242        };
243
244        Ok(self.udp_tx.send(&encrypted).await?)
245    }
246
247    pub async fn write_rtp(&self, pkt: &rtp::packet::Packet) -> Result<usize> {
248        let raw = pkt.marshal()?;
249        self.write(&raw, true).await
250    }
251
252    pub async fn write_rtcp(
253        &self,
254        pkt: &(dyn rtcp::packet::Packet + Send + Sync),
255    ) -> Result<usize> {
256        let raw = pkt.marshal()?;
257        self.write(&raw, false).await
258    }
259}
260
261/// create a list of Destination SSRCs
262/// that's a superset of all Destinations in the slice.
263fn destination_ssrc(pkts: &[Box<dyn rtcp::packet::Packet + Send + Sync>]) -> Vec<u32> {
264    let mut ssrc_set = HashSet::new();
265    for p in pkts {
266        for ssrc in p.destination_ssrc() {
267            ssrc_set.insert(ssrc);
268        }
269    }
270    ssrc_set.into_iter().collect()
271}