1#![cfg(all(
35 unix,
36 not(target_os = "emscripten"),
37 any(feature = "tokio", feature = "async-std")
38))]
39#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
40
41use std::{
42 collections::VecDeque,
43 io,
44 path::PathBuf,
45 pin::Pin,
46 task::{Context, Poll},
47};
48
49use futures::{
50 future::{BoxFuture, Ready},
51 prelude::*,
52 stream::BoxStream,
53};
54use libp2p_core::{
55 multiaddr::{Multiaddr, Protocol},
56 transport::{DialOpts, ListenerId, TransportError, TransportEvent},
57 Transport,
58};
59
60pub type Listener<T> = BoxStream<
61 'static,
62 Result<
63 TransportEvent<<T as Transport>::ListenerUpgrade, <T as Transport>::Error>,
64 Result<(), <T as Transport>::Error>,
65 >,
66>;
67
68macro_rules! codegen {
69 ($feature_name:expr, $uds_config:ident, $build_listener:expr, $unix_stream:ty, $($mut_or_not:tt)*) => {
70 pub struct $uds_config {
72 listeners: VecDeque<(ListenerId, Listener<Self>)>,
73 }
74
75 impl $uds_config {
76 pub fn new() -> $uds_config {
78 $uds_config {
79 listeners: VecDeque::new(),
80 }
81 }
82 }
83
84 impl Default for $uds_config {
85 fn default() -> Self {
86 Self::new()
87 }
88 }
89
90 impl Transport for $uds_config {
91 type Output = $unix_stream;
92 type Error = io::Error;
93 type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
94 type Dial = BoxFuture<'static, Result<Self::Output, Self::Error>>;
95
96 fn listen_on(
97 &mut self,
98 id: ListenerId,
99 addr: Multiaddr,
100 ) -> Result<(), TransportError<Self::Error>> {
101 if let Ok(path) = multiaddr_to_path(&addr) {
102 #[allow(clippy::redundant_closure_call)]
103 let listener = $build_listener(path)
104 .map_err(Err)
105 .map_ok(move |listener| {
106 stream::once({
107 let addr = addr.clone();
108 async move {
109 tracing::debug!(address=%addr, "Now listening on address");
110 Ok(TransportEvent::NewAddress {
111 listener_id: id,
112 listen_addr: addr,
113 })
114 }
115 })
116 .chain(stream::unfold(
117 listener,
118 move |listener| {
119 let addr = addr.clone();
120 async move {
121 let event = match listener.accept().await {
122 Ok((stream, _)) => {
123 tracing::debug!(address=%addr, "incoming connection on address");
124 TransportEvent::Incoming {
125 upgrade: future::ok(stream),
126 local_addr: addr.clone(),
127 send_back_addr: addr.clone(),
128 listener_id: id,
129 }
130 }
131 Err(error) => TransportEvent::ListenerError {
132 listener_id: id,
133 error,
134 },
135 };
136 Some((Ok(event), listener))
137 }
138 },
139 ))
140 })
141 .try_flatten_stream()
142 .boxed();
143 self.listeners.push_back((id, listener));
144 Ok(())
145 } else {
146 Err(TransportError::MultiaddrNotSupported(addr))
147 }
148 }
149
150 fn remove_listener(&mut self, id: ListenerId) -> bool {
151 if let Some(index) = self
152 .listeners
153 .iter()
154 .position(|(listener_id, _)| listener_id == &id)
155 {
156 let listener_stream = self.listeners.get_mut(index).unwrap();
157 let report_closed_stream = stream::once(async { Err(Ok(())) }).boxed();
158 *listener_stream = (id, report_closed_stream);
159 true
160 } else {
161 false
162 }
163 }
164
165 fn dial(&mut self, addr: Multiaddr, _dial_opts: DialOpts) -> Result<Self::Dial, TransportError<Self::Error>> {
166 if let Ok(path) = multiaddr_to_path(&addr) {
168 tracing::debug!(address=%addr, "Dialing address");
169 Ok(async move { <$unix_stream>::connect(&path).await }.boxed())
170 } else {
171 Err(TransportError::MultiaddrNotSupported(addr))
172 }
173 }
174
175 fn poll(
176 mut self: Pin<&mut Self>,
177 cx: &mut Context<'_>,
178 ) -> Poll<TransportEvent<Self::ListenerUpgrade, Self::Error>> {
179 let mut remaining = self.listeners.len();
180 while let Some((id, mut listener)) = self.listeners.pop_back() {
181 let event = match Stream::poll_next(Pin::new(&mut listener), cx) {
182 Poll::Pending => None,
183 Poll::Ready(None) => panic!("Alive listeners always have a sender."),
184 Poll::Ready(Some(Ok(event))) => Some(event),
185 Poll::Ready(Some(Err(reason))) => {
186 return Poll::Ready(TransportEvent::ListenerClosed {
187 listener_id: id,
188 reason,
189 })
190 }
191 };
192 self.listeners.push_front((id, listener));
193 if let Some(event) = event {
194 return Poll::Ready(event);
195 } else {
196 remaining -= 1;
197 if remaining == 0 {
198 break;
199 }
200 }
201 }
202 Poll::Pending
203 }
204 }
205 };
206}
207
208#[cfg(feature = "async-std")]
209codegen!(
210 "async-std",
211 UdsConfig,
212 |addr| async move { async_std::os::unix::net::UnixListener::bind(&addr).await },
213 async_std::os::unix::net::UnixStream,
214);
215#[cfg(feature = "tokio")]
216codegen!(
217 "tokio",
218 TokioUdsConfig,
219 |addr| async move { tokio::net::UnixListener::bind(&addr) },
220 tokio::net::UnixStream,
221);
222
223fn multiaddr_to_path(addr: &Multiaddr) -> Result<PathBuf, ()> {
229 let mut protocols = addr.iter();
230 match protocols.next() {
231 Some(Protocol::Unix(ref path)) => {
232 let path = PathBuf::from(path.as_ref());
233 if !path.is_absolute() {
234 return Err(());
235 }
236 match protocols.next() {
237 None | Some(Protocol::P2p(_)) => Ok(path),
238 Some(_) => Err(()),
239 }
240 }
241 _ => Err(()),
242 }
243}
244
245#[cfg(all(test, feature = "async-std"))]
246mod tests {
247 use std::{borrow::Cow, path::Path};
248
249 use futures::{channel::oneshot, prelude::*};
250 use libp2p_core::{
251 multiaddr::{Multiaddr, Protocol},
252 transport::{DialOpts, ListenerId, PortUse},
253 Endpoint, Transport,
254 };
255
256 use super::{multiaddr_to_path, UdsConfig};
257
258 #[test]
259 fn multiaddr_to_path_conversion() {
260 assert!(
261 multiaddr_to_path(&"/ip4/127.0.0.1/udp/1234".parse::<Multiaddr>().unwrap()).is_err()
262 );
263
264 assert_eq!(
265 multiaddr_to_path(&Multiaddr::from(Protocol::Unix("/tmp/foo".into()))),
266 Ok(Path::new("/tmp/foo").to_owned())
267 );
268 assert_eq!(
269 multiaddr_to_path(&Multiaddr::from(Protocol::Unix("/home/bar/baz".into()))),
270 Ok(Path::new("/home/bar/baz").to_owned())
271 );
272 }
273
274 #[test]
275 fn communicating_between_dialer_and_listener() {
276 let temp_dir = tempfile::tempdir().unwrap();
277 let socket = temp_dir.path().join("socket");
278 let addr = Multiaddr::from(Protocol::Unix(Cow::Owned(
279 socket.to_string_lossy().into_owned(),
280 )));
281
282 let (tx, rx) = oneshot::channel();
283
284 async_std::task::spawn(async move {
285 let mut transport = UdsConfig::new().boxed();
286 transport.listen_on(ListenerId::next(), addr).unwrap();
287
288 let listen_addr = transport
289 .select_next_some()
290 .await
291 .into_new_address()
292 .expect("listen address");
293
294 tx.send(listen_addr).unwrap();
295
296 let (sock, _addr) = transport
297 .select_next_some()
298 .await
299 .into_incoming()
300 .expect("incoming stream");
301
302 let mut sock = sock.await.unwrap();
303 let mut buf = [0u8; 3];
304 sock.read_exact(&mut buf).await.unwrap();
305 assert_eq!(buf, [1, 2, 3]);
306 });
307
308 async_std::task::block_on(async move {
309 let mut uds = UdsConfig::new();
310 let addr = rx.await.unwrap();
311 let mut socket = uds
312 .dial(
313 addr,
314 DialOpts {
315 role: Endpoint::Dialer,
316 port_use: PortUse::Reuse,
317 },
318 )
319 .unwrap()
320 .await
321 .unwrap();
322 let _ = socket.write(&[1, 2, 3]).await.unwrap();
323 });
324 }
325
326 #[test]
327 #[ignore] fn larger_addr_denied() {
329 let mut uds = UdsConfig::new();
330
331 let addr = "/unix//foo/bar".parse::<Multiaddr>().unwrap();
332 assert!(uds.listen_on(ListenerId::next(), addr).is_err());
333 }
334
335 #[test]
336 #[ignore] fn relative_addr_denied() {
338 assert!("/unix/./foo/bar".parse::<Multiaddr>().is_err());
339 }
340}