tikv_client/common/
security.rs1use std::fs::File;
4use std::io::Read;
5use std::path::Path;
6use std::path::PathBuf;
7use std::time::Duration;
8
9use log::info;
10use regex::Regex;
11use tonic::transport::Certificate;
12use tonic::transport::Channel;
13use tonic::transport::ClientTlsConfig;
14use tonic::transport::Identity;
15
16use crate::internal_err;
17use crate::Result;
18
19lazy_static::lazy_static! {
20 static ref SCHEME_REG: Regex = Regex::new(r"^\s*(https?://)").unwrap();
21}
22
23fn check_pem_file(tag: &str, path: &Path) -> Result<File> {
24 File::open(path)
25 .map_err(|e| internal_err!("failed to open {} to load {}: {:?}", path.display(), tag, e))
26}
27
28fn load_pem_file(tag: &str, path: &Path) -> Result<Vec<u8>> {
29 let mut file = check_pem_file(tag, path)?;
30 let mut key = vec![];
31 file.read_to_end(&mut key)
32 .map_err(|e| {
33 internal_err!(
34 "failed to load {} from path {}: {:?}",
35 tag,
36 path.display(),
37 e
38 )
39 })
40 .map(|_| key)
41}
42
43#[derive(Default)]
45pub struct SecurityManager {
46 ca: Vec<u8>,
48 cert: Vec<u8>,
50 key: PathBuf,
52}
53
54impl SecurityManager {
55 pub fn load(
57 ca_path: impl AsRef<Path>,
58 cert_path: impl AsRef<Path>,
59 key_path: impl Into<PathBuf>,
60 ) -> Result<SecurityManager> {
61 let key_path = key_path.into();
62 check_pem_file("private key", &key_path)?;
63 Ok(SecurityManager {
64 ca: load_pem_file("ca", ca_path.as_ref())?,
65 cert: load_pem_file("certificate", cert_path.as_ref())?,
66 key: key_path,
67 })
68 }
69
70 pub async fn connect<Factory, Client>(
72 &self,
73 addr: &str,
75 factory: Factory,
76 ) -> Result<Client>
77 where
78 Factory: FnOnce(Channel) -> Client,
79 {
80 let addr = "http://".to_string() + &SCHEME_REG.replace(addr, "");
81
82 info!("connect to rpc server at endpoint: {:?}", addr);
83
84 let mut builder = Channel::from_shared(addr)?
85 .tcp_keepalive(Some(Duration::from_secs(10)))
86 .keep_alive_timeout(Duration::from_secs(3));
87
88 if !self.ca.is_empty() {
89 let tls = ClientTlsConfig::new()
90 .ca_certificate(Certificate::from_pem(&self.ca))
91 .identity(Identity::from_pem(
92 &self.cert,
93 load_pem_file("private key", &self.key)?,
94 ));
95 builder = builder.tls_config(tls)?;
96 };
97
98 let ch = builder.connect().await?;
99
100 Ok(factory(ch))
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use std::fs::File;
107 use std::io::Write;
108 use std::path::PathBuf;
109
110 use tempfile;
111
112 use super::*;
113
114 #[test]
115 fn test_security() {
116 let temp = tempfile::tempdir().unwrap();
117 let example_ca = temp.path().join("ca");
118 let example_cert = temp.path().join("cert");
119 let example_pem = temp.path().join("key");
120 for (id, f) in [&example_ca, &example_cert, &example_pem]
121 .iter()
122 .enumerate()
123 {
124 File::create(f).unwrap().write_all(&[id as u8]).unwrap();
125 }
126 let cert_path: PathBuf = format!("{}", example_cert.display()).into();
127 let key_path: PathBuf = format!("{}", example_pem.display()).into();
128 let ca_path: PathBuf = format!("{}", example_ca.display()).into();
129 let mgr = SecurityManager::load(ca_path, cert_path, &key_path).unwrap();
130 assert_eq!(mgr.ca, vec![0]);
131 assert_eq!(mgr.cert, vec![1]);
132 let key = load_pem_file("private key", &key_path).unwrap();
133 assert_eq!(key, vec![2]);
134 }
135}