pingora_core/connectors/
mod.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
// Copyright 2024 Cloudflare, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//! Connecting to servers

pub mod http;
pub mod l4;
mod offload;

#[cfg(feature = "any_tls")]
mod tls;

#[cfg(not(feature = "any_tls"))]
use crate::tls::connectors as tls;

use crate::protocols::Stream;
use crate::server::configuration::ServerConf;
use crate::upstreams::peer::{Peer, ALPN};

pub use l4::Connect as L4Connect;
use l4::{connect as l4_connect, BindTo};
use log::{debug, error, warn};
use offload::OffloadRuntime;
use parking_lot::RwLock;
use pingora_error::{Error, ErrorType::*, OrErr, Result};
use pingora_pool::{ConnectionMeta, ConnectionPool};
use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc;
use tls::TlsConnector;
use tokio::sync::Mutex;

/// The options to configure a [TransportConnector]
#[derive(Clone)]
pub struct ConnectorOptions {
    /// Path to the CA file used to validate server certs.
    ///
    /// If `None`, the CA in the [default](https://www.openssl.org/docs/manmaster/man3/SSL_CTX_set_default_verify_paths.html)
    /// locations will be loaded
    pub ca_file: Option<String>,
    /// The default client cert and key to use for mTLS
    ///
    /// Each individual connection can use their own cert key to override this.
    pub cert_key_file: Option<(String, String)>,
    /// When enabled allows TLS keys to be written to a file specified by the SSLKEYLOG
    /// env variable. This can be used by tools like Wireshark to decrypt traffic
    /// for debugging purposes.
    pub debug_ssl_keylog: bool,
    /// How many connections to keepalive
    pub keepalive_pool_size: usize,
    /// Optionally offload the connection establishment to dedicated thread pools
    ///
    /// TCP and TLS connection establishment can be CPU intensive. Sometimes such tasks can slow
    /// down the entire service, which causes timeouts which leads to more connections which
    /// snowballs the issue. Use this option to isolate these CPU intensive tasks from impacting
    /// other traffic.
    ///
    /// Syntax: (#pools, #thread in each pool)
    pub offload_threadpool: Option<(usize, usize)>,
    /// Bind to any of the given source IPv6 addresses
    pub bind_to_v4: Vec<SocketAddr>,
    /// Bind to any of the given source IPv4 addresses
    pub bind_to_v6: Vec<SocketAddr>,
}

impl ConnectorOptions {
    /// Derive the [ConnectorOptions] from a [ServerConf]
    pub fn from_server_conf(server_conf: &ServerConf) -> Self {
        // if both pools and threads are Some(>0)
        let offload_threadpool = server_conf
            .upstream_connect_offload_threadpools
            .zip(server_conf.upstream_connect_offload_thread_per_pool)
            .filter(|(pools, threads)| *pools > 0 && *threads > 0);

        // create SocketAddrs with port 0 for src addr bind

        let bind_to_v4 = server_conf
            .client_bind_to_ipv4
            .iter()
            .map(|v4| {
                let ip = v4.parse().unwrap();
                SocketAddr::new(ip, 0)
            })
            .collect();

        let bind_to_v6 = server_conf
            .client_bind_to_ipv6
            .iter()
            .map(|v6| {
                let ip = v6.parse().unwrap();
                SocketAddr::new(ip, 0)
            })
            .collect();
        ConnectorOptions {
            ca_file: server_conf.ca_file.clone(),
            cert_key_file: None, // TODO: use it
            debug_ssl_keylog: server_conf.upstream_debug_ssl_keylog,
            keepalive_pool_size: server_conf.upstream_keepalive_pool_size,
            offload_threadpool,
            bind_to_v4,
            bind_to_v6,
        }
    }

