ldap_rs/
channel.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
//! Low-level LDAP channel operations

use std::{io, net::ToSocketAddrs, time::Duration};

use futures::{
    channel::mpsc::{self, Receiver, Sender},
    future,
    sink::SinkExt,
    StreamExt, TryStreamExt,
};
use log::{debug, error};
use rasn_ldap::LdapMessage;
use tokio::{
    io::{AsyncRead, AsyncWrite},
    net::TcpStream,
};

use crate::{
    codec::LdapCodec,
    error::Error,
    options::{TlsKind, TlsOptions},
    TlsBackend,
};

const CHANNEL_SIZE: usize = 1024;
const CONNECT_TIMEOUT: Duration = Duration::from_secs(10);

pub type LdapMessageSender = Sender<LdapMessage>;
pub type LdapMessageReceiver = Receiver<LdapMessage>;

trait TlsStream: AsyncRead + AsyncWrite + Unpin + Send {}

#[cfg(feature = "tls-native-tls")]
impl<T: AsyncRead + AsyncWrite + Unpin + Send> TlsStream for tokio_native_tls::TlsStream<T> {}

#[cfg(feature = "tls-rustls")]
impl<T: AsyncRead + AsyncWrite + Unpin + Send> TlsStream for tokio_rustls::client::TlsStream<T> {}

fn io_error<E>(e: E) -> io::Error
where
    E: Into<Box<dyn std::error::Error + Send + Sync>>,
{
    io::Error::new(io::ErrorKind::InvalidData, e)
}

fn make_channel<S>(stream: S) -> (LdapMessageSender, LdapMessageReceiver)
where
    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
    // construct framed instance based on LdapCodec
    let framed = tokio_util::codec::Framed::new(stream, LdapCodec);

    // The 'in' channel:
    // Messages received from the socket will be forwarded to tx_in
    // and received by the external client via rx_in endpoint
    let (tx_in, rx_in) = mpsc::channel(CHANNEL_SIZE);

    // The 'out' channel:
    // Messages sent to tx_out by external clients will be picked up on rx_out endpoint
    // and forwarded to socket
    let (tx_out, rx_out) = mpsc::channel(CHANNEL_SIZE);

    let channel = async move {
        // sink is the sending part, stream is the receiving part
        let (mut sink, stream) = framed.split();

        // we receive LdapMessage messages from the clients and convert to stream chunks
        let mut rx = rx_out.map(Ok::<_, Error>);

        // app -> socket
        let to_wire = sink.send_all(&mut rx);

        // convert incoming channel errors into io::Error
        let mut tx = tx_in.sink_map_err(io_error);

        // app <- socket
        let from_wire = stream.map_err(io_error).forward(&mut tx);

        // await for either of futures: terminating one side will drop the other
        future::select(to_wire, from_wire).await;
    };

    // spawn in the background
    tokio::spawn(channel);

    // we return (tx_out, rx_in) pair so that the consumer can send and receive messages
    (tx_out, rx_in)
}

