1#![deny(
38 missing_docs,
39 unused_must_use,
40 unused_mut,
41 unused_imports,
42 unused_import_braces
43)]
44
45pub use tungstenite;
46
47mod compat;
48mod handshake;
49
50#[cfg(any(
51 feature = "async-tls",
52 feature = "async-native-tls",
53 feature = "tokio-native-tls",
54 feature = "tokio-rustls-manual-roots",
55 feature = "tokio-rustls-native-certs",
56 feature = "tokio-rustls-webpki-roots",
57 feature = "tokio-openssl",
58))]
59pub mod stream;
60
61use std::{
62 io::{Read, Write},
63 pin::Pin,
64 task::{ready, Context, Poll},
65};
66
67use compat::{cvt, AllowStd, ContextWaker};
68use futures_core::stream::{FusedStream, Stream};
69use futures_io::{AsyncRead, AsyncWrite};
70use log::*;
71
72#[cfg(feature = "handshake")]
73use tungstenite::{
74 client::IntoClientRequest,
75 handshake::{
76 client::{ClientHandshake, Response},
77 server::{Callback, NoCallback},
78 HandshakeError,
79 },
80};
81use tungstenite::{
82 error::Error as WsError,
83 protocol::{Message, Role, WebSocket, WebSocketConfig},
84};
85
86#[cfg(feature = "async-std-runtime")]
87pub mod async_std;
88#[cfg(feature = "async-tls")]
89pub mod async_tls;
90#[cfg(feature = "gio-runtime")]
91pub mod gio;
92#[cfg(feature = "tokio-runtime")]
93pub mod tokio;
94
95pub mod bytes;
96pub use bytes::ByteReader;
97#[cfg(feature = "futures-03-sink")]
98pub use bytes::ByteWriter;
99
100use tungstenite::protocol::CloseFrame;
101
102#[cfg(feature = "handshake")]
115pub async fn client_async<'a, R, S>(
116 request: R,
117 stream: S,
118) -> Result<(WebSocketStream<S>, Response), WsError>
119where
120 R: IntoClientRequest + Unpin,
121 S: AsyncRead + AsyncWrite + Unpin,
122{
123 client_async_with_config(request, stream, None).await
124}
125
126#[cfg(feature = "handshake")]
129pub async fn client_async_with_config<'a, R, S>(
130 request: R,
131 stream: S,
132 config: Option<WebSocketConfig>,
133) -> Result<(WebSocketStream<S>, Response), WsError>
134where
135 R: IntoClientRequest + Unpin,
136 S: AsyncRead + AsyncWrite + Unpin,
137{
138 let f = handshake::client_handshake(stream, move |allow_std| {
139 let request = request.into_client_request()?;
140 let cli_handshake = ClientHandshake::start(allow_std, request, config)?;
141 cli_handshake.handshake()
142 });
143 f.await.map_err(|e| match e {
144 HandshakeError::Failure(e) => e,
145 e => WsError::Io(std::io::Error::new(
146 std::io::ErrorKind::Other,
147 e.to_string(),
148 )),
149 })
150}
151
152#[cfg(feature = "handshake")]
164pub async fn accept_async<S>(stream: S) -> Result<WebSocketStream<S>, WsError>
165where
166 S: AsyncRead + AsyncWrite + Unpin,
167{
168 accept_hdr_async(stream, NoCallback).await
169}
170
171#[cfg(feature = "handshake")]
174pub async fn accept_async_with_config<S>(
175 stream: S,
176 config: Option<WebSocketConfig>,
177) -> Result<WebSocketStream<S>, WsError>
178where
179 S: AsyncRead + AsyncWrite + Unpin,
180{
181 accept_hdr_async_with_config(stream, NoCallback, config).await
182}
183
184#[cfg(feature = "handshake")]
190pub async fn accept_hdr_async<S, C>(stream: S, callback: C) -> Result<WebSocketStream<S>, WsError>
191where
192 S: AsyncRead + AsyncWrite + Unpin,
193 C: Callback + Unpin,
194{
195 accept_hdr_async_with_config(stream, callback, None).await
196}
197
198#[cfg(feature = "handshake")]
201pub async fn accept_hdr_async_with_config<S, C>(
202 stream: S,
203 callback: C,
204 config: Option<WebSocketConfig>,
205) -> Result<WebSocketStream<S>, WsError>
206where
207 S: AsyncRead + AsyncWrite + Unpin,
208 C: Callback + Unpin,
209{
210 let f = handshake::server_handshake(stream, move |allow_std| {
211 tungstenite::accept_hdr_with_config(allow_std, callback, config)
212 });
213 f.await.map_err(|e| match e {
214 HandshakeError::Failure(e) => e,
215 e => WsError::Io(std::io::Error::new(
216 std::io::ErrorKind::Other,
217 e.to_string(),
218 )),
219 })
220}
221
222#[derive(Debug)]
232pub struct WebSocketStream<S> {
233 inner: WebSocket<AllowStd<S>>,
234 #[cfg(feature = "futures-03-sink")]
235 closing: bool,
236 ended: bool,
237 ready: bool,
242}
243
244impl<S> WebSocketStream<S> {
245 pub async fn from_raw_socket(stream: S, role: Role, config: Option<WebSocketConfig>) -> Self
248 where
249 S: AsyncRead + AsyncWrite + Unpin,
250 {
251 handshake::without_handshake(stream, move |allow_std| {
252 WebSocket::from_raw_socket(allow_std, role, config)
253 })
254 .await
255 }
256
257 pub async fn from_partially_read(
260 stream: S,
261 part: Vec<u8>,
262 role: Role,
263 config: Option<WebSocketConfig>,
264 ) -> Self
265 where
266 S: AsyncRead + AsyncWrite + Unpin,
267 {
268 handshake::without_handshake(stream, move |allow_std| {
269 WebSocket::from_partially_read(allow_std, part, role, config)
270 })
271 .await
272 }
273
274 pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
275 Self {
276 inner: ws,
277 #[cfg(feature = "futures-03-sink")]
278 closing: false,
279 ended: false,
280 ready: true,
281 }
282 }
283
284 fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
285 where
286 S: Unpin,
287 F: FnOnce(&mut WebSocket<AllowStd<S>>) -> R,
288 AllowStd<S>: Read + Write,
289 {
290 #[cfg(feature = "verbose-logging")]
291 trace!("{}:{} WebSocketStream.with_context", file!(), line!());
292 if let Some((kind, ctx)) = ctx {
293 self.inner.get_mut().set_waker(kind, ctx.waker());
294 }
295 f(&mut self.inner)
296 }
297
298 pub fn get_ref(&self) -> &S
300 where
301 S: AsyncRead + AsyncWrite + Unpin,
302 {
303 self.inner.get_ref().get_ref()
304 }
305
306 pub fn get_mut(&mut self) -> &mut S
308 where
309 S: AsyncRead + AsyncWrite + Unpin,
310 {
311 self.inner.get_mut().get_mut()
312 }
313
314 pub fn get_config(&self) -> &WebSocketConfig {
316 self.inner.get_config()
317 }
318
319 pub async fn close(&mut self, msg: Option<CloseFrame>) -> Result<(), WsError>
321 where
322 S: AsyncRead + AsyncWrite + Unpin,
323 {
324 self.send(Message::Close(msg)).await
325 }
326}
327
328impl<T> Stream for WebSocketStream<T>
329where
330 T: AsyncRead + AsyncWrite + Unpin,
331{
332 type Item = Result<Message, WsError>;
333
334 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
335 #[cfg(feature = "verbose-logging")]
336 trace!("{}:{} Stream.poll_next", file!(), line!());
337
338 if self.ended {
342 return Poll::Ready(None);
343 }
344
345 match ready!(self.with_context(Some((ContextWaker::Read, cx)), |s| {
346 #[cfg(feature = "verbose-logging")]
347 trace!(
348 "{}:{} Stream.with_context poll_next -> read()",
349 file!(),
350 line!()
351 );
352 cvt(s.read())
353 })) {
354 Ok(v) => Poll::Ready(Some(Ok(v))),
355 Err(e) => {
356 self.ended = true;
357 if matches!(e, WsError::AlreadyClosed | WsError::ConnectionClosed) {
358 Poll::Ready(None)
359 } else {
360 Poll::Ready(Some(Err(e)))
361 }
362 }
363 }
364 }
365}
366
367impl<T> FusedStream for WebSocketStream<T>
368where
369 T: AsyncRead + AsyncWrite + Unpin,
370{
371 fn is_terminated(&self) -> bool {
372 self.ended
373 }
374}
375
376#[cfg(feature = "futures-03-sink")]
377impl<T> futures_util::Sink<Message> for WebSocketStream<T>
378where
379 T: AsyncRead + AsyncWrite + Unpin,
380{
381 type Error = WsError;
382
383 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
384 if self.ready {
385 Poll::Ready(Ok(()))
386 } else {
387 (*self)
389 .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
390 .map(|r| {
391 self.ready = true;
392 r
393 })
394 }
395 }
396
397 fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
398 match (*self).with_context(None, |s| s.write(item)) {
399 Ok(()) => {
400 self.ready = true;
401 Ok(())
402 }
403 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
404 self.ready = false;
407 Ok(())
408 }
409 Err(e) => {
410 self.ready = true;
411 debug!("websocket start_send error: {}", e);
412 Err(e)
413 }
414 }
415 }
416
417 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
418 (*self)
419 .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
420 .map(|r| {
421 self.ready = true;
422 match r {
423 Err(WsError::ConnectionClosed) => Ok(()),
425 other => other,
426 }
427 })
428 }
429
430 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
431 self.ready = true;
432 let res = if self.closing {
433 (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
435 } else {
436 (*self).with_context(Some((ContextWaker::Write, cx)), |s| s.close(None))
437 };
438
439 match res {
440 Ok(()) => Poll::Ready(Ok(())),
441 Err(WsError::ConnectionClosed) => Poll::Ready(Ok(())),
442 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
443 trace!("WouldBlock");
444 self.closing = true;
445 Poll::Pending
446 }
447 Err(err) => {
448 debug!("websocket close error: {}", err);
449 Poll::Ready(Err(err))
450 }
451 }
452 }
453}
454
455impl<S> WebSocketStream<S> {
456 pub async fn send(&mut self, msg: Message) -> Result<(), WsError>
458 where
459 S: AsyncRead + AsyncWrite + Unpin,
460 {
461 Send::new(self, msg).await
462 }
463}
464
465struct Send<'a, S> {
466 ws: &'a mut WebSocketStream<S>,
467 msg: Option<Message>,
468}
469
470impl<'a, S> Send<'a, S>
471where
472 S: AsyncRead + AsyncWrite + Unpin,
473{
474 fn new(ws: &'a mut WebSocketStream<S>, msg: Message) -> Self {
475 Self { ws, msg: Some(msg) }
476 }
477}
478
479impl<S> std::future::Future for Send<'_, S>
480where
481 S: AsyncRead + AsyncWrite + Unpin,
482{
483 type Output = Result<(), WsError>;
484
485 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
486 if self.msg.is_some() {
487 if !self.ws.ready {
488 let polled = self
490 .ws
491 .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
492 .map(|r| {
493 self.ws.ready = true;
494 r
495 });
496 ready!(polled)?
497 }
498
499 let msg = self.msg.take().expect("unreachable");
500 match self.ws.with_context(None, |s| s.write(msg)) {
501 Ok(_) => Ok(()),
502 Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
503 self.ws.ready = false;
507 Ok(())
508 }
509 Err(e) => {
510 debug!("websocket start_send error: {}", e);
511 Err(e)
512 }
513 }?;
514 }
515
516 let polled = self
517 .ws
518 .with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush()))
519 .map(|r| {
520 self.ws.ready = true;
521 match r {
522 Err(WsError::ConnectionClosed) => Ok(()),
524 other => other,
525 }
526 });
527 ready!(polled)?;
528
529 Poll::Ready(Ok(()))
530 }
531}
532
533#[cfg(any(
534 feature = "async-tls",
535 feature = "async-std-runtime",
536 feature = "tokio-runtime",
537 feature = "gio-runtime"
538))]
539#[inline]
541pub(crate) fn domain(
542 request: &tungstenite::handshake::client::Request,
543) -> Result<String, tungstenite::Error> {
544 request
545 .uri()
546 .host()
547 .map(|host| {
548 let host = if host.starts_with('[') {
554 &host[1..host.len() - 1]
555 } else {
556 host
557 };
558
559 host.to_owned()
560 })
561 .ok_or(tungstenite::Error::Url(
562 tungstenite::error::UrlError::NoHostName,
563 ))
564}
565
566#[cfg(any(
567 feature = "async-std-runtime",
568 feature = "tokio-runtime",
569 feature = "gio-runtime"
570))]
571#[inline]
573pub(crate) fn port(
574 request: &tungstenite::handshake::client::Request,
575) -> Result<u16, tungstenite::Error> {
576 request
577 .uri()
578 .port_u16()
579 .or_else(|| match request.uri().scheme_str() {
580 Some("wss") => Some(443),
581 Some("ws") => Some(80),
582 _ => None,
583 })
584 .ok_or(tungstenite::Error::Url(
585 tungstenite::error::UrlError::UnsupportedUrlScheme,
586 ))
587}
588
589#[cfg(test)]
590mod tests {
591 #[cfg(any(
592 feature = "async-tls",
593 feature = "async-std-runtime",
594 feature = "tokio-runtime",
595 feature = "gio-runtime"
596 ))]
597 #[test]
598 fn domain_strips_ipv6_brackets() {
599 use tungstenite::client::IntoClientRequest;
600
601 let request = "ws://[::1]:80".into_client_request().unwrap();
602 assert_eq!(crate::domain(&request).unwrap(), "::1");
603 }
604
605 #[cfg(feature = "handshake")]
606 #[test]
607 fn requests_cannot_contain_invalid_uris() {
608 use tungstenite::client::IntoClientRequest;
609
610 assert!("ws://[".into_client_request().is_err());
611 assert!("ws://[blabla/bla".into_client_request().is_err());
612 assert!("ws://[::1/bla".into_client_request().is_err());
613 }
614}