    /// Create a new [ConnectorOptions] with the given keepalive pool size
    pub fn new(keepalive_pool_size: usize) -> Self {
        ConnectorOptions {
            ca_file: None,
            cert_key_file: None,
            debug_ssl_keylog: false,
            keepalive_pool_size,
            offload_threadpool: None,
            bind_to_v4: vec![],
            bind_to_v6: vec![],
        }
    }
}

/// [TransportConnector] provides APIs to connect to servers via TCP or TLS with connection reuse
pub struct TransportConnector {
    tls_ctx: tls::Connector,
    connection_pool: Arc<ConnectionPool<Arc<Mutex<Stream>>>>,
    offload: Option<OffloadRuntime>,
    bind_to_v4: Vec<SocketAddr>,
    bind_to_v6: Vec<SocketAddr>,
    preferred_http_version: PreferredHttpVersion,
}

const DEFAULT_POOL_SIZE: usize = 128;

impl TransportConnector {
    /// Create a new [TransportConnector] with the given [ConnectorOptions]
    pub fn new(mut options: Option<ConnectorOptions>) -> Self {
        let pool_size = options
            .as_ref()
            .map_or(DEFAULT_POOL_SIZE, |c| c.keepalive_pool_size);
        // Take the offloading setting there because this layer has implement offloading,
        // so no need for stacks at lower layer to offload again.
        let offload = options.as_mut().and_then(|o| o.offload_threadpool.take());
        let bind_to_v4 = options
            .as_ref()
            .map_or_else(Vec::new, |o| o.bind_to_v4.clone());
        let bind_to_v6 = options
            .as_ref()
            .map_or_else(Vec::new, |o| o.bind_to_v6.clone());
        TransportConnector {
            tls_ctx: tls::Connector::new(options),
            connection_pool: Arc::new(ConnectionPool::new(pool_size)),
            offload: offload.map(|v| OffloadRuntime::new(v.0, v.1)),
            bind_to_v4,
            bind_to_v6,
            preferred_http_version: PreferredHttpVersion::new(),
        }
    }

    /// Connect to the given server [Peer]
    ///
    /// No connection is reused.
    pub async fn new_stream<P: Peer + Send + Sync + 'static>(&self, peer: &P) -> Result<Stream> {
        let rt = self
            .offload
            .as_ref()
            .map(|o| o.get_runtime(peer.reuse_hash()));
        let bind_to = l4::bind_to_random(peer, &self.bind_to_v4, &self.bind_to_v6);
        let alpn_override = self.preferred_http_version.get(peer);
        let stream = if let Some(rt) = rt {
            let peer = peer.clone();
            let tls_ctx = self.tls_ctx.clone();
            rt.spawn(async move { do_connect(&peer, bind_to, alpn_override, &tls_ctx.ctx).await })
                .await
                .or_err(InternalError, "offload runtime failure")??
        } else {
            do_connect(peer, bind_to, alpn_override, &self.tls_ctx.ctx).await?
        };