/// LDAP channel errors
#[derive(Debug, thiserror::Error)]
pub enum ChannelError {
    #[error(transparent)]
    IoError(#[from] io::Error),

    #[error(transparent)]
    ConnectTimeout(#[from] tokio::time::error::Elapsed),

    #[error("STARTTLS failed")]
    StartTlsFailed,

    #[cfg(feature = "tls-native-tls")]
    #[error(transparent)]
    NativeTls(#[from] native_tls::Error),

    #[cfg(feature = "tls-rustls")]
    #[error(transparent)]
    Rustls(#[from] rustls::Error),

    #[cfg(feature = "tls-rustls")]
    #[error(transparent)]
    DnsName(#[from] rustls_pki_types::InvalidDnsNameError),
}

pub type ChannelResult<T> = Result<T, ChannelError>;

/// LDAP TCP channel connector
pub struct LdapChannel {
    address: String,
    port: u16,
}

impl LdapChannel {
    /// Create a client-side channel with a given server address and port
    pub fn for_client<S>(address: S, port: u16) -> Self
    where
        S: AsRef<str>,
    {
        LdapChannel {
            address: address.as_ref().to_owned(),
            port,
        }
    }

    /// Connect to a server
    /// Returns a pair of (sender, receiver) endpoints
    pub async fn connect(self, tls_options: TlsOptions) -> ChannelResult<(LdapMessageSender, LdapMessageReceiver)> {
        let mut addrs = (self.address.as_ref(), self.port).to_socket_addrs()?;
        let address = addrs.next().ok_or_else(|| io_error("Address resolution error"))?;

        // TCP connect with a timeout
        let stream = tokio::time::timeout(CONNECT_TIMEOUT, TcpStream::connect(&address)).await??;

        debug!("Connection established to {}", address);

        let channel = match tls_options.kind {
            TlsKind::Plain => make_channel(stream),
            #[cfg(tls)]
            TlsKind::Tls => make_channel(self.tls_connect(tls_options, stream).await?),
            #[cfg(tls)]
            TlsKind::StartTls => make_channel(self.starttls_connect(tls_options, stream).await?),
        };
        Ok(channel)
    }

    async fn tls_connect<S>(&self, tls_options: TlsOptions, stream: S) -> ChannelResult<Box<dyn TlsStream>>
    where
        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
    {
        match tls_options.backend.unwrap_or_default() {
            #[cfg(feature = "tls-native-tls")]
            TlsBackend::Native(connector) => Ok(Box::new(
                self.tls_connect_native_tls(tls_options.domain_name, connector, stream)
                    .await?,
            )),
            #[cfg(feature = "tls-rustls")]
            TlsBackend::Rustls(client_config) => Ok(Box::new(
                self.tls_connect_rustls(tls_options.domain_name, client_config, stream)
                    .await?,
            )),
        }
    }

    #[cfg(tls)]
    async fn starttls_connect<S>(
        &self,
        tls_options: TlsOptions,
        mut stream: S,
    ) -> ChannelResult<impl AsyncRead + AsyncWrite + Unpin + Send>
    where
        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
    {
        use log::warn;
        use rasn_ldap::{ExtendedRequest, ProtocolOp, ResultCode};

        const STARTTLS_TIMEOUT: Duration = Duration::from_secs(30);

        debug!("Begin STARTTLS negotiation");
        let mut framed = tokio_util::codec::Framed::new(&mut stream, LdapCodec);
        let req = ExtendedRequest {
            request_name: crate::oid::STARTTLS_OID.into(),
            request_value: None,
        };
        framed
            .send(LdapMessage::new(1, ProtocolOp::ExtendedReq(req)))
            .await
            .map_err(|_| ChannelError::StartTlsFailed)?;
        match tokio::time::timeout(STARTTLS_TIMEOUT, framed.next()).await {
            Ok(Some(Ok(item))) => match item.protocol_op {
                ProtocolOp::ExtendedResp(resp) if resp.result_code == ResultCode::Success && item.message_id == 1 => {
                    debug!("End STARTTLS negotiation, switching protocols");
                    return self.tls_connect(tls_options, stream).await;
                }
                _ => {
                    warn!("STARTTLS negotiation failed");
                }
            },
            Err(_) => {
                warn!("Timeout occurred while waiting for STARTTLS reply");
            }
            _ => {
                warn!("Unexpected response while waiting for STARTTLS reply");
            }
        }
        Err(ChannelError::StartTlsFailed)
    }

    #[cfg(feature = "tls-native-tls")]
    async fn tls_connect_native_tls<S>(
        &self,
        domain_name: Option<String>,
        tls_connector: native_tls::TlsConnector,
        stream: S,
    ) -> ChannelResult<tokio_native_tls::TlsStream<S>>
    where
        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
    {
        let domain = domain_name.as_deref().unwrap_or(&self.address);

        debug!("Performing TLS handshake using native-tls, SNI: {}", domain);

        let tokio_connector = tokio_native_tls::TlsConnector::from(tls_connector);

        let stream = tokio_connector
            .connect(domain, stream)
            .await
            .map_err(ChannelError::NativeTls)?;

        debug!("TLS handshake succeeded!");

        Ok(stream)
    }

    #[cfg(feature = "tls-rustls")]
    async fn tls_connect_rustls<S>(
        &self,
        domain_name: Option<String>,
        client_config: rustls::ClientConfig,
        stream: S,
    ) -> ChannelResult<tokio_rustls::client::TlsStream<S>>
    where
        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
    {
        use rustls_pki_types::ServerName;
        use std::sync::Arc;

        let domain = ServerName::try_from(domain_name.as_deref().unwrap_or(&self.address).to_owned())?;

        debug!("Performing TLS handshake using rustls, SNI: {:?}", domain);

        let tokio_connector = tokio_rustls::TlsConnector::from(Arc::new(client_config));
        let stream = tokio_connector.connect(domain, stream).await?;

        debug!("TLS handshake succeeded!");

        Ok(stream)
    }
}

#[cfg(test)]
mod tests {
    use std::{
        net::{SocketAddr, ToSocketAddrs},
        sync::{
            atomic::{AtomicUsize, Ordering},
            Arc,
        },
    };

    use rasn_ldap::{ProtocolOp, UnbindRequest};
    use tokio::net::TcpListener;
    use tokio_util::codec::Framed;

    use super::*;

    fn new_msg() -> LdapMessage {
        LdapMessage::new(1, ProtocolOp::UnbindRequest(UnbindRequest))
    }

    async fn start_server(address: &SocketAddr, num_msgs: usize) {
        let tcp = TcpListener::bind(&address).await.unwrap();

        tokio::spawn(async move {
            if let Ok((stream, _)) = tcp.accept().await {
                let framed = Framed::new(stream, LdapCodec);
                let (mut sink, stream) = framed.split();
                sink.send_all(&mut stream.take(num_msgs)).await.unwrap();
            }
        });
    }

    #[tokio::test]
    async fn test_connection_success() {
        let address = ("127.0.0.1", 22561);

        let socket_address = address.to_socket_addrs().unwrap().next().unwrap();

        let counter = Arc::new(AtomicUsize::new(0));
        let flag = counter.clone();

        let res = {
            start_server(&socket_address, 2).await;

            let (mut sender, mut receiver) = LdapChannel::for_client(address.0, address.1)
                .connect(TlsOptions::default())
                .await
                .unwrap();
            let msg = new_msg();

            sender.send(msg.clone()).await.unwrap();
            sender.send(msg.clone()).await.unwrap();

            while let Some(m) = receiver.next().await {
                assert_eq!(msg, m);
                flag.fetch_add(1, Ordering::SeqCst);
            }
            Ok::<(), ()>(())
        };
        assert!(res.is_ok());
        assert_eq!(counter.load(Ordering::SeqCst), 2);
    }

    #[tokio::test]
    async fn test_connection_fail() {
        let res = LdapChannel::for_client("127.0.0.1", 32222)
            .connect(TlsOptions::default())
            .await;

        assert!(res.is_err());
    }
}