1#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
24
25use std::{
26 io, iter,
27 pin::Pin,
28 task::{Context, Poll},
29};
30
31use bytes::Bytes;
32use futures::{future::BoxFuture, prelude::*};
33use libp2p_core::{
34 upgrade::{InboundConnectionUpgrade, OutboundConnectionUpgrade},
35 UpgradeInfo,
36};
37use libp2p_identity as identity;
38use libp2p_identity::{PeerId, PublicKey};
39
40use crate::error::Error;
41
42mod error;
43mod handshake;
44mod proto {
45 #![allow(unreachable_pub)]
46 include!("generated/mod.rs");
47 pub(crate) use self::structs::Exchange;
48}
49
50#[derive(Clone)]
52pub struct Config {
53 local_public_key: identity::PublicKey,
54}
55
56impl Config {
57 pub fn new(identity: &identity::Keypair) -> Self {
58 Self {
59 local_public_key: identity.public(),
60 }
61 }
62}
63
64impl UpgradeInfo for Config {
65 type Info = &'static str;
66 type InfoIter = iter::Once<Self::Info>;
67
68 fn protocol_info(&self) -> Self::InfoIter {
69 iter::once("/plaintext/2.0.0")
70 }
71}
72
73impl<C> InboundConnectionUpgrade<C> for Config
74where
75 C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
76{
77 type Output = (PeerId, Output<C>);
78 type Error = Error;
79 type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
80
81 fn upgrade_inbound(self, socket: C, _: Self::Info) -> Self::Future {
82 Box::pin(self.handshake(socket))
83 }
84}
85
86impl<C> OutboundConnectionUpgrade<C> for Config
87where
88 C: AsyncRead + AsyncWrite + Send + Unpin + 'static,
89{
90 type Output = (PeerId, Output<C>);
91 type Error = Error;
92 type Future = BoxFuture<'static, Result<Self::Output, Self::Error>>;
93
94 fn upgrade_outbound(self, socket: C, _: Self::Info) -> Self::Future {
95 Box::pin(self.handshake(socket))
96 }
97}
98
99impl Config {
100 async fn handshake<T>(self, socket: T) -> Result<(PeerId, Output<T>), Error>
101 where
102 T: AsyncRead + AsyncWrite + Send + Unpin + 'static,
103 {
104 tracing::debug!("Starting plaintext handshake.");
105 let (socket, remote_key, read_buffer) = handshake::handshake(socket, self).await?;
106 tracing::debug!("Finished plaintext handshake.");
107
108 Ok((
109 remote_key.to_peer_id(),
110 Output {
111 socket,
112 remote_key,
113 read_buffer,
114 },
115 ))
116 }
117}
118
119pub struct Output<S>
121where
122 S: AsyncRead + AsyncWrite + Unpin,
123{
124 pub socket: S,
126 pub remote_key: PublicKey,
128 read_buffer: Bytes,
132}
133
134impl<S: AsyncRead + AsyncWrite + Unpin> AsyncRead for Output<S> {
135 fn poll_read(
136 mut self: Pin<&mut Self>,
137 cx: &mut Context<'_>,
138 buf: &mut [u8],
139 ) -> Poll<Result<usize, io::Error>> {
140 if !self.read_buffer.is_empty() {
141 let n = std::cmp::min(buf.len(), self.read_buffer.len());
142 let b = self.read_buffer.split_to(n);
143 buf[..n].copy_from_slice(&b[..]);
144 return Poll::Ready(Ok(n));
145 }
146 AsyncRead::poll_read(Pin::new(&mut self.socket), cx, buf)
147 }
148}
149
150impl<S: AsyncRead + AsyncWrite + Unpin> AsyncWrite for Output<S> {
151 fn poll_write(
152 mut self: Pin<&mut Self>,
153 cx: &mut Context<'_>,
154 buf: &[u8],
155 ) -> Poll<Result<usize, io::Error>> {
156 AsyncWrite::poll_write(Pin::new(&mut self.socket), cx, buf)
157 }
158
159 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
160 AsyncWrite::poll_flush(Pin::new(&mut self.socket), cx)
161 }
162
163 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
164 AsyncWrite::poll_close(Pin::new(&mut self.socket), cx)
165 }
166}