        Ok(stream)
    }

    /// Try to find a reusable connection to the given server [Peer]
    pub async fn reused_stream<P: Peer + Send + Sync>(&self, peer: &P) -> Option<Stream> {
        match self.connection_pool.get(&peer.reuse_hash()) {
            Some(s) => {
                debug!("find reusable stream, trying to acquire it");
                {
                    let _ = s.lock().await;
                } // wait for the idle poll to release it
                match Arc::try_unwrap(s) {
                    Ok(l) => {
                        let mut stream = l.into_inner();
                        // test_reusable_stream: we assume server would never actively send data
                        // first on an idle stream.
                        #[cfg(unix)]
                        if peer.matches_fd(stream.id()) && test_reusable_stream(&mut stream) {
                            Some(stream)
                        } else {
                            None
                        }
                        #[cfg(windows)]
                        {
                            use std::os::windows::io::{AsRawSocket, RawSocket};
                            struct WrappedRawSocket(RawSocket);
                            impl AsRawSocket for WrappedRawSocket {
                                fn as_raw_socket(&self) -> RawSocket {
                                    self.0
                                }
                            }
                            if peer.matches_sock(WrappedRawSocket(stream.id() as RawSocket))
                                && test_reusable_stream(&mut stream)
                            {
                                Some(stream)
                            } else {
                                None
                            }
                        }
                    }
                    Err(_) => {
                        error!("failed to acquire reusable stream");
                        None
                    }
                }
            }
            None => {
                debug!("No reusable connection found for {peer}");
                None
            }
        }
    }

    /// Return the [Stream] to the [TransportConnector] for connection reuse.
    ///
    /// Not all TCP/TLS connections can be reused. It is the caller's responsibility to make sure
    /// that protocol over the [Stream] supports connection reuse and the [Stream] itself is ready
    /// to be reused.
    ///
    /// If a [Stream] is dropped instead of being returned via this function. it will be closed.
    pub fn release_stream(
        &self,
        mut stream: Stream,
        key: u64, // usually peer.reuse_hash()
        idle_timeout: Option<std::time::Duration>,
    ) {
        if !test_reusable_stream(&mut stream) {
            return;
        }
        let id = stream.id();
        let meta = ConnectionMeta::new(key, id);
        debug!("Try to keepalive client session");
        let stream = Arc::new(Mutex::new(stream));
        let locked_stream = stream.clone().try_lock_owned().unwrap(); // safe as we just created it
        let (notify_close, watch_use) = self.connection_pool.put(&meta, stream);
        let pool = self.connection_pool.clone(); //clone the arc
        let rt = pingora_runtime::current_handle();
        rt.spawn(async move {
            pool.idle_poll(locked_stream, &meta, idle_timeout, notify_close, watch_use)
                .await;
        });
    }

    /// Get a stream to the given server [Peer]
    ///
    /// This function will try to find a reusable [Stream] first. If there is none, a new connection
    /// will be made to the server.
    ///
    /// The returned boolean will indicate whether the stream is reused.
    pub async fn get_stream<P: Peer + Send + Sync + 'static>(
        &self,
        peer: &P,
    ) -> Result<(Stream, bool)> {
        let reused_stream = self.reused_stream(peer).await;
        if let Some(s) = reused_stream {
            Ok((s, true))
        } else {
            let s = self.new_stream(peer).await?;
            Ok((s, false))
        }
    }

    /// Tell the connector to always send h1 for ALPN for the given peer in the future.
    pub fn prefer_h1(&self, peer: &impl Peer) {
        self.preferred_http_version.add(peer, 1);
    }
}

// Perform the actual L4 and tls connection steps while respecting the peer's
// connection timeout if there is one
async fn do_connect<P: Peer + Send + Sync>(
    peer: &P,
    bind_to: Option<BindTo>,
    alpn_override: Option<ALPN>,
    tls_ctx: &TlsConnector,
) -> Result<Stream> {
    // Create the future that does the connections, but don't evaluate it until
    // we decide if we need a timeout or not
    let connect_future = do_connect_inner(peer, bind_to, alpn_override, tls_ctx);

    match peer.total_connection_timeout() {
        Some(t) => match pingora_timeout::timeout(t, connect_future).await {
            Ok(res) => res,
            Err(_) => Error::e_explain(
                ConnectTimedout,
                format!("connecting to server {peer}, total-connection timeout {t:?}"),
            ),
        },
        None => connect_future.await,
    }
}

// Perform the actual L4 and tls connection steps with no timeout
async fn do_connect_inner<P: Peer + Send + Sync>(
    peer: &P,
    bind_to: Option<BindTo>,
    alpn_override: Option<ALPN>,
    tls_ctx: &TlsConnector,
) -> Result<Stream> {
    let stream = l4_connect(peer, bind_to).await?;
    if peer.tls() {
        let tls_stream = tls::connect(stream, peer, alpn_override, tls_ctx).await?;
        Ok(Box::new(tls_stream))
    } else {
        Ok(Box::new(stream))
    }
}

