tikv_client/common/
security.rs

1// Copyright 2018 TiKV Project Authors. Licensed under Apache-2.0.
2
3use std::fs::File;
4use std::io::Read;
5use std::path::Path;
6use std::path::PathBuf;
7use std::time::Duration;
8
9use log::info;
10use regex::Regex;
11use tonic::transport::Certificate;
12use tonic::transport::Channel;
13use tonic::transport::ClientTlsConfig;
14use tonic::transport::Identity;
15
16use crate::internal_err;
17use crate::Result;
18
19lazy_static::lazy_static! {
20    static ref SCHEME_REG: Regex = Regex::new(r"^\s*(https?://)").unwrap();
21}
22
23fn check_pem_file(tag: &str, path: &Path) -> Result<File> {
24    File::open(path)
25        .map_err(|e| internal_err!("failed to open {} to load {}: {:?}", path.display(), tag, e))
26}
27
28fn load_pem_file(tag: &str, path: &Path) -> Result<Vec<u8>> {
29    let mut file = check_pem_file(tag, path)?;
30    let mut key = vec![];
31    file.read_to_end(&mut key)
32        .map_err(|e| {
33            internal_err!(
34                "failed to load {} from path {}: {:?}",
35                tag,
36                path.display(),
37                e
38            )
39        })
40        .map(|_| key)
41}
42
43/// Manages the TLS protocol
44#[derive(Default)]
45pub struct SecurityManager {
46    /// The PEM encoding of the server’s CA certificates.
47    ca: Vec<u8>,
48    /// The PEM encoding of the server’s certificate chain.
49    cert: Vec<u8>,
50    /// The path to the file that contains the PEM encoding of the server’s private key.
51    key: PathBuf,
52}
53
54impl SecurityManager {
55    /// Load TLS configuration from files.
56    pub fn load(
57        ca_path: impl AsRef<Path>,
58        cert_path: impl AsRef<Path>,
59        key_path: impl Into<PathBuf>,
60    ) -> Result<SecurityManager> {
61        let key_path = key_path.into();
62        check_pem_file("private key", &key_path)?;
63        Ok(SecurityManager {
64            ca: load_pem_file("ca", ca_path.as_ref())?,
65            cert: load_pem_file("certificate", cert_path.as_ref())?,
66            key: key_path,
67        })
68    }
69
70    /// Connect to gRPC server using TLS connection. If TLS is not configured, use normal connection.
71    pub async fn connect<Factory, Client>(
72        &self,
73        // env: Arc<Environment>,
74        addr: &str,
75        factory: Factory,
76    ) -> Result<Client>
77    where
78        Factory: FnOnce(Channel) -> Client,
79    {
80        let addr = "http://".to_string() + &SCHEME_REG.replace(addr, "");
81
82        info!("connect to rpc server at endpoint: {:?}", addr);
83
84        let mut builder = Channel::from_shared(addr)?
85            .tcp_keepalive(Some(Duration::from_secs(10)))
86            .keep_alive_timeout(Duration::from_secs(3));
87
88        if !self.ca.is_empty() {
89            let tls = ClientTlsConfig::new()
90                .ca_certificate(Certificate::from_pem(&self.ca))
91                .identity(Identity::from_pem(
92                    &self.cert,
93                    load_pem_file("private key", &self.key)?,
94                ));
95            builder = builder.tls_config(tls)?;
96        };
97
98        let ch = builder.connect().await?;
99
100        Ok(factory(ch))
101    }
102}
103
104#[cfg(test)]
105mod tests {
106    use std::fs::File;
107    use std::io::Write;
108    use std::path::PathBuf;
109
110    use tempfile;
111
112    use super::*;
113
114    #[test]
115    fn test_security() {
116        let temp = tempfile::tempdir().unwrap();
117        let example_ca = temp.path().join("ca");
118        let example_cert = temp.path().join("cert");
119        let example_pem = temp.path().join("key");
120        for (id, f) in [&example_ca, &example_cert, &example_pem]
121            .iter()
122            .enumerate()
123        {
124            File::create(f).unwrap().write_all(&[id as u8]).unwrap();
125        }
126        let cert_path: PathBuf = format!("{}", example_cert.display()).into();
127        let key_path: PathBuf = format!("{}", example_pem.display()).into();
128        let ca_path: PathBuf = format!("{}", example_ca.display()).into();
129        let mgr = SecurityManager::load(ca_path, cert_path, &key_path).unwrap();
130        assert_eq!(mgr.ca, vec![0]);
131        assert_eq!(mgr.cert, vec![1]);
132        let key = load_pem_file("private key", &key_path).unwrap();
133        assert_eq!(key, vec![2]);
134    }
135}