use std::io;
use std::mem;
use std::net::SocketAddr;
use std::time::Duration;
use futures::stream::{Fuse, Peekable, Stream};
use futures::sync::mpsc::{unbounded, UnboundedReceiver};
use futures::{Async, Future, Poll};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_tcp::TcpStream as TokioTcpStream;
use tokio_timer::Timeout;
use error::*;
use xfer::{BufStreamHandle, SerialMessage};
enum WriteTcpState {
LenBytes {
pos: usize,
length: [u8; 2],
bytes: Vec<u8>,
},
Bytes {
pos: usize,
bytes: Vec<u8>,
},
Flushing,
}
pub enum ReadTcpState {
LenBytes {
pos: usize,
bytes: [u8; 2],
},
Bytes {
pos: usize,
bytes: Vec<u8>,
},
}
#[must_use = "futures do nothing unless polled"]
pub struct TcpStream<S> {
socket: S,
outbound_messages: Peekable<Fuse<UnboundedReceiver<SerialMessage>>>,
send_state: Option<WriteTcpState>,
read_state: ReadTcpState,
peer_addr: SocketAddr,
}
impl<S> TcpStream<S> {
pub fn peer_addr(&self) -> SocketAddr {
self.peer_addr
}
}
impl TcpStream<TokioTcpStream> {
pub fn new<E>(
name_server: SocketAddr,
) -> (
Box<Future<Item = TcpStream<TokioTcpStream>, Error = io::Error> + Send>,
BufStreamHandle,
)
where
E: FromProtoError,
{
Self::with_timeout(name_server, Duration::from_secs(5))
}
pub fn with_timeout(
name_server: SocketAddr,
timeout: Duration,
) -> (
Box<Future<Item = TcpStream<TokioTcpStream>, Error = io::Error> + Send>,
BufStreamHandle,
) {
let (message_sender, outbound_messages) = unbounded();
let message_sender = BufStreamHandle::new(message_sender);
let tcp = TokioTcpStream::connect(&name_server);
let stream = Timeout::new(tcp, timeout)
.map_err(move |e| {
debug!("timed out connecting to: {}", name_server);
e.into_inner().unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::TimedOut,
format!("timed out connecting to: {}", name_server),
)
})
})
.map(move |tcp_stream| {
debug!("TCP connection established to: {}", name_server);
TcpStream {
socket: tcp_stream,
outbound_messages: outbound_messages.fuse().peekable(),
send_state: None,
read_state: ReadTcpState::LenBytes {
pos: 0,
bytes: [0u8; 2],
},
peer_addr: name_server,
}
});
(Box::new(stream), message_sender)
}
}
impl<S: AsyncRead + AsyncWrite> TcpStream<S> {
pub fn from_stream(stream: S, peer_addr: SocketAddr) -> (Self, BufStreamHandle) {
let (message_sender, outbound_messages) = unbounded();
let message_sender = BufStreamHandle::new(message_sender);
let stream = Self::from_stream_with_receiver(stream, peer_addr, outbound_messages);
(stream, message_sender)
}
pub fn from_stream_with_receiver(
stream: S,
peer_addr: SocketAddr,
receiver: UnboundedReceiver<SerialMessage>,
) -> Self {
TcpStream {
socket: stream,
outbound_messages: receiver.fuse().peekable(),
send_state: None,
read_state: ReadTcpState::LenBytes {
pos: 0,
bytes: [0u8; 2],
},
peer_addr: peer_addr,
}
}
}
impl<S: AsyncRead + AsyncWrite> Stream for TcpStream<S> {
type Item = SerialMessage;
type Error = io::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop {
if self.send_state.is_some() {
match self.send_state {
Some(WriteTcpState::LenBytes {
ref mut pos,
ref length,
..
}) => {
let wrote = try_nb!(self.socket.write(&length[*pos..]));
*pos += wrote;
}
Some(WriteTcpState::Bytes {
ref mut pos,
ref bytes,
}) => {
let wrote = try_nb!(self.socket.write(&bytes[*pos..]));
*pos += wrote;
}
Some(WriteTcpState::Flushing) => {
try_nb!(self.socket.flush());
}
_ => (),
}
let current_state = mem::replace(&mut self.send_state, None);
match current_state {
Some(WriteTcpState::LenBytes { pos, length, bytes }) => {
if pos < length.len() {
mem::replace(
&mut self.send_state,
Some(WriteTcpState::LenBytes {
pos: pos,
length: length,
bytes: bytes,
}),
);
} else {
mem::replace(
&mut self.send_state,
Some(WriteTcpState::Bytes {
pos: 0,
bytes: bytes,
}),
);
}
}
Some(WriteTcpState::Bytes { pos, bytes }) => {
if pos < bytes.len() {
mem::replace(
&mut self.send_state,
Some(WriteTcpState::Bytes {
pos: pos,
bytes: bytes,
}),
);
} else {
mem::replace(&mut self.send_state, Some(WriteTcpState::Flushing));
}
}
Some(WriteTcpState::Flushing) => {
mem::replace(&mut self.send_state, None);
}
None => (),
};
} else {
match self
.outbound_messages
.poll()
.map_err(|()| io::Error::new(io::ErrorKind::Other, "unknown"))?
{
Async::Ready(Some(message)) => {
let (buffer, dst) = message.unwrap();
let peer = self.peer_addr;
if peer != dst {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("mismatched peer: {} and dst: {}", peer, dst),
));
}
let len: [u8; 2] = [
(buffer.len() >> 8 & 0xFF) as u8,
(buffer.len() & 0xFF) as u8,
];
debug!("sending message len: {} to: {}", buffer.len(), dst);
self.send_state = Some(WriteTcpState::LenBytes {
pos: 0,
length: len,
bytes: buffer,
});
}
Async::NotReady => break,
Async::Ready(None) => {
debug!("no messages to send");
break;
}
}
}
}
let mut ret_buf: Option<Vec<u8>> = None;
while ret_buf.is_none() {
let new_state: Option<ReadTcpState> = match self.read_state {
ReadTcpState::LenBytes {
ref mut pos,
ref mut bytes,
} => {
let read = try_nb!(self.socket.read(&mut bytes[*pos..]));
if read == 0 {
debug!("zero bytes read, stream closed?");
if *pos == 0 {
return Ok(Async::Ready(None));
} else {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"closed while reading length",
));
}
}
debug!("in ReadTcpState::LenBytes: {}", pos);
*pos += read;
if *pos < bytes.len() {
debug!("remain ReadTcpState::LenBytes: {}", pos);
None
} else {
let length =
u16::from(bytes[0]) << 8 & 0xFF00 | u16::from(bytes[1]) & 0x00FF;
debug!("got length: {}", length);
let mut bytes = Vec::with_capacity(length as usize);
bytes.resize(length as usize, 0);
debug!("move ReadTcpState::Bytes: {}", bytes.len());
Some(ReadTcpState::Bytes {
pos: 0,
bytes: bytes,
})
}
}
ReadTcpState::Bytes {
ref mut pos,
ref mut bytes,
} => {
let read = try_nb!(self.socket.read(&mut bytes[*pos..]));
if read == 0 {
debug!("zero bytes read for message, stream closed?");
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"closed while reading message",
));
}
debug!("in ReadTcpState::Bytes: {}", bytes.len());
*pos += read;
if *pos < bytes.len() {
debug!("remain ReadTcpState::Bytes: {}", bytes.len());
None
} else {
debug!("reset ReadTcpState::LenBytes: {}", 0);
Some(ReadTcpState::LenBytes {
pos: 0,
bytes: [0u8; 2],
})
}
}
};
if let Some(state) = new_state {
if let ReadTcpState::Bytes { pos, bytes } =
mem::replace(&mut self.read_state, state)
{
debug!("returning bytes");
assert_eq!(pos, bytes.len());
ret_buf = Some(bytes);
}
}
}
if let Some(buffer) = ret_buf {
debug!("returning buffer");
let src_addr = self.peer_addr;
return Ok(Async::Ready(Some(SerialMessage::new(buffer, src_addr))));
} else {
debug!("bottomed out");
return Ok(Async::NotReady);
}
}
}
#[cfg(not(target_os = "linux"))]
#[cfg(test)]
use std::net::Ipv6Addr;
#[cfg(test)]
use std::net::{IpAddr, Ipv4Addr};
#[test]
fn test_tcp_client_stream_ipv4() {
tcp_client_stream_test(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
}
#[test]
#[cfg(not(target_os = "linux"))]
fn test_tcp_client_stream_ipv6() {
tcp_client_stream_test(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)))
}
#[cfg(test)]
const TEST_BYTES: &'static [u8; 8] = b"DEADBEEF";
#[cfg(test)]
const TEST_BYTES_LEN: usize = 8;
#[cfg(test)]
fn tcp_client_stream_test(server_addr: IpAddr) {
use std::io::{Read, Write};
use tokio::runtime::current_thread::Runtime;
use std;
let succeeded = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let succeeded_clone = succeeded.clone();
std::thread::Builder::new()
.name("thread_killer".to_string())
.spawn(move || {
let succeeded = succeeded_clone.clone();
for _ in 0..15 {
std::thread::sleep(std::time::Duration::from_secs(1));
if succeeded.load(std::sync::atomic::Ordering::Relaxed) {
return;
}
}
panic!("timeout");
})
.unwrap();
let server = std::net::TcpListener::bind(SocketAddr::new(server_addr, 0)).unwrap();
let server_addr = server.local_addr().unwrap();
let send_recv_times = 4;
let server_handle = std::thread::Builder::new()
.name("test_tcp_client_stream:server".to_string())
.spawn(move || {
let (mut socket, _) = server.accept().expect("accept failed");
socket
.set_read_timeout(Some(std::time::Duration::from_secs(5)))
.unwrap();
socket
.set_write_timeout(Some(std::time::Duration::from_secs(5)))
.unwrap();
for _ in 0..send_recv_times {
let mut len_bytes = [0_u8; 2];
socket
.read_exact(&mut len_bytes)
.expect("SERVER: receive failed");
let length = (len_bytes[0] as u16) << 8 & 0xFF00 | len_bytes[1] as u16 & 0x00FF;
assert_eq!(length as usize, TEST_BYTES_LEN);
let mut buffer = [0_u8; TEST_BYTES_LEN];
socket.read_exact(&mut buffer).unwrap();
assert_eq!(&buffer, TEST_BYTES);
socket
.write_all(&len_bytes)
.expect("SERVER: send length failed");
socket
.write_all(&buffer)
.expect("SERVER: send buffer failed");
std::thread::yield_now();
}
})
.unwrap();
let mut io_loop = Runtime::new().unwrap();
let (stream, sender) = TcpStream::new::<ProtoError>(server_addr);
let mut stream = io_loop
.block_on(stream)
.ok()
.expect("run failed to get stream");
for _ in 0..send_recv_times {
sender
.unbounded_send(SerialMessage::new(TEST_BYTES.to_vec(), server_addr))
.expect("send failed");
let (buffer, stream_tmp) = io_loop
.block_on(stream.into_future())
.ok()
.expect("future iteration run failed");
stream = stream_tmp;
let message = buffer.expect("no buffer received");
assert_eq!(message.bytes(), TEST_BYTES);
}
succeeded.store(true, std::sync::atomic::Ordering::Relaxed);
server_handle.join().expect("server thread failed");
}