#![warn(missing_docs)]
#![doc(html_root_url="https://docs.rs/hyper-native-tls/0.3")]
extern crate antidote;
extern crate hyper;
pub extern crate native_tls;
use antidote::Mutex;
use hyper::net::{SslClient, SslServer, NetworkStream};
use native_tls::{TlsAcceptor, TlsConnector, Identity};
use std::net::SocketAddr;
use std::time::Duration;
use std::error::Error;
use std::io::{self, Read};
use std::fs::File;
use std::ops::{Deref, DerefMut};
use std::sync::Arc;
use std::fmt;
use std::path::Path;
pub use native_tls::Certificate;
#[derive(Debug, Clone)]
pub struct TlsStream<S>(Arc<Mutex<native_tls::TlsStream<S>>>);
impl<S> TlsStream<S>
where S: io::Read + io::Write
{
pub fn lock(&self) -> StreamGuard<S> {
StreamGuard(self.0.lock())
}
}
impl<S> io::Read for TlsStream<S>
where S: io::Read + io::Write
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.lock().read(buf)
}
}
impl<S> io::Write for TlsStream<S>
where S: io::Read + io::Write
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.lock().write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.lock().flush()
}
}
impl<S> NetworkStream for TlsStream<S>
where S: NetworkStream
{
fn peer_addr(&mut self) -> io::Result<SocketAddr> {
self.0.lock().get_mut().peer_addr()
}
fn set_read_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.0.lock().get_mut().set_read_timeout(dur)
}
fn set_write_timeout(&self, dur: Option<Duration>) -> io::Result<()> {
self.0.lock().get_mut().set_write_timeout(dur)
}
}
pub struct StreamGuard<'a, T: io::Read + io::Write + 'a>(antidote::MutexGuard<'a, native_tls::TlsStream<T>>);
impl<'a, T> Deref for StreamGuard<'a, T>
where T: io::Read + io::Write + 'a
{
type Target = native_tls::TlsStream<T>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<'a, T> DerefMut for StreamGuard<'a, T>
where T: io::Read + io::Write + 'a
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
pub struct NativeTlsClient {
connector: TlsConnector,
}
impl NativeTlsClient {
pub fn new() -> native_tls::Result<NativeTlsClient> {
TlsConnector::builder().build()
.map(NativeTlsClient::from)
}
}
impl From<TlsConnector> for NativeTlsClient {
fn from(t: TlsConnector) -> NativeTlsClient {
NativeTlsClient {
connector: t,
}
}
}
impl<T> SslClient<T> for NativeTlsClient
where T: NetworkStream + Send + Clone + fmt::Debug + Sync
{
type Stream = TlsStream<T>;
fn wrap_client(&self, stream: T, host: &str) -> hyper::Result<TlsStream<T>> {
let stream = self.connector.connect(host, stream);
match stream {
Ok(s) => Ok(TlsStream(Arc::new(Mutex::new(s)))),
Err(e) => Err(hyper::Error::Ssl(Box::new(e))),
}
}
}
#[derive(Clone)]
pub struct NativeTlsServer(Arc<TlsAcceptor>);
impl NativeTlsServer {
pub fn new<P>(identity: P, password: &str) -> Result<NativeTlsServer, ServerError>
where P: AsRef<Path>
{
let mut buf = vec![];
try!(File::open(identity)
.and_then(|mut f| f.read_to_end(&mut buf))
.map_err(ServerError::Io));
let identity = try!(Identity::from_pkcs12(&buf, password).map_err(ServerError::Tls));
let acceptor = try!(TlsAcceptor::builder(identity)
.build()
.map_err(ServerError::Tls));
Ok(acceptor.into())
}
}
impl From<TlsAcceptor> for NativeTlsServer {
fn from(t: TlsAcceptor) -> NativeTlsServer {
NativeTlsServer(Arc::new(t))
}
}
impl<T> SslServer<T> for NativeTlsServer
where T: NetworkStream + Send + Clone + fmt::Debug + Sync
{
type Stream = TlsStream<T>;
fn wrap_server(&self, stream: T) -> hyper::Result<TlsStream<T>> {
match self.0.accept(stream) {
Ok(s) => Ok(TlsStream(Arc::new(Mutex::new(s)))),
Err(e) => Err(hyper::Error::Ssl(Box::new(e))),
}
}
}
#[derive(Debug)]
pub enum ServerError {
Io(io::Error),
Tls(native_tls::Error),
}
impl fmt::Display for ServerError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
try!(fmt.write_str(self.description()));
match *self {
ServerError::Io(ref e) => write!(fmt, ": {}", e),
ServerError::Tls(ref e) => write!(fmt, ": {}", e),
}
}
}
impl Error for ServerError {
fn description(&self) -> &str {
match *self {
ServerError::Io(_) => "error reading identity",
ServerError::Tls(_) => "error initializing acceptor",
}
}
fn cause(&self) -> Option<&Error> {
match *self {
ServerError::Io(ref e) => Some(e),
ServerError::Tls(ref e) => Some(e),
}
}
}
#[cfg(test)]
mod test {
use hyper::{Client, Server};
use hyper::server::{Request, Response, Fresh};
use hyper::net::HttpsConnector;
use std::fs::File;
use std::io::Read;
use std::mem;
use super::*;
#[test]
fn client() {
let ssl = NativeTlsClient::new().unwrap();
let connector = HttpsConnector::new(ssl);
let client = Client::with_connector(connector);
let mut resp = client.get("https://google.com").send().unwrap();
assert!(resp.status.is_success());
let mut body = vec![];
resp.read_to_end(&mut body).unwrap();
}
#[test]
fn server() {
let ssl = NativeTlsServer::new("test/identity.p12", "mypass").unwrap();
let server = Server::https("127.0.0.1:0", ssl).unwrap();
let listening = server
.handle(|_: Request, resp: Response<Fresh>| resp.send(b"hello").unwrap())
.unwrap();
let port = listening.socket.port();
mem::forget(listening);
let mut buf = Vec::new();
let _ = File::open("test/root-ca.der")
.unwrap()
.read_to_end(&mut buf)
.unwrap();
let cert = Certificate::from_der(&buf).unwrap();
let mut tls_connector_builder = TlsConnector::builder();
tls_connector_builder.add_root_certificate(cert);
let tls_connector = tls_connector_builder.build().unwrap();
let native_tls_client = NativeTlsClient::from(tls_connector);
let connector = HttpsConnector::new(native_tls_client);
let client = Client::with_connector(connector);
let mut resp = client
.get(&format!("https://localhost:{}", port))
.send()
.unwrap();
let mut body = vec![];
resp.read_to_end(&mut body).unwrap();
assert_eq!(body, b"hello");
}
}