use std::fs::File;
use std::io::Read;
use std::path::Path;
use std::path::PathBuf;
use std::time::Duration;
use log::info;
use regex::Regex;
use tonic::transport::Certificate;
use tonic::transport::Channel;
use tonic::transport::ClientTlsConfig;
use tonic::transport::Identity;
use crate::internal_err;
use crate::Result;
lazy_static::lazy_static! {
static ref SCHEME_REG: Regex = Regex::new(r"^\s*(https?://)").unwrap();
}
fn check_pem_file(tag: &str, path: &Path) -> Result<File> {
File::open(path)
.map_err(|e| internal_err!("failed to open {} to load {}: {:?}", path.display(), tag, e))
}
fn load_pem_file(tag: &str, path: &Path) -> Result<Vec<u8>> {
let mut file = check_pem_file(tag, path)?;
let mut key = vec![];
file.read_to_end(&mut key)
.map_err(|e| {
internal_err!(
"failed to load {} from path {}: {:?}",
tag,
path.display(),
e
)
})
.map(|_| key)
}
#[derive(Default)]
pub struct SecurityManager {
ca: Vec<u8>,
cert: Vec<u8>,
key: PathBuf,
}
impl SecurityManager {
pub fn load(
ca_path: impl AsRef<Path>,
cert_path: impl AsRef<Path>,
key_path: impl Into<PathBuf>,
) -> Result<SecurityManager> {
let key_path = key_path.into();
check_pem_file("private key", &key_path)?;
Ok(SecurityManager {
ca: load_pem_file("ca", ca_path.as_ref())?,
cert: load_pem_file("certificate", cert_path.as_ref())?,
key: key_path,
})
}
pub async fn connect<Factory, Client>(
&self,
addr: &str,
factory: Factory,
) -> Result<Client>
where
Factory: FnOnce(Channel) -> Client,
{
let addr = "http://".to_string() + &SCHEME_REG.replace(addr, "");
info!("connect to rpc server at endpoint: {:?}", addr);
let mut builder = Channel::from_shared(addr)?
.tcp_keepalive(Some(Duration::from_secs(10)))
.keep_alive_timeout(Duration::from_secs(3));
if !self.ca.is_empty() {
let tls = ClientTlsConfig::new()
.ca_certificate(Certificate::from_pem(&self.ca))
.identity(Identity::from_pem(
&self.cert,
load_pem_file("private key", &self.key)?,
));
builder = builder.tls_config(tls)?;
};
let ch = builder.connect().await?;
Ok(factory(ch))
}
}
#[cfg(test)]
mod tests {
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
use tempfile;
use super::*;
#[test]
fn test_security() {
let temp = tempfile::tempdir().unwrap();
let example_ca = temp.path().join("ca");
let example_cert = temp.path().join("cert");
let example_pem = temp.path().join("key");
for (id, f) in [&example_ca, &example_cert, &example_pem]
.iter()
.enumerate()
{
File::create(f).unwrap().write_all(&[id as u8]).unwrap();
}
let cert_path: PathBuf = format!("{}", example_cert.display()).into();
let key_path: PathBuf = format!("{}", example_pem.display()).into();
let ca_path: PathBuf = format!("{}", example_ca.display()).into();
let mgr = SecurityManager::load(ca_path, cert_path, &key_path).unwrap();
assert_eq!(mgr.ca, vec![0]);
assert_eq!(mgr.cert, vec![1]);
let key = load_pem_file("private key", &key_path).unwrap();
assert_eq!(key, vec![2]);
}
}