use std::{error, fmt, mem, ops};
use std::cell::UnsafeCell;
use std::sync::{Arc, Mutex};
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::SeqCst;
use std::collections::HashMap;
use futures_core::{Future, Poll, Async};
use futures_core::task::{self, Wake, Waker, LocalMap};
#[must_use = "futures do nothing unless polled"]
pub struct Shared<F: Future> {
inner: Arc<Inner<F>>,
waiter: usize,
}
impl<F> fmt::Debug for Shared<F>
where F: Future + fmt::Debug,
F::Item: fmt::Debug,
F::Error: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("Shared")
.field("inner", &self.inner)
.field("waiter", &self.waiter)
.finish()
}
}
struct Inner<F: Future> {
next_clone_id: AtomicUsize,
future: UnsafeCell<Option<(F, LocalMap)>>,
result: UnsafeCell<Option<Result<SharedItem<F::Item>, SharedError<F::Error>>>>,
notifier: Arc<Notifier>,
}
struct Notifier {
state: AtomicUsize,
waiters: Mutex<HashMap<usize, Waker>>,
}
const IDLE: usize = 0;
const POLLING: usize = 1;
const REPOLL: usize = 2;
const COMPLETE: usize = 3;
const POISONED: usize = 4;
pub fn new<F: Future>(future: F) -> Shared<F> {
Shared {
inner: Arc::new(Inner {
next_clone_id: AtomicUsize::new(1),
notifier: Arc::new(Notifier {
state: AtomicUsize::new(IDLE),
waiters: Mutex::new(HashMap::new()),
}),
future: UnsafeCell::new(Some((future, LocalMap::new()))),
result: UnsafeCell::new(None),
}),
waiter: 0,
}
}
impl<F> Shared<F> where F: Future {
pub fn peek(&self) -> Option<Result<SharedItem<F::Item>, SharedError<F::Error>>> {
match self.inner.notifier.state.load(SeqCst) {
COMPLETE => {
Some(unsafe { self.clone_result() })
}
POISONED => panic!("inner future panicked during poll"),
_ => None,
}
}
fn set_waiter(&mut self, cx: &mut task::Context) {
let mut waiters = self.inner.notifier.waiters.lock().unwrap();
waiters.insert(self.waiter, cx.waker().clone());
}
unsafe fn clone_result(&self) -> Result<SharedItem<F::Item>, SharedError<F::Error>> {
match *self.inner.result.get() {
Some(Ok(ref item)) => Ok(SharedItem { item: item.item.clone() }),
Some(Err(ref e)) => Err(SharedError { error: e.error.clone() }),
_ => unreachable!(),
}
}
fn complete(&self) {
unsafe { *self.inner.future.get() = None };
self.inner.notifier.state.store(COMPLETE, SeqCst);
Wake::wake(&self.inner.notifier);
}
}
impl<F> Future for Shared<F>
where F: Future
{
type Item = SharedItem<F::Item>;
type Error = SharedError<F::Error>;
fn poll(&mut self, cx: &mut task::Context) -> Poll<Self::Item, Self::Error> {
self.set_waiter(cx);
match self.inner.notifier.state.compare_and_swap(IDLE, POLLING, SeqCst) {
IDLE => {
}
POLLING | REPOLL => {
return Ok(Async::Pending);
}
COMPLETE => {
return unsafe { self.clone_result().map(Async::Ready) };
}
POISONED => panic!("inner future panicked during poll"),
_ => unreachable!(),
}
loop {
struct Reset<'a>(&'a AtomicUsize);
impl<'a> Drop for Reset<'a> {
fn drop(&mut self) {
use std::thread;
if thread::panicking() {
self.0.store(POISONED, SeqCst);
}
}
}
let _reset = Reset(&self.inner.notifier.state);
let res = unsafe {
let (ref mut f, ref mut data) = *(*self.inner.future.get()).as_mut().unwrap();
let waker = Waker::from(self.inner.notifier.clone());
let mut cx = task::Context::new(data, &waker, cx.executor());
f.poll(&mut cx)
};
match res {
Ok(Async::Pending) => {
match self.inner.notifier.state.compare_and_swap(POLLING, IDLE, SeqCst) {
POLLING => {
return Ok(Async::Pending);
}
REPOLL => {
let prev = self.inner.notifier.state.swap(POLLING, SeqCst);
assert_eq!(prev, REPOLL);
}
_ => unreachable!(),
}
}
Ok(Async::Ready(i)) => {
unsafe {
(*self.inner.result.get()) = Some(Ok(SharedItem { item: Arc::new(i) }));
}
break;
}
Err(e) => {
unsafe {
(*self.inner.result.get()) = Some(Err(SharedError { error: Arc::new(e) }));
}
break;
}
}
}
self.complete();
unsafe { self.clone_result().map(Async::Ready) }
}
}
impl<F> Clone for Shared<F> where F: Future {
fn clone(&self) -> Self {
let next_clone_id = self.inner.next_clone_id.fetch_add(1, SeqCst);
Shared {
inner: self.inner.clone(),
waiter: next_clone_id,
}
}
}
impl<F> Drop for Shared<F> where F: Future {
fn drop(&mut self) {
let mut waiters = self.inner.notifier.waiters.lock().unwrap();
waiters.remove(&self.waiter);
}
}
impl Wake for Notifier {
fn wake(arc_self: &Arc<Self>) {
arc_self.state.compare_and_swap(POLLING, REPOLL, SeqCst);
let waiters = mem::replace(&mut *arc_self.waiters.lock().unwrap(), HashMap::new());
for (_, waiter) in waiters {
waiter.wake();
}
}
}
unsafe impl<F> Sync for Inner<F>
where F: Future + Send,
F::Item: Send + Sync,
F::Error: Send + Sync
{}
unsafe impl<F> Send for Inner<F>
where F: Future + Send,
F::Item: Send + Sync,
F::Error: Send + Sync
{}
impl<F> fmt::Debug for Inner<F>
where F: Future + fmt::Debug,
F::Item: fmt::Debug,
F::Error: fmt::Debug,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("Inner")
.finish()
}
}
#[derive(Clone, Debug)]
pub struct SharedItem<T> {
item: Arc<T>,
}
impl<T> SharedItem<T> {
pub fn into_inner(self) -> Arc<T> {
self.item
}
}
impl<T> ops::Deref for SharedItem<T> {
type Target = T;
fn deref(&self) -> &T {
&self.item.as_ref()
}
}
#[derive(Clone, Debug)]
pub struct SharedError<E> {
error: Arc<E>,
}
impl<E> SharedError<E> {
pub fn into_inner(self) -> Arc<E> {
self.error
}
}
impl<E> ops::Deref for SharedError<E> {
type Target = E;
fn deref(&self) -> &E {
&self.error.as_ref()
}
}
impl<E> fmt::Display for SharedError<E>
where E: fmt::Display,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.error.fmt(f)
}
}
impl<E> error::Error for SharedError<E>
where E: error::Error,
{
fn description(&self) -> &str {
self.error.description()
}
fn cause(&self) -> Option<&error::Error> {
self.error.cause()
}
}