struct PreferredHttpVersion {
    // TODO: shard to avoid the global lock
    versions: RwLock<HashMap<u64, u8>>, // <hash of peer, version>
}

// TODO: limit the size of this

impl PreferredHttpVersion {
    pub fn new() -> Self {
        PreferredHttpVersion {
            versions: RwLock::default(),
        }
    }

    pub fn add(&self, peer: &impl Peer, version: u8) {
        let key = peer.reuse_hash();
        let mut v = self.versions.write();
        v.insert(key, version);
    }

    pub fn get(&self, peer: &impl Peer) -> Option<ALPN> {
        let key = peer.reuse_hash();
        let v = self.versions.read();
        v.get(&key)
            .copied()
            .map(|v| if v == 1 { ALPN::H1 } else { ALPN::H2H1 })
    }
}

use futures::future::FutureExt;
use tokio::io::AsyncReadExt;

/// Test whether a stream is already closed or not reusable (server sent unexpected data)
fn test_reusable_stream(stream: &mut Stream) -> bool {
    let mut buf = [0; 1];
    let result = stream.read(&mut buf[..]).now_or_never();
    if let Some(data_result) = result {
        match data_result {
            Ok(n) => {
                if n == 0 {
                    debug!("Idle connection is closed");
                } else {
                    warn!("Unexpected data read in idle connection");
                }
            }
            Err(e) => {
                debug!("Idle connection is broken: {e:?}");
            }
        }
        false
    } else {
        true
    }
}

#[cfg(test)]
#[cfg(feature = "any_tls")]
mod tests {
    use pingora_error::ErrorType;
    use tls::Connector;

    use super::*;
    use crate::upstreams::peer::BasicPeer;
    use tokio::io::AsyncWriteExt;
    #[cfg(unix)]
    use tokio::net::UnixListener;

    // 192.0.2.1 is effectively a black hole
    const BLACK_HOLE: &str = "192.0.2.1:79";

    #[tokio::test]
    async fn test_connect() {
        let connector = TransportConnector::new(None);
        let peer = BasicPeer::new("1.1.1.1:80");
        // make a new connection to 1.1.1.1
        let stream = connector.new_stream(&peer).await.unwrap();
        connector.release_stream(stream, peer.reuse_hash(), None);

        let (_, reused) = connector.get_stream(&peer).await.unwrap();
        assert!(reused);
    }

    #[tokio::test]
    async fn test_connect_tls() {
        let connector = TransportConnector::new(None);
        let mut peer = BasicPeer::new("1.1.1.1:443");
        // BasicPeer will use tls when SNI is set
        peer.sni = "one.one.one.one".to_string();
        // make a new connection to https://1.1.1.1
        let stream = connector.new_stream(&peer).await.unwrap();
        connector.release_stream(stream, peer.reuse_hash(), None);

        let (_, reused) = connector.get_stream(&peer).await.unwrap();
        assert!(reused);
    }

    #[cfg(unix)]
    const MOCK_UDS_PATH: &str = "/tmp/test_unix_transport_connector.sock";

    // one-off mock server
    #[cfg(unix)]
    async fn mock_connect_server() {
        let _ = std::fs::remove_file(MOCK_UDS_PATH);
        let listener = UnixListener::bind(MOCK_UDS_PATH).unwrap();
        if let Ok((mut stream, _addr)) = listener.accept().await {
            stream.write_all(b"it works!").await.unwrap();
            // wait a bit so that the client can read
            tokio::time::sleep(std::time::Duration::from_millis(100)).await;
        }
        let _ = std::fs::remove_file(MOCK_UDS_PATH);
    }
    #[tokio::test(flavor = "multi_thread")]
    async fn test_connect_uds() {
        tokio::spawn(async {
            mock_connect_server().await;
        });
        // create a new service at /tmp
        let connector = TransportConnector::new(None);
        let peer = BasicPeer::new_uds(MOCK_UDS_PATH).unwrap();
        // make a new connection to mock uds
        let mut stream = connector.new_stream(&peer).await.unwrap();
        let mut buf = [0; 9];
        let _ = stream.read(&mut buf).await.unwrap();
        assert_eq!(&buf, b"it works!");
        connector.release_stream(stream, peer.reuse_hash(), None);

        let (_, reused) = connector.get_stream(&peer).await.unwrap();
        assert!(reused);
    }

