1#[cfg(test)]
2mod session_rtcp_test;
3#[cfg(test)]
4mod session_rtp_test;
5
6use std::collections::{HashMap, HashSet};
7use std::marker::{Send, Sync};
8use std::sync::Arc;
9
10use bytes::Bytes;
11use tokio::sync::{mpsc, Mutex};
12use util::conn::Conn;
13use util::marshal::*;
14
15use crate::config::*;
16use crate::context::*;
17use crate::error::{Error, Result};
18use crate::option::*;
19use crate::stream::*;
20
21const DEFAULT_SESSION_SRTP_REPLAY_PROTECTION_WINDOW: usize = 64;
22const DEFAULT_SESSION_SRTCP_REPLAY_PROTECTION_WINDOW: usize = 64;
23
24pub struct Session {
29 local_context: Arc<Mutex<Context>>,
30 streams_map: Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
31 new_stream_rx: Arc<Mutex<mpsc::Receiver<Arc<Stream>>>>,
32 close_stream_tx: mpsc::Sender<u32>,
33 close_session_tx: mpsc::Sender<()>,
34 pub(crate) udp_tx: Arc<dyn Conn + Send + Sync>,
35 is_rtp: bool,
36}
37
38impl Session {
39 pub async fn new(
40 conn: Arc<dyn Conn + Send + Sync>,
41 config: Config,
42 is_rtp: bool,
43 ) -> Result<Self> {
44 let local_context = Context::new(
45 &config.keys.local_master_key,
46 &config.keys.local_master_salt,
47 config.profile,
48 config.local_rtp_options,
49 config.local_rtcp_options,
50 )?;
51
52 let mut remote_context = Context::new(
53 &config.keys.remote_master_key,
54 &config.keys.remote_master_salt,
55 config.profile,
56 if config.remote_rtp_options.is_none() {
57 Some(srtp_replay_protection(
58 DEFAULT_SESSION_SRTP_REPLAY_PROTECTION_WINDOW,
59 ))
60 } else {
61 config.remote_rtp_options
62 },
63 if config.remote_rtcp_options.is_none() {
64 Some(srtcp_replay_protection(
65 DEFAULT_SESSION_SRTCP_REPLAY_PROTECTION_WINDOW,
66 ))
67 } else {
68 config.remote_rtcp_options
69 },
70 )?;
71
72 let streams_map = Arc::new(Mutex::new(HashMap::new()));
73 let (mut new_stream_tx, new_stream_rx) = mpsc::channel(8);
74 let (close_stream_tx, mut close_stream_rx) = mpsc::channel(8);
75 let (close_session_tx, mut close_session_rx) = mpsc::channel(8);
76 let udp_tx = Arc::clone(&conn);
77 let udp_rx = Arc::clone(&conn);
78 let cloned_streams_map = Arc::clone(&streams_map);
79 let cloned_close_stream_tx = close_stream_tx.clone();
80
81 tokio::spawn(async move {
82 let mut buf = vec![0u8; 8192];
83
84 loop {
85 let incoming_stream = Session::incoming(
86 &udp_rx,
87 &mut buf,
88 &cloned_streams_map,
89 &cloned_close_stream_tx,
90 &mut new_stream_tx,
91 &mut remote_context,
92 is_rtp,
93 );
94 let close_stream = close_stream_rx.recv();
95 let close_session = close_session_rx.recv();
96
97 tokio::select! {
98 result = incoming_stream => match result{
99 Ok(()) => {},
100 Err(err) => log::info!("{}", err),
101 },
102 opt = close_stream => if let Some(ssrc) = opt {
103 Session::close_stream(&cloned_streams_map, ssrc).await
104 },
105 _ = close_session => break
106 }
107 }
108 });
109
110 Ok(Session {
111 local_context: Arc::new(Mutex::new(local_context)),
112 streams_map,
113 new_stream_rx: Arc::new(Mutex::new(new_stream_rx)),
114 close_stream_tx,
115 close_session_tx,
116 udp_tx,
117 is_rtp,
118 })
119 }
120
121 async fn close_stream(streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>, ssrc: u32) {
122 let mut streams = streams_map.lock().await;
123 streams.remove(&ssrc);
124 }
125
126 async fn incoming(
127 udp_rx: &Arc<dyn Conn + Send + Sync>,
128 buf: &mut [u8],
129 streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
130 close_stream_tx: &mpsc::Sender<u32>,
131 new_stream_tx: &mut mpsc::Sender<Arc<Stream>>,
132 remote_context: &mut Context,
133 is_rtp: bool,
134 ) -> Result<()> {
135 let n = udp_rx.recv(buf).await?;
136 if n == 0 {
137 return Err(Error::SessionEof);
138 }
139
140 let decrypted = if is_rtp {
141 remote_context.decrypt_rtp(&buf[0..n])?
142 } else {
143 remote_context.decrypt_rtcp(&buf[0..n])?
144 };
145
146 let mut buf = &decrypted[..];
147 let ssrcs = if is_rtp {
148 vec![rtp::header::Header::unmarshal(&mut buf)?.ssrc]
149 } else {
150 let pkts = rtcp::packet::unmarshal(&mut buf)?;
151 destination_ssrc(&pkts)
152 };
153
154 for ssrc in ssrcs {
155 let (stream, is_new) =
156 Session::get_or_create_stream(streams_map, close_stream_tx.clone(), is_rtp, ssrc)
157 .await;
158 if is_new {
159 log::trace!(
160 "srtp session got new {} stream {}",
161 if is_rtp { "rtp" } else { "rtcp" },
162 ssrc
163 );
164 new_stream_tx.send(Arc::clone(&stream)).await?;
165 }
166
167 match stream.buffer.write(&decrypted).await {
168 Ok(_) => {}
169 Err(err) => {
170 if util::Error::ErrBufferFull != err {
172 return Err(err.into());
173 }
174 }
175 }
176 }
177
178 Ok(())
179 }
180
181 async fn get_or_create_stream(
182 streams_map: &Arc<Mutex<HashMap<u32, Arc<Stream>>>>,
183 close_stream_tx: mpsc::Sender<u32>,
184 is_rtp: bool,
185 ssrc: u32,
186 ) -> (Arc<Stream>, bool) {
187 let mut streams = streams_map.lock().await;
188
189 if let Some(stream) = streams.get(&ssrc) {
190 (Arc::clone(stream), false)
191 } else {
192 let stream = Arc::new(Stream::new(ssrc, close_stream_tx, is_rtp));
193 streams.insert(ssrc, Arc::clone(&stream));
194 (stream, true)
195 }
196 }
197
198 pub async fn open(&self, ssrc: u32) -> Arc<Stream> {
201 let (stream, _) = Session::get_or_create_stream(
202 &self.streams_map,
203 self.close_stream_tx.clone(),
204 self.is_rtp,
205 ssrc,
206 )
207 .await;
208
209 stream
210 }
211
212 pub async fn accept(&self) -> Result<Arc<Stream>> {
214 let mut new_stream_rx = self.new_stream_rx.lock().await;
215 let result = new_stream_rx.recv().await;
216 if let Some(stream) = result {
217 Ok(stream)
218 } else {
219 Err(Error::SessionSrtpAlreadyClosed)
220 }
221 }
222
223 pub async fn close(&self) -> Result<()> {
224 self.close_session_tx.send(()).await?;
225
226 Ok(())
227 }
228
229 pub async fn write(&self, buf: &Bytes, is_rtp: bool) -> Result<usize> {
230 if self.is_rtp != is_rtp {
231 return Err(Error::SessionRtpRtcpTypeMismatch);
232 }
233
234 let encrypted = {
235 let mut local_context = self.local_context.lock().await;
236
237 if is_rtp {
238 local_context.encrypt_rtp(buf)?
239 } else {
240 local_context.encrypt_rtcp(buf)?
241 }
242 };
243
244 Ok(self.udp_tx.send(&encrypted).await?)
245 }
246
247 pub async fn write_rtp(&self, pkt: &rtp::packet::Packet) -> Result<usize> {
248 let raw = pkt.marshal()?;
249 self.write(&raw, true).await
250 }
251
252 pub async fn write_rtcp(
253 &self,
254 pkt: &(dyn rtcp::packet::Packet + Send + Sync),
255 ) -> Result<usize> {
256 let raw = pkt.marshal()?;
257 self.write(&raw, false).await
258 }
259}
260
261fn destination_ssrc(pkts: &[Box<dyn rtcp::packet::Packet + Send + Sync>]) -> Vec<u32> {
264 let mut ssrc_set = HashSet::new();
265 for p in pkts {
266 for ssrc in p.destination_ssrc() {
267 ssrc_set.insert(ssrc);
268 }
269 }
270 ssrc_set.into_iter().collect()
271}