futures_select_macro/
lib.rs

1//! The futures-rs `select! macro implementation.
2
3#![recursion_limit="128"]
4#![warn(rust_2018_idioms, unreachable_pub)]
5// It cannot be included in the published code because this lints have false positives in the minimum required version.
6#![cfg_attr(test, warn(single_use_lifetimes))]
7#![warn(clippy::all)]
8
9extern crate proc_macro;
10
11use proc_macro::TokenStream;
12use proc_macro2::Span;
13use proc_macro_hack::proc_macro_hack;
14use quote::{format_ident, quote};
15use syn::{parenthesized, parse_quote, Expr, Ident, Pat, Token};
16use syn::parse::{Parse, ParseStream};
17
18mod kw {
19    syn::custom_keyword!(complete);
20    syn::custom_keyword!(futures_crate_path);
21}
22
23struct Select {
24    futures_crate_path: Option<syn::Path>,
25    // span of `complete`, then expression after `=> ...`
26    complete: Option<Expr>,
27    default: Option<Expr>,
28    normal_fut_exprs: Vec<Expr>,
29    normal_fut_handlers: Vec<(Pat, Expr)>,
30}
31
32#[allow(clippy::large_enum_variant)]
33enum CaseKind {
34    Complete,
35    Default,
36    Normal(Pat, Expr),
37}
38
39impl Parse for Select {
40    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
41        let mut select = Select {
42            futures_crate_path: None,
43            complete: None,
44            default: None,
45            normal_fut_exprs: vec![],
46            normal_fut_handlers: vec![],
47        };
48
49        // When `futures_crate_path(::path::to::futures::lib)` is provided,
50        // it sets the path through which futures library functions will be
51        // accessed.
52        if input.peek(kw::futures_crate_path) {
53            input.parse::<kw::futures_crate_path>()?;
54            let content;
55            parenthesized!(content in input);
56            select.futures_crate_path = Some(content.parse()?);
57        }
58
59        while !input.is_empty() {
60            let case_kind = if input.peek(kw::complete) {
61                // `complete`
62                if select.complete.is_some() {
63                    return Err(input.error("multiple `complete` cases found, only one allowed"));
64                }
65                input.parse::<kw::complete>()?;
66                CaseKind::Complete
67            } else if input.peek(Token![default]) {
68                // `default`
69                if select.default.is_some() {
70                    return Err(input.error("multiple `default` cases found, only one allowed"));
71                }
72                input.parse::<Ident>()?;
73                CaseKind::Default
74            } else {
75                // `<pat> = <expr>`
76                let pat = input.parse()?;
77                input.parse::<Token![=]>()?;
78                let expr = input.parse()?;
79                CaseKind::Normal(pat, expr)
80            };
81
82            // `=> <expr>`
83            input.parse::<Token![=>]>()?;
84            let expr = input.parse::<Expr>()?;
85
86            // Commas after the expression are only optional if it's a `Block`
87            // or it is the last branch in the `match`.
88            let is_block = match expr { Expr::Block(_) => true, _ => false };
89            if is_block || input.is_empty() {
90                input.parse::<Option<Token![,]>>()?;
91            } else {
92                input.parse::<Token![,]>()?;
93            }
94
95            match case_kind {
96                CaseKind::Complete => select.complete = Some(expr),
97                CaseKind::Default => select.default = Some(expr),
98                CaseKind::Normal(pat, fut_expr) => {
99                    select.normal_fut_exprs.push(fut_expr);
100                    select.normal_fut_handlers.push((pat, expr));
101                },
102            }
103        }
104
105        Ok(select)
106    }
107}
108
109// Enum over all the cases in which the `select!` waiting has completed and the result
110// can be processed.
111//
112// `enum __PrivResult<_1, _2, ...> { _1(_1), _2(_2), ..., Complete }`
113fn declare_result_enum(
114    result_ident: Ident,
115    variants: usize,
116    complete: bool,
117    span: Span
118) -> (Vec<Ident>, syn::ItemEnum) {
119    // "_0", "_1", "_2"
120    let variant_names: Vec<Ident> =
121        (0..variants)
122            .map(|num| format_ident!("_{}", num, span = span))
123            .collect();
124
125    let type_parameters = &variant_names;
126    let variants = &variant_names;
127
128    let complete_variant = if complete {
129        Some(quote!(Complete))
130    } else {
131        None
132    };
133
134    let enum_item = parse_quote! {
135        enum #result_ident<#(#type_parameters,)*> {
136            #(
137                #variants(#type_parameters),
138            )*
139            #complete_variant
140        }
141    };
142
143    (variant_names, enum_item)
144}
145
146/// The `select!` macro.
147#[proc_macro_hack]
148pub fn select(input: TokenStream) -> TokenStream {
149    let parsed = syn::parse_macro_input!(input as Select);
150
151    let futures_crate: syn::Path = parsed.futures_crate_path.unwrap_or_else(|| parse_quote!(::futures_util));
152
153    // should be def_site, but that's unstable
154    let span = Span::call_site();
155
156    let enum_ident = Ident::new("__PrivResult", span);
157
158    let (variant_names, enum_item) = declare_result_enum(
159        enum_ident.clone(),
160        parsed.normal_fut_exprs.len(),
161        parsed.complete.is_some(),
162        span,
163    );
164
165    // bind non-`Ident` future exprs w/ `let`
166    let mut future_let_bindings = Vec::with_capacity(parsed.normal_fut_exprs.len());
167    let bound_future_names: Vec<_> = parsed.normal_fut_exprs.into_iter()
168        .zip(variant_names.iter())
169        .map(|(expr, variant_name)| {
170            match expr {
171                // Don't bind futures that are already a path.
172                // This prevents creating redundant stack space
173                // for them.
174                syn::Expr::Path(path) => path,
175                _ => {
176                    future_let_bindings.push(quote! {
177                        let mut #variant_name = #expr;
178                    });
179                    parse_quote! { #variant_name }
180                }
181            }
182        })
183        .collect();
184
185    // For each future, make an `&mut dyn FnMut(&mut Context<'_>) -> Option<Poll<__PrivResult<...>>`
186    // to use for polling that individual future. These will then be put in an array.
187    let poll_functions = bound_future_names.iter().zip(variant_names.iter())
188        .map(|(bound_future_name, variant_name)| {
189            quote! {
190                let mut #variant_name = |__cx: &mut #futures_crate::task::Context<'_>| {
191                    if #futures_crate::future::FusedFuture::is_terminated(&#bound_future_name) {
192                        None
193                    } else {
194                        Some(#futures_crate::future::FutureExt::poll_unpin(
195                            &mut #bound_future_name,
196                            __cx,
197                        ).map(#enum_ident::#variant_name))
198                    }
199                };
200                let #variant_name: &mut dyn FnMut(
201                    &mut #futures_crate::task::Context<'_>
202                ) -> Option<#futures_crate::task::Poll<_>> = &mut #variant_name;
203            }
204        });
205
206    let none_polled = if parsed.complete.is_some() {
207        quote! {
208            #futures_crate::task::Poll::Ready(#enum_ident::Complete)
209        }
210    } else {
211        quote! {
212            panic!("all futures in select! were completed,\
213                    but no `complete =>` handler was provided")
214        }
215    };
216
217    let branches = parsed.normal_fut_handlers.into_iter()
218        .zip(variant_names.iter())
219        .map(|((pat, expr), variant_name)| {
220            quote! {
221                #enum_ident::#variant_name(#pat) => { #expr },
222            }
223        });
224    let branches = quote! { #( #branches )* };
225
226    let complete_branch = parsed.complete.map(|complete_expr| {
227        quote! {
228            #enum_ident::Complete => { #complete_expr },
229        }
230    });
231
232    let branches = quote! {
233        #branches
234        #complete_branch
235    };
236
237    let await_and_select = if let Some(default_expr) = parsed.default {
238        quote! {
239            if let #futures_crate::task::Poll::Ready(x) =
240                __poll_fn(&mut #futures_crate::task::Context::from_waker(
241                    #futures_crate::task::noop_waker_ref()
242                ))
243            {
244                match x { #branches }
245            } else {
246                #default_expr
247            };
248        }
249    } else {
250        quote! {
251            match #futures_crate::future::poll_fn(__poll_fn).await {
252                #branches
253            }
254        }
255    };
256
257    TokenStream::from(quote! { {
258        #enum_item
259        #( #future_let_bindings )*
260
261        let mut __poll_fn = |__cx: &mut #futures_crate::task::Context<'_>| {
262            let mut __any_polled = false;
263
264            #( #poll_functions )*
265
266            let mut __select_arr = [#( #variant_names ),*];
267            #futures_crate::async_await::shuffle(&mut __select_arr);
268            for poller in &mut __select_arr {
269                let poller: &mut &mut dyn FnMut(
270                    &mut #futures_crate::task::Context<'_>
271                ) -> Option<#futures_crate::task::Poll<_>> = poller;
272                match poller(__cx) {
273                    Some(x @ #futures_crate::task::Poll::Ready(_)) =>
274                        return x,
275                    Some(#futures_crate::task::Poll::Pending) => {
276                        __any_polled = true;
277                    }
278                    None => {}
279                }
280            }
281
282            if !__any_polled {
283                #none_polled
284            } else {
285                #futures_crate::task::Poll::Pending
286            }
287        };
288
289        #await_and_select
290    } })
291}