    async fn do_test_conn_timeout(conf: Option<ConnectorOptions>) {
        let connector = TransportConnector::new(conf);
        let mut peer = BasicPeer::new(BLACK_HOLE);
        peer.options.connection_timeout = Some(std::time::Duration::from_millis(1));
        let stream = connector.new_stream(&peer).await;
        match stream {
            Ok(_) => panic!("should throw an error"),
            Err(e) => assert_eq!(e.etype(), &ConnectTimedout),
        }
    }

    #[tokio::test]
    async fn test_conn_timeout() {
        do_test_conn_timeout(None).await;
    }

    #[tokio::test]
    async fn test_conn_timeout_with_offload() {
        let mut conf = ConnectorOptions::new(8);
        conf.offload_threadpool = Some((2, 2));
        do_test_conn_timeout(Some(conf)).await;
    }

    #[tokio::test]
    async fn test_connector_bind_to() {
        // connect to remote while bind to localhost will fail
        let peer = BasicPeer::new("240.0.0.1:80");
        let mut conf = ConnectorOptions::new(1);
        conf.bind_to_v4.push("127.0.0.1:0".parse().unwrap());
        let connector = TransportConnector::new(Some(conf));

        let stream = connector.new_stream(&peer).await;
        let error = stream.unwrap_err();
        // XXX: some systems will allow the socket to bind and connect without error, only to timeout
        assert!(error.etype() == &ConnectError || error.etype() == &ConnectTimedout)
    }

    /// Helper function for testing error handling in the `do_connect` function.
    /// This assumes that the connection will fail to on the peer and returns
    /// the decomposed error type and message
    async fn get_do_connect_failure_with_peer(peer: &BasicPeer) -> (ErrorType, String) {
        let tls_connector = Connector::new(None);
        let stream = do_connect(peer, None, None, &tls_connector.ctx).await;
        match stream {
            Ok(_) => panic!("should throw an error"),
            Err(e) => (
                e.etype().clone(),
                e.context
                    .as_ref()
                    .map(|ctx| ctx.as_str().to_owned())
                    .unwrap_or_default(),
            ),
        }
    }

    #[tokio::test]
    async fn test_do_connect_with_total_timeout() {
        let mut peer = BasicPeer::new(BLACK_HOLE);
        peer.options.total_connection_timeout = Some(std::time::Duration::from_millis(1));
        let (etype, context) = get_do_connect_failure_with_peer(&peer).await;
        assert_eq!(etype, ConnectTimedout);
        assert!(context.contains("total-connection timeout"));
    }

    #[tokio::test]
    async fn test_tls_connect_timeout_supersedes_total() {
        let mut peer = BasicPeer::new(BLACK_HOLE);
        peer.options.total_connection_timeout = Some(std::time::Duration::from_millis(10));
        peer.options.connection_timeout = Some(std::time::Duration::from_millis(1));
        let (etype, context) = get_do_connect_failure_with_peer(&peer).await;
        assert_eq!(etype, ConnectTimedout);
        assert!(!context.contains("total-connection timeout"));
    }

    #[tokio::test]
    async fn test_do_connect_without_total_timeout() {
        let peer = BasicPeer::new(BLACK_HOLE);
        let (etype, context) = get_do_connect_failure_with_peer(&peer).await;
        assert!(etype != ConnectTimedout || !context.contains("total-connection timeout"));
    }
}