#[allow(unused_imports)]
use core_foundation::array::{CFArray, CFArrayRef};
use core_foundation::base::{Boolean, TCFType};
#[cfg(feature = "alpn")]
use core_foundation::string::CFString;
use core_foundation_sys::base::{kCFAllocatorDefault, OSStatus};
use std::os::raw::c_void;
#[allow(unused_imports)]
use security_framework_sys::base::{
errSecBadReq, errSecIO, errSecNotTrusted, errSecSuccess, errSecTrustSettingDeny,
errSecUnimplemented,
};
use security_framework_sys::secure_transport::*;
use std::any::Any;
use std::cmp;
use std::fmt;
use std::io;
use std::io::prelude::*;
use std::marker::PhantomData;
use std::panic::{self, AssertUnwindSafe};
use std::ptr;
use std::result;
use std::slice;
use crate::base::{Error, Result};
use crate::certificate::SecCertificate;
use crate::cipher_suite::CipherSuite;
use crate::identity::SecIdentity;
use crate::import_export::Pkcs12ImportOptions;
use crate::policy::SecPolicy;
use crate::trust::SecTrust;
use crate::{cvt, AsInner};
use security_framework_sys::base::errSecParam;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct SslProtocolSide(SSLProtocolSide);
impl SslProtocolSide {
pub const SERVER: Self = Self(kSSLServerSide);
pub const CLIENT: Self = Self(kSSLClientSide);
}
#[derive(Debug, Copy, Clone)]
pub struct SslConnectionType(SSLConnectionType);
impl SslConnectionType {
pub const STREAM: Self = Self(kSSLStreamType);
pub const DATAGRAM: Self = Self(kSSLDatagramType);
}
#[derive(Debug)]
pub enum HandshakeError<S> {
Failure(Error),
Interrupted(MidHandshakeSslStream<S>),
}
impl<S> From<Error> for HandshakeError<S> {
#[inline(always)]
fn from(err: Error) -> Self {
Self::Failure(err)
}
}
#[derive(Debug)]
pub enum ClientHandshakeError<S> {
Failure(Error),
Interrupted(MidHandshakeClientBuilder<S>),
}
impl<S> From<Error> for ClientHandshakeError<S> {
#[inline(always)]
fn from(err: Error) -> Self {
Self::Failure(err)
}
}
#[derive(Debug)]
pub struct MidHandshakeSslStream<S> {
stream: SslStream<S>,
error: Error,
}
impl<S> MidHandshakeSslStream<S> {
#[inline(always)]
#[must_use]
pub fn get_ref(&self) -> &S {
self.stream.get_ref()
}
#[inline(always)]
pub fn get_mut(&mut self) -> &mut S {
self.stream.get_mut()
}
#[inline(always)]
#[must_use]
pub fn context(&self) -> &SslContext {
self.stream.context()
}
#[inline(always)]
pub fn context_mut(&mut self) -> &mut SslContext {
self.stream.context_mut()
}
#[inline(always)]
#[must_use]
pub fn server_auth_completed(&self) -> bool {
self.error.code() == errSSLPeerAuthCompleted
}
#[inline(always)]
#[must_use]
pub fn client_cert_requested(&self) -> bool {
self.error.code() == errSSLClientCertRequested
}
#[inline(always)]
#[must_use]
pub fn would_block(&self) -> bool {
self.error.code() == errSSLWouldBlock
}
#[inline(always)]
#[must_use]
pub fn error(&self) -> &Error {
&self.error
}
#[inline(always)]
pub fn handshake(self) -> result::Result<SslStream<S>, HandshakeError<S>> {
self.stream.handshake()
}
}
#[derive(Debug)]
pub struct MidHandshakeClientBuilder<S> {
stream: MidHandshakeSslStream<S>,
domain: Option<String>,
certs: Vec<SecCertificate>,
trust_certs_only: bool,
danger_accept_invalid_certs: bool,
}
impl<S> MidHandshakeClientBuilder<S> {
#[inline(always)]
#[must_use]
pub fn get_ref(&self) -> &S {
self.stream.get_ref()
}
#[inline(always)]
pub fn get_mut(&mut self) -> &mut S {
self.stream.get_mut()
}
#[inline(always)]
#[must_use]
pub fn error(&self) -> &Error {
self.stream.error()
}
pub fn handshake(self) -> result::Result<SslStream<S>, ClientHandshakeError<S>> {
let MidHandshakeClientBuilder {
stream,
domain,
certs,
trust_certs_only,
danger_accept_invalid_certs,
} = self;
let mut result = stream.handshake();
loop {
let stream = match result {
Ok(stream) => return Ok(stream),
Err(HandshakeError::Interrupted(stream)) => stream,
Err(HandshakeError::Failure(err)) => {
return Err(ClientHandshakeError::Failure(err))
}
};
if stream.would_block() {
let ret = MidHandshakeClientBuilder {
stream,
domain,
certs,
trust_certs_only,
danger_accept_invalid_certs,
};
return Err(ClientHandshakeError::Interrupted(ret));
}
if stream.server_auth_completed() {
if danger_accept_invalid_certs {
result = stream.handshake();
continue;
}
let mut trust = match stream.context().peer_trust2()? {
Some(trust) => trust,
None => {
result = stream.handshake();
continue;
}
};
trust.set_anchor_certificates(&certs)?;
trust.set_trust_anchor_certificates_only(self.trust_certs_only)?;
let policy = SecPolicy::create_ssl(SslProtocolSide::SERVER, domain.as_deref());
trust.set_policy(&policy)?;
trust.evaluate_with_error().map_err(|error| {
#[cfg(feature = "log")]
log::warn!("SecTrustEvaluateWithError: {}", error.to_string());
Error::from_code(error.code() as _)
})?;
result = stream.handshake();
continue;
}
let err = Error::from_code(stream.error().code());
return Err(ClientHandshakeError::Failure(err));
}
}
}
#[derive(Debug, PartialEq, Eq)]
pub struct SessionState(SSLSessionState);
impl SessionState {
pub const IDLE: Self = Self(kSSLIdle);
pub const HANDSHAKE: Self = Self(kSSLHandshake);
pub const CONNECTED: Self = Self(kSSLConnected);
pub const CLOSED: Self = Self(kSSLClosed);
pub const ABORTED: Self = Self(kSSLAborted);
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct SslAuthenticate(SSLAuthenticate);
impl SslAuthenticate {
pub const NEVER: Self = Self(kNeverAuthenticate);
pub const ALWAYS: Self = Self(kAlwaysAuthenticate);
pub const TRY: Self = Self(kTryAuthenticate);
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct SslClientCertificateState(SSLClientCertificateState);
impl SslClientCertificateState {
pub const NONE: Self = Self(kSSLClientCertNone);
pub const REQUESTED: Self = Self(kSSLClientCertRequested);
pub const SENT: Self = Self(kSSLClientCertSent);
pub const REJECTED: Self = Self(kSSLClientCertRejected);
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct SslProtocol(SSLProtocol);
impl SslProtocol {
pub const UNKNOWN: Self = Self(kSSLProtocolUnknown);
pub const SSL3: Self = Self(kSSLProtocol3);
pub const TLS1: Self = Self(kTLSProtocol1);
pub const TLS11: Self = Self(kTLSProtocol11);
pub const TLS12: Self = Self(kTLSProtocol12);
pub const TLS13: Self = Self(kTLSProtocol13);
pub const SSL2: Self = Self(kSSLProtocol2);
pub const DTLS1: Self = Self(kDTLSProtocol1);
pub const SSL3_ONLY: Self = Self(kSSLProtocol3Only);
pub const TLS1_ONLY: Self = Self(kTLSProtocol1Only);
pub const ALL: Self = Self(kSSLProtocolAll);
}
declare_TCFType! {
SslContext, SSLContextRef
}
impl_TCFType!(SslContext, SSLContextRef, SSLContextGetTypeID);
impl fmt::Debug for SslContext {
#[cold]
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut builder = fmt.debug_struct("SslContext");
if let Ok(state) = self.state() {
builder.field("state", &state);
}
builder.finish()
}
}
unsafe impl Sync for SslContext {}
unsafe impl Send for SslContext {}
impl AsInner for SslContext {
type Inner = SSLContextRef;
#[inline(always)]
fn as_inner(&self) -> SSLContextRef {
self.0
}
}
macro_rules! impl_options {
($($(#[$a:meta])* const $opt:ident: $get:ident & $set:ident,)*) => {
$(
$(#[$a])*
#[inline(always)]
pub fn $set(&mut self, value: bool) -> Result<()> {
unsafe { cvt(SSLSetSessionOption(self.0, $opt, value as Boolean)) }
}
$(#[$a])*
#[inline]
pub fn $get(&self) -> Result<bool> {
let mut value = 0;
unsafe { cvt(SSLGetSessionOption(self.0, $opt, &mut value))?; }
Ok(value != 0)
}
)*
}
}
impl SslContext {
#[inline]
pub fn new(side: SslProtocolSide, type_: SslConnectionType) -> Result<Self> {
unsafe {
let ctx = SSLCreateContext(kCFAllocatorDefault, side.0, type_.0);
Ok(Self(ctx))
}
}
#[inline]
pub fn set_peer_domain_name(&mut self, peer_name: &str) -> Result<()> {
unsafe {
cvt(SSLSetPeerDomainName(
self.0,
peer_name.as_ptr().cast(),
peer_name.len(),
))
}
}
pub fn peer_domain_name(&self) -> Result<String> {
unsafe {
let mut len = 0;
cvt(SSLGetPeerDomainNameLength(self.0, &mut len))?;
let mut buf = vec![0; len];
cvt(SSLGetPeerDomainName(
self.0,
buf.as_mut_ptr().cast(),
&mut len,
))?;
Ok(String::from_utf8(buf).unwrap())
}
}
pub fn set_certificate(
&mut self,
identity: &SecIdentity,
certs: &[SecCertificate],
) -> Result<()> {
let mut arr = vec![identity.as_CFType()];
arr.extend(certs.iter().map(|c| c.as_CFType()));
let certs = CFArray::from_CFTypes(&arr);
unsafe { cvt(SSLSetCertificate(self.0, certs.as_concrete_TypeRef())) }
}
#[inline]
pub fn set_peer_id(&mut self, peer_id: &[u8]) -> Result<()> {
unsafe { cvt(SSLSetPeerID(self.0, peer_id.as_ptr().cast(), peer_id.len())) }
}
pub fn peer_id(&self) -> Result<Option<&[u8]>> {
unsafe {
let mut ptr = ptr::null();
let mut len = 0;
cvt(SSLGetPeerID(self.0, &mut ptr, &mut len))?;
if ptr.is_null() {
Ok(None)
} else {
Ok(Some(slice::from_raw_parts(ptr.cast(), len)))
}
}
}
pub fn supported_ciphers(&self) -> Result<Vec<CipherSuite>> {
unsafe {
let mut num_ciphers = 0;
cvt(SSLGetNumberSupportedCiphers(self.0, &mut num_ciphers))?;
let mut ciphers = vec![0; num_ciphers];
cvt(SSLGetSupportedCiphers(
self.0,
ciphers.as_mut_ptr(),
&mut num_ciphers,
))?;
Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
}
}
pub fn enabled_ciphers(&self) -> Result<Vec<CipherSuite>> {
unsafe {
let mut num_ciphers = 0;
cvt(SSLGetNumberEnabledCiphers(self.0, &mut num_ciphers))?;
let mut ciphers = vec![0; num_ciphers];
cvt(SSLGetEnabledCiphers(
self.0,
ciphers.as_mut_ptr(),
&mut num_ciphers,
))?;
Ok(ciphers.iter().map(|c| CipherSuite::from_raw(*c)).collect())
}
}
pub fn set_enabled_ciphers(&mut self, ciphers: &[CipherSuite]) -> Result<()> {
let ciphers = ciphers.iter().map(|c| c.to_raw()).collect::<Vec<_>>();
unsafe {
cvt(SSLSetEnabledCiphers(
self.0,
ciphers.as_ptr(),
ciphers.len(),
))
}
}
#[inline]
pub fn negotiated_cipher(&self) -> Result<CipherSuite> {
unsafe {
let mut cipher = 0;
cvt(SSLGetNegotiatedCipher(self.0, &mut cipher))?;
Ok(CipherSuite::from_raw(cipher))
}
}
#[inline]
pub fn set_client_side_authenticate(&mut self, auth: SslAuthenticate) -> Result<()> {
unsafe { cvt(SSLSetClientSideAuthenticate(self.0, auth.0)) }
}
#[inline]
pub fn client_certificate_state(&self) -> Result<SslClientCertificateState> {
let mut state = 0;
unsafe {
cvt(SSLGetClientCertificateState(self.0, &mut state))?;
}
Ok(SslClientCertificateState(state))
}
pub fn peer_trust2(&self) -> Result<Option<SecTrust>> {
if self.state()? == SessionState::IDLE {
return Err(Error::from_code(errSecBadReq));
}
unsafe {
let mut trust = ptr::null_mut();
cvt(SSLCopyPeerTrust(self.0, &mut trust))?;
if trust.is_null() {
Ok(None)
} else {
Ok(Some(SecTrust::wrap_under_create_rule(trust)))
}
}
}
#[inline]
pub fn state(&self) -> Result<SessionState> {
unsafe {
let mut state = 0;
cvt(SSLGetSessionState(self.0, &mut state))?;
Ok(SessionState(state))
}
}
#[inline]
pub fn negotiated_protocol_version(&self) -> Result<SslProtocol> {
unsafe {
let mut version = 0;
cvt(SSLGetNegotiatedProtocolVersion(self.0, &mut version))?;
Ok(SslProtocol(version))
}
}
#[inline]
pub fn protocol_version_max(&self) -> Result<SslProtocol> {
unsafe {
let mut version = 0;
cvt(SSLGetProtocolVersionMax(self.0, &mut version))?;
Ok(SslProtocol(version))
}
}
#[inline]
pub fn set_protocol_version_max(&mut self, max_version: SslProtocol) -> Result<()> {
unsafe { cvt(SSLSetProtocolVersionMax(self.0, max_version.0)) }
}
#[inline]
pub fn protocol_version_min(&self) -> Result<SslProtocol> {
unsafe {
let mut version = 0;
cvt(SSLGetProtocolVersionMin(self.0, &mut version))?;
Ok(SslProtocol(version))
}
}
#[inline]
pub fn set_protocol_version_min(&mut self, min_version: SslProtocol) -> Result<()> {
unsafe { cvt(SSLSetProtocolVersionMin(self.0, min_version.0)) }
}
#[cfg(feature = "alpn")]
pub fn alpn_protocols(&self) -> Result<Vec<String>> {
let mut array: CFArrayRef = ptr::null();
unsafe {
#[cfg(feature = "OSX_10_13")]
{
cvt(SSLCopyALPNProtocols(self.0, &mut array))?;
}
#[cfg(not(feature = "OSX_10_13"))]
{
dlsym! { fn SSLCopyALPNProtocols(SSLContextRef, *mut CFArrayRef) -> OSStatus }
if let Some(f) = SSLCopyALPNProtocols.get() {
cvt(f(self.0, &mut array))?;
} else {
return Err(Error::from_code(errSecUnimplemented));
}
}
if array.is_null() {
return Ok(vec![]);
}
let array = CFArray::<CFString>::wrap_under_create_rule(array);
Ok(array.into_iter().map(|p| p.to_string()).collect())
}
}
#[cfg(feature = "alpn")]
pub fn set_alpn_protocols(&mut self, protocols: &[&str]) -> Result<()> {
let protocols = CFArray::from_CFTypes(
&protocols
.iter()
.map(|proto| CFString::new(proto))
.collect::<Vec<_>>(),
);
#[cfg(feature = "OSX_10_13")]
{
unsafe { cvt(SSLSetALPNProtocols(self.0, protocols.as_concrete_TypeRef())) }
}
#[cfg(not(feature = "OSX_10_13"))]
{
dlsym! { fn SSLSetALPNProtocols(SSLContextRef, CFArrayRef) -> OSStatus }
if let Some(f) = SSLSetALPNProtocols.get() {
unsafe { cvt(f(self.0, protocols.as_concrete_TypeRef())) }
} else {
Err(Error::from_code(errSecUnimplemented))
}
}
}
#[cfg(feature = "session-tickets")]
pub fn set_session_tickets_enabled(&mut self, enabled: bool) -> Result<()> {
#[cfg(feature = "OSX_10_13")]
{
unsafe { cvt(SSLSetSessionTicketsEnabled(self.0, enabled as Boolean)) }
}
#[cfg(not(feature = "OSX_10_13"))]
{
dlsym! { fn SSLSetSessionTicketsEnabled(SSLContextRef, Boolean) -> OSStatus }
if let Some(f) = SSLSetSessionTicketsEnabled.get() {
unsafe { cvt(f(self.0, enabled as Boolean)) }
} else {
Err(Error::from_code(errSecUnimplemented))
}
}
}
#[cfg(target_os = "macos")]
#[deprecated(note = "use `set_protocol_version_max`")]
pub fn set_protocol_version_enabled(
&mut self,
protocol: SslProtocol,
enabled: bool,
) -> Result<()> {
unsafe {
cvt(SSLSetProtocolVersionEnabled(
self.0,
protocol.0,
Boolean::from(enabled),
))
}
}
#[inline]
pub fn buffered_read_size(&self) -> Result<usize> {
unsafe {
let mut size = 0;
cvt(SSLGetBufferedReadSize(self.0, &mut size))?;
Ok(size)
}
}
impl_options! {
const kSSLSessionOptionBreakOnServerAuth: break_on_server_auth & set_break_on_server_auth,
const kSSLSessionOptionBreakOnCertRequested: break_on_cert_requested & set_break_on_cert_requested,
const kSSLSessionOptionBreakOnClientAuth: break_on_client_auth & set_break_on_client_auth,
const kSSLSessionOptionFalseStart: false_start & set_false_start,
const kSSLSessionOptionSendOneByteRecord: send_one_byte_record & set_send_one_byte_record,
}
fn into_stream<S>(self, stream: S) -> Result<SslStream<S>>
where
S: Read + Write,
{
unsafe {
let ret = SSLSetIOFuncs(self.0, read_func::<S>, write_func::<S>);
if ret != errSecSuccess {
return Err(Error::from_code(ret));
}
let stream = Connection {
stream,
err: None,
panic: None,
};
let stream = Box::into_raw(Box::new(stream));
let ret = SSLSetConnection(self.0, stream.cast());
if ret != errSecSuccess {
let _conn = Box::from_raw(stream);
return Err(Error::from_code(ret));
}
Ok(SslStream {
ctx: self,
_m: PhantomData,
})
}
}
pub fn handshake<S>(self, stream: S) -> result::Result<SslStream<S>, HandshakeError<S>>
where
S: Read + Write,
{
self.into_stream(stream)
.map_err(HandshakeError::Failure)
.and_then(SslStream::handshake)
}
}
struct Connection<S> {
stream: S,
err: Option<io::Error>,
panic: Option<Box<dyn Any + Send>>,
}
#[cold]
fn translate_err(e: &io::Error) -> OSStatus {
match e.kind() {
io::ErrorKind::NotFound => errSSLClosedGraceful,
io::ErrorKind::ConnectionReset => errSSLClosedAbort,
io::ErrorKind::WouldBlock |
io::ErrorKind::NotConnected => errSSLWouldBlock,
_ => errSecIO,
}
}
unsafe extern "C" fn read_func<S>(
connection: SSLConnectionRef,
data: *mut c_void,
data_length: *mut usize,
) -> OSStatus
where
S: Read,
{
let conn: &mut Connection<S> = &mut *(connection as *mut _);
let data = slice::from_raw_parts_mut(data.cast::<u8>(), *data_length);
let mut start = 0;
let mut ret = errSecSuccess;
while start < data.len() {
match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.read(&mut data[start..]))) {
Ok(Ok(0)) => {
ret = errSSLClosedNoNotify;
break;
}
Ok(Ok(len)) => start += len,
Ok(Err(e)) => {
ret = translate_err(&e);
conn.err = Some(e);
break;
}
Err(e) => {
ret = errSecIO;
conn.panic = Some(e);
break;
}
}
}
*data_length = start;
ret
}
unsafe extern "C" fn write_func<S>(
connection: SSLConnectionRef,
data: *const c_void,
data_length: *mut usize,
) -> OSStatus
where
S: Write,
{
let conn: &mut Connection<S> = &mut *(connection as *mut _);
let data = slice::from_raw_parts(data as *mut u8, *data_length);
let mut start = 0;
let mut ret = errSecSuccess;
while start < data.len() {
match panic::catch_unwind(AssertUnwindSafe(|| conn.stream.write(&data[start..]))) {
Ok(Ok(0)) => {
ret = errSSLClosedNoNotify;
break;
}
Ok(Ok(len)) => start += len,
Ok(Err(e)) => {
ret = translate_err(&e);
conn.err = Some(e);
break;
}
Err(e) => {
ret = errSecIO;
conn.panic = Some(e);
break;
}
}
}
*data_length = start;
ret
}
pub struct SslStream<S> {
ctx: SslContext,
_m: PhantomData<S>,
}
impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
#[cold]
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("SslStream")
.field("context", &self.ctx)
.field("stream", self.get_ref())
.finish()
}
}
impl<S> Drop for SslStream<S> {
fn drop(&mut self) {
unsafe {
let mut conn = ptr::null();
let ret = SSLGetConnection(self.ctx.0, &mut conn);
assert!(ret == errSecSuccess);
let _ = Box::<Connection<S>>::from_raw(conn as *mut _);
}
}
}
impl<S> SslStream<S> {
fn handshake(mut self) -> result::Result<Self, HandshakeError<S>> {
match unsafe { SSLHandshake(self.ctx.0) } {
errSecSuccess => Ok(self),
reason @ errSSLPeerAuthCompleted
| reason @ errSSLClientCertRequested
| reason @ errSSLWouldBlock
| reason @ errSSLClientHelloReceived => {
Err(HandshakeError::Interrupted(MidHandshakeSslStream {
stream: self,
error: Error::from_code(reason),
}))
}
err => {
self.check_panic();
Err(HandshakeError::Failure(Error::from_code(err)))
}
}
}
#[inline(always)]
#[must_use]
pub fn get_ref(&self) -> &S {
&self.connection().stream
}
#[inline(always)]
pub fn get_mut(&mut self) -> &mut S {
&mut self.connection_mut().stream
}
#[inline(always)]
#[must_use]
pub fn context(&self) -> &SslContext {
&self.ctx
}
#[inline(always)]
pub fn context_mut(&mut self) -> &mut SslContext {
&mut self.ctx
}
pub fn close(&mut self) -> result::Result<(), io::Error> {
unsafe {
let ret = SSLClose(self.ctx.0);
if ret == errSecSuccess {
Ok(())
} else {
Err(self.get_error(ret))
}
}
}
fn connection(&self) -> &Connection<S> {
unsafe {
let mut conn = ptr::null();
let ret = SSLGetConnection(self.ctx.0, &mut conn);
assert!(ret == errSecSuccess);
&mut *(conn as *mut Connection<S>)
}
}
fn connection_mut(&mut self) -> &mut Connection<S> {
unsafe {
let mut conn = ptr::null();
let ret = SSLGetConnection(self.ctx.0, &mut conn);
assert!(ret == errSecSuccess);
&mut *(conn as *mut Connection<S>)
}
}
#[cold]
fn check_panic(&mut self) {
let conn = self.connection_mut();
if let Some(err) = conn.panic.take() {
panic::resume_unwind(err);
}
}
#[cold]
fn get_error(&mut self, ret: OSStatus) -> io::Error {
self.check_panic();
if let Some(err) = self.connection_mut().err.take() {
err
} else {
io::Error::new(io::ErrorKind::Other, Error::from_code(ret))
}
}
}
impl<S: Read + Write> Read for SslStream<S> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
let buffered = self.context().buffered_read_size().unwrap_or(0);
let to_read = if buffered > 0 {
cmp::min(buffered, buf.len())
} else {
buf.len()
};
unsafe {
let mut nread = 0;
let ret = SSLRead(self.ctx.0, buf.as_mut_ptr().cast(), to_read, &mut nread);
if nread > 0 {
return Ok(nread);
}
match ret {
errSSLClosedGraceful | errSSLClosedAbort | errSSLClosedNoNotify => Ok(0),
errSSLPeerAuthCompleted => self.read(buf),
_ => Err(self.get_error(ret)),
}
}
}
}
impl<S: Read + Write> Write for SslStream<S> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
unsafe {
let mut nwritten = 0;
let ret = SSLWrite(
self.ctx.0,
buf.as_ptr().cast(),
buf.len(),
&mut nwritten,
);
if nwritten > 0 {
Ok(nwritten)
} else {
Err(self.get_error(ret))
}
}
}
fn flush(&mut self) -> io::Result<()> {
self.connection_mut().stream.flush()
}
}
#[derive(Debug)]
pub struct ClientBuilder {
identity: Option<SecIdentity>,
certs: Vec<SecCertificate>,
chain: Vec<SecCertificate>,
protocol_min: Option<SslProtocol>,
protocol_max: Option<SslProtocol>,
trust_certs_only: bool,
use_sni: bool,
danger_accept_invalid_certs: bool,
danger_accept_invalid_hostnames: bool,
whitelisted_ciphers: Vec<CipherSuite>,
blacklisted_ciphers: Vec<CipherSuite>,
#[cfg(feature = "alpn")]
alpn: Option<Vec<String>>,
#[cfg(feature = "session-tickets")]
enable_session_tickets: bool,
}
impl Default for ClientBuilder {
#[inline(always)]
fn default() -> Self {
Self::new()
}
}
impl ClientBuilder {
#[inline]
#[must_use]
pub fn new() -> Self {
Self {
identity: None,
certs: Vec::new(),
chain: Vec::new(),
protocol_min: None,
protocol_max: None,
trust_certs_only: false,
use_sni: true,
danger_accept_invalid_certs: false,
danger_accept_invalid_hostnames: false,
whitelisted_ciphers: Vec::new(),
blacklisted_ciphers: Vec::new(),
#[cfg(feature = "alpn")]
alpn: None,
#[cfg(feature = "session-tickets")]
enable_session_tickets: false,
}
}
#[inline]
pub fn anchor_certificates(&mut self, certs: &[SecCertificate]) -> &mut Self {
self.certs = certs.to_owned();
self
}
#[inline]
pub fn add_anchor_certificate(&mut self, certs: &SecCertificate) -> &mut Self {
self.certs.push(certs.to_owned());
self
}
#[inline(always)]
pub fn trust_anchor_certificates_only(&mut self, only: bool) -> &mut Self {
self.trust_certs_only = only;
self
}
#[inline(always)]
pub fn danger_accept_invalid_certs(&mut self, noverify: bool) -> &mut Self {
self.danger_accept_invalid_certs = noverify;
self
}
#[inline(always)]
pub fn use_sni(&mut self, use_sni: bool) -> &mut Self {
self.use_sni = use_sni;
self
}
#[inline(always)]
pub fn danger_accept_invalid_hostnames(
&mut self,
danger_accept_invalid_hostnames: bool,
) -> &mut Self {
self.danger_accept_invalid_hostnames = danger_accept_invalid_hostnames;
self
}
pub fn whitelist_ciphers(&mut self, whitelisted_ciphers: &[CipherSuite]) -> &mut Self {
self.whitelisted_ciphers = whitelisted_ciphers.to_owned();
self
}
pub fn blacklist_ciphers(&mut self, blacklisted_ciphers: &[CipherSuite]) -> &mut Self {
self.blacklisted_ciphers = blacklisted_ciphers.to_owned();
self
}
pub fn identity(&mut self, identity: &SecIdentity, chain: &[SecCertificate]) -> &mut Self {
self.identity = Some(identity.clone());
self.chain = chain.to_owned();
self
}
#[inline(always)]
pub fn protocol_min(&mut self, min: SslProtocol) -> &mut Self {
self.protocol_min = Some(min);
self
}
#[inline(always)]
pub fn protocol_max(&mut self, max: SslProtocol) -> &mut Self {
self.protocol_max = Some(max);
self
}
#[cfg(feature = "alpn")]
pub fn alpn_protocols(&mut self, protocols: &[&str]) -> &mut Self {
self.alpn = Some(protocols.iter().map(|s| s.to_string()).collect());
self
}
#[cfg(feature = "session-tickets")]
#[inline(always)]
pub fn enable_session_tickets(&mut self, enable: bool) -> &mut Self {
self.enable_session_tickets = enable;
self
}
pub fn handshake<S>(
&self,
domain: &str,
stream: S,
) -> result::Result<SslStream<S>, ClientHandshakeError<S>>
where
S: Read + Write,
{
let stream = MidHandshakeSslStream {
stream: self.ctx_into_stream(domain, stream)?,
error: Error::from(errSecSuccess),
};
let certs = self.certs.clone();
let stream = MidHandshakeClientBuilder {
stream,
domain: if self.danger_accept_invalid_hostnames {
None
} else {
Some(domain.to_string())
},
certs,
trust_certs_only: self.trust_certs_only,
danger_accept_invalid_certs: self.danger_accept_invalid_certs,
};
stream.handshake()
}
fn ctx_into_stream<S>(&self, domain: &str, stream: S) -> Result<SslStream<S>>
where
S: Read + Write,
{
let mut ctx = SslContext::new(SslProtocolSide::CLIENT, SslConnectionType::STREAM)?;
if self.use_sni {
ctx.set_peer_domain_name(domain)?;
}
if let Some(ref identity) = self.identity {
ctx.set_certificate(identity, &self.chain)?;
}
#[cfg(feature = "alpn")]
{
if let Some(ref alpn) = self.alpn {
ctx.set_alpn_protocols(&alpn.iter().map(|s| &**s).collect::<Vec<_>>())?;
}
}
#[cfg(feature = "session-tickets")]
{
if self.enable_session_tickets {
ctx.set_peer_id(domain.as_bytes())?;
ctx.set_session_tickets_enabled(true)?;
}
}
ctx.set_break_on_server_auth(true)?;
self.configure_protocols(&mut ctx)?;
self.configure_ciphers(&mut ctx)?;
ctx.into_stream(stream)
}
fn configure_protocols(&self, ctx: &mut SslContext) -> Result<()> {
if let Some(min) = self.protocol_min {
ctx.set_protocol_version_min(min)?;
}
if let Some(max) = self.protocol_max {
ctx.set_protocol_version_max(max)?;
}
Ok(())
}
fn configure_ciphers(&self, ctx: &mut SslContext) -> Result<()> {
let mut ciphers = if self.whitelisted_ciphers.is_empty() {
ctx.enabled_ciphers()?
} else {
self.whitelisted_ciphers.clone()
};
if !self.blacklisted_ciphers.is_empty() {
ciphers.retain(|cipher| !self.blacklisted_ciphers.contains(cipher));
}
ctx.set_enabled_ciphers(&ciphers)?;
Ok(())
}
}
#[derive(Debug)]
pub struct ServerBuilder {
identity: SecIdentity,
certs: Vec<SecCertificate>,
}
impl ServerBuilder {
#[must_use]
pub fn new(identity: &SecIdentity, certs: &[SecCertificate]) -> Self {
Self {
identity: identity.clone(),
certs: certs.to_owned(),
}
}
pub fn from_pkcs12(pkcs12_der: &[u8], passphrase: &str) -> Result<Self> {
let mut identities: Vec<(SecIdentity, Vec<SecCertificate>)> = Pkcs12ImportOptions::new()
.passphrase(passphrase)
.import(pkcs12_der)?
.into_iter()
.filter_map(|idendity| {
let certs = idendity.cert_chain.unwrap_or_default();
idendity.identity.map(|identity| (identity, certs))
})
.collect();
if identities.len() == 1 {
let (identity, certs) = identities.pop().unwrap();
Ok(ServerBuilder::new(&identity, &certs))
} else {
Err(Error::from_code(errSecParam))
}
}
pub fn new_ssl_context(&self) -> Result<SslContext> {
let mut ctx = SslContext::new(SslProtocolSide::SERVER, SslConnectionType::STREAM)?;
ctx.set_certificate(&self.identity, &self.certs)?;
Ok(ctx)
}
pub fn handshake<S>(&self, stream: S) -> Result<SslStream<S>>
where
S: Read + Write,
{
match self.new_ssl_context()?.handshake(stream) {
Ok(stream) => Ok(stream),
Err(HandshakeError::Interrupted(stream)) => Err(*stream.error()),
Err(HandshakeError::Failure(err)) => Err(err),
}
}
}
#[cfg(test)]
mod test {
use std::io::prelude::*;
use std::net::TcpStream;
use super::*;
#[test]
fn server_builder_from_pkcs12() {
let pkcs12_der = include_bytes!("../test/server.p12");
ServerBuilder::from_pkcs12(pkcs12_der, "password123").unwrap();
}
#[test]
fn connect() {
let mut ctx = p!(SslContext::new(
SslProtocolSide::CLIENT,
SslConnectionType::STREAM
));
p!(ctx.set_peer_domain_name("google.com"));
let stream = p!(TcpStream::connect("google.com:443"));
p!(ctx.handshake(stream));
}
#[test]
fn connect_bad_domain() {
let mut ctx = p!(SslContext::new(
SslProtocolSide::CLIENT,
SslConnectionType::STREAM
));
p!(ctx.set_peer_domain_name("foobar.com"));
let stream = p!(TcpStream::connect("google.com:443"));
match ctx.handshake(stream) {
Ok(_) => panic!("expected failure"),
Err(_) => {}
}
}
#[test]
fn load_page() {
let mut ctx = p!(SslContext::new(
SslProtocolSide::CLIENT,
SslConnectionType::STREAM
));
p!(ctx.set_peer_domain_name("google.com"));
let stream = p!(TcpStream::connect("google.com:443"));
let mut stream = p!(ctx.handshake(stream));
p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
p!(stream.flush());
let mut buf = vec![];
p!(stream.read_to_end(&mut buf));
println!("{}", String::from_utf8_lossy(&buf));
}
#[test]
fn client_no_session_ticket_resumption() {
for _ in 0..2 {
let stream = p!(TcpStream::connect("google.com:443"));
let stream = MidHandshakeSslStream {
stream: ClientBuilder::new()
.ctx_into_stream("google.com", stream)
.unwrap(),
error: Error::from(errSecSuccess),
};
let mut result = stream.handshake();
if let Err(HandshakeError::Interrupted(stream)) = result {
assert!(stream.server_auth_completed());
result = stream.handshake();
} else {
panic!("Unexpectedly skipped server auth");
}
assert!(result.is_ok());
}
}
#[test]
#[cfg(feature = "session-tickets")]
fn client_session_ticket_resumption() {
for i in 0..2 {
let stream = p!(TcpStream::connect("google.com:443"));
let mut builder = ClientBuilder::new();
builder.enable_session_tickets(true);
let stream = MidHandshakeSslStream {
stream: builder.ctx_into_stream("google.com", stream).unwrap(),
error: Error::from(errSecSuccess),
};
let mut result = stream.handshake();
if let Err(HandshakeError::Interrupted(stream)) = result {
assert!(stream.server_auth_completed());
assert_eq!(
i, 0,
"Session ticket resumption did not work, server auth was not skipped"
);
result = stream.handshake();
} else {
assert_eq!(i, 1, "Unexpectedly skipped server auth");
}
assert!(result.is_ok());
}
}
#[test]
#[cfg(feature = "alpn")]
fn client_alpn_accept() {
let mut ctx = p!(SslContext::new(
SslProtocolSide::CLIENT,
SslConnectionType::STREAM
));
p!(ctx.set_peer_domain_name("google.com"));
p!(ctx.set_alpn_protocols(&vec!["h2"]));
let stream = p!(TcpStream::connect("google.com:443"));
let stream = ctx.handshake(stream).unwrap();
assert_eq!(vec!["h2"], stream.context().alpn_protocols().unwrap());
}
#[test]
#[cfg(feature = "alpn")]
fn client_alpn_reject() {
let mut ctx = p!(SslContext::new(
SslProtocolSide::CLIENT,
SslConnectionType::STREAM
));
p!(ctx.set_peer_domain_name("google.com"));
p!(ctx.set_alpn_protocols(&vec!["h2c"]));
let stream = p!(TcpStream::connect("google.com:443"));
let stream = ctx.handshake(stream).unwrap();
assert!(stream.context().alpn_protocols().is_err());
}
#[test]
fn client_no_anchor_certs() {
let stream = p!(TcpStream::connect("google.com:443"));
assert!(ClientBuilder::new()
.trust_anchor_certificates_only(true)
.handshake("google.com", stream)
.is_err());
}
#[test]
fn client_bad_domain() {
let stream = p!(TcpStream::connect("google.com:443"));
assert!(ClientBuilder::new()
.handshake("foobar.com", stream)
.is_err());
}
#[test]
fn client_bad_domain_ignored() {
let stream = p!(TcpStream::connect("google.com:443"));
ClientBuilder::new()
.danger_accept_invalid_hostnames(true)
.handshake("foobar.com", stream)
.unwrap();
}
#[test]
fn connect_no_verify_ssl() {
let stream = p!(TcpStream::connect("expired.badssl.com:443"));
let mut builder = ClientBuilder::new();
builder.danger_accept_invalid_certs(true);
builder.handshake("expired.badssl.com", stream).unwrap();
}
#[test]
fn load_page_client() {
let stream = p!(TcpStream::connect("google.com:443"));
let mut stream = p!(ClientBuilder::new().handshake("google.com", stream));
p!(stream.write_all(b"GET / HTTP/1.0\r\n\r\n"));
p!(stream.flush());
let mut buf = vec![];
p!(stream.read_to_end(&mut buf));
println!("{}", String::from_utf8_lossy(&buf));
}
#[test]
#[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] fn cipher_configuration() {
let mut ctx = p!(SslContext::new(
SslProtocolSide::SERVER,
SslConnectionType::STREAM
));
let ciphers = p!(ctx.enabled_ciphers());
let ciphers = ciphers
.iter()
.enumerate()
.filter_map(|(i, c)| if i % 2 == 0 { Some(*c) } else { None })
.collect::<Vec<_>>();
p!(ctx.set_enabled_ciphers(&ciphers));
assert_eq!(ciphers, p!(ctx.enabled_ciphers()));
}
#[test]
fn test_builder_whitelist_ciphers() {
let stream = p!(TcpStream::connect("google.com:443"));
let ctx = p!(SslContext::new(
SslProtocolSide::CLIENT,
SslConnectionType::STREAM
));
assert!(p!(ctx.enabled_ciphers()).len() > 1);
let ciphers = p!(ctx.enabled_ciphers());
let cipher = ciphers.first().unwrap();
let stream = p!(ClientBuilder::new()
.whitelist_ciphers(&[*cipher])
.ctx_into_stream("google.com", stream));
assert_eq!(1, p!(stream.context().enabled_ciphers()).len());
}
#[test]
#[cfg_attr(any(target_os = "ios", target_os = "tvos", target_os = "watchos", target_os = "visionos"), ignore)] fn test_builder_blacklist_ciphers() {
let stream = p!(TcpStream::connect("google.com:443"));
let ctx = p!(SslContext::new(
SslProtocolSide::CLIENT,
SslConnectionType::STREAM
));
let num = p!(ctx.enabled_ciphers()).len();
assert!(num > 1);
let ciphers = p!(ctx.enabled_ciphers());
let cipher = ciphers.first().unwrap();
let stream = p!(ClientBuilder::new()
.blacklist_ciphers(&[*cipher])
.ctx_into_stream("google.com", stream));
assert_eq!(num - 1, p!(stream.context().enabled_ciphers()).len());
}
#[test]
fn idle_context_peer_trust() {
let ctx = p!(SslContext::new(
SslProtocolSide::SERVER,
SslConnectionType::STREAM
));
assert!(ctx.peer_trust2().is_err());
}
#[test]
fn peer_id() {
let mut ctx = p!(SslContext::new(
SslProtocolSide::SERVER,
SslConnectionType::STREAM
));
assert!(p!(ctx.peer_id()).is_none());
p!(ctx.set_peer_id(b"foobar"));
assert_eq!(p!(ctx.peer_id()), Some(&b"foobar"[..]));
}
#[test]
fn peer_domain_name() {
let mut ctx = p!(SslContext::new(
SslProtocolSide::CLIENT,
SslConnectionType::STREAM
));
assert_eq!("", p!(ctx.peer_domain_name()));
p!(ctx.set_peer_domain_name("foobar.com"));
assert_eq!("foobar.com", p!(ctx.peer_domain_name()));
}
#[test]
#[should_panic(expected = "blammo")]
fn write_panic() {
struct ExplodingStream(TcpStream);
impl Read for ExplodingStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}
impl Write for ExplodingStream {
fn write(&mut self, _: &[u8]) -> io::Result<usize> {
panic!("blammo");
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
let mut ctx = p!(SslContext::new(
SslProtocolSide::CLIENT,
SslConnectionType::STREAM
));
p!(ctx.set_peer_domain_name("google.com"));
let stream = p!(TcpStream::connect("google.com:443"));
let _ = ctx.handshake(ExplodingStream(stream));
}
#[test]
#[should_panic(expected = "blammo")]
fn read_panic() {
struct ExplodingStream(TcpStream);
impl Read for ExplodingStream {
fn read(&mut self, _: &mut [u8]) -> io::Result<usize> {
panic!("blammo");
}
}
impl Write for ExplodingStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}
let mut ctx = p!(SslContext::new(
SslProtocolSide::CLIENT,
SslConnectionType::STREAM
));
p!(ctx.set_peer_domain_name("google.com"));
let stream = p!(TcpStream::connect("google.com:443"));
let _ = ctx.handshake(ExplodingStream(stream));
}
#[test]
fn zero_length_buffers() {
let mut ctx = p!(SslContext::new(
SslProtocolSide::CLIENT,
SslConnectionType::STREAM
));
p!(ctx.set_peer_domain_name("google.com"));
let stream = p!(TcpStream::connect("google.com:443"));
let mut stream = ctx.handshake(stream).unwrap();
assert_eq!(stream.write(b"").unwrap(), 0);
assert_eq!(stream.read(&mut []).unwrap(), 0);
}
}