err_derive/
lib.rs

1//! # `err-derive`
2//!
3//! ## Deriving error causes / sources
4//!
5//! Add an `#[error(source)]` attribute to the field:
6//!
7//! ```
8//! use std::io;
9//! use err_derive::Error;
10//!
11//! /// `MyError::source` will return a reference to the `io_error` field
12//! #[derive(Debug, Error)]
13//! #[error(display = "An error occurred.")]
14//! struct MyError {
15//!     #[error(source)]
16//!     io_error: io::Error,
17//! }
18//! #
19//! # fn main() {}
20//! ```
21//!
22//! ## Automatic From implementations
23//!
24//! From<Other> will be automatically implemented when there is a single field
25//! in the enum variant or struct, and that field has the `#[source]` attribute.
26//!
27//! In cases where multiple enum variants have a `#[source]` field of the same type
28//! all but one of the variants need to be opted-out from the automatic From implementation (see
29//! below).
30//!
31//! ```
32//! use std::io;
33//! use err_derive::Error;
34//!
35//! /// `From<io::Error>` will be implemented for `MyError`
36//! #[derive(Debug, Error)]
37//! #[error(display = "An error occurred.")]
38//! struct MyError {
39//!     #[error(from)]
40//!     io_error: io::Error,
41//! }
42//! #
43//! # fn main() {}
44//! ```
45//!
46//! ### Opt out of From implementation
47//!
48//! Use the `#[no_from]` attribute on either the enum or a single variant to opt-out of the
49//! automatic From implementation.
50//!
51//! When `#[no_from]` is set on the enum, you can opt-in individual variants by using `#[from]`
52//!
53//! ```rust
54//! use err_derive::Error;
55//! use std::{io, fmt};
56//!
57//! #[derive(Debug, Error)]
58//! enum ClientError {
59//!     #[error(display = "regular bad io error {}", _0)]
60//!     Io(#[source] io::Error),
61//!     #[error(display = "extra bad io error {}", _0)]
62//!     // Without #[no_from], this From impl would conflict with the normal Io error
63//!     ReallyBadIo(#[error(source, no_from)] io::Error)
64//! }
65//!
66//! #[derive(Debug, Error)]
67//! #[error(no_from)] // Don't impl From for any variants by default
68//! enum InnerError {
69//!     #[error(display = "an error")]
70//!     Io(#[source] io::Error),
71//!     #[error(display = "an error")]
72//!     // Opt-in impl From for a single variant
73//!     Formatting(#[error(source, from)] fmt::Error)
74//! }
75//! ```
76//!
77//! ### Auto-boxing From implementation
78//!
79//! If an enum single variant has `Box<T>` as its type, a `From`
80//! implementation for `T` will be automatically be generated that wraps it in
81//! `Box::new`.
82//!
83//! ```rust
84//! use err_derive::Error;
85//! use std::{io, fmt};
86//!
87//! #[derive(Debug, Error)]
88//! enum ClientError {
89//!     #[error(display = "io error in a box{}", _0)]
90//!     Io(#[error(source)] Box<io::Error>),
91//! }
92//! ```
93//!
94//! ## Formatting fields
95//!
96//! ```rust
97//! use std::path::PathBuf;
98//! use err_derive::Error;
99//!
100//! #[derive(Debug, Error)]
101//! pub enum FormatError {
102//!     #[error(display = "invalid header (expected: {:?}, got: {:?})", expected, found)]
103//!     InvalidHeader {
104//!         expected: String,
105//!         found: String,
106//!     },
107//!     // Note that tuple fields need to be prefixed with `_`
108//!     #[error(display = "missing attribute: {:?}", _0)]
109//!     MissingAttribute(String),
110//!
111//! }
112//!
113//! #[derive(Debug, Error)]
114//! pub enum LoadingError {
115//!     #[error(display = "could not decode file")]
116//!     FormatError(#[error(source)] #[error(from)] FormatError),
117//!     #[error(display = "could not find file: {:?}", path)]
118//!     NotFound { path: PathBuf },
119//! }
120//! #
121//! # fn main() {}
122//! ```
123//!
124//! ## Printing the error
125//!
126//! ```
127//! use std::error::Error;
128//!
129//! fn print_error(e: &dyn Error) {
130//!     eprintln!("error: {}", e);
131//!     let mut cause = e.source();
132//!     while let Some(e) = cause {
133//!         eprintln!("caused by: {}", e);
134//!         cause = e.source();
135//!     }
136//! }
137//! ```
138//!
139
140extern crate proc_macro;
141extern crate syn;
142
143use quote::quote;
144use synstructure::decl_derive;
145
146use proc_macro::TokenStream;
147use proc_macro2::TokenStream as TokenStream2;
148use syn::spanned::Spanned;
149
150extern crate proc_macro_error;
151
152use proc_macro_error::{abort, proc_macro_error};
153use syn::Attribute;
154
155decl_derive!([Error, attributes(error, source, cause)] => #[proc_macro_error] error_derive);
156
157fn error_derive(s: synstructure::Structure) -> TokenStream {
158    let source_body = s.each_variant(|v| {
159        if let Some(source) = v.bindings().iter().find(|binding| {
160            has_attr(&binding.ast().attrs, "source") || has_attr(&binding.ast().attrs, "cause")
161        })
162        // TODO https://github.com/rust-lang/rust/issues/54140 deprecate cause with warning
163        {
164            quote!(return Some(#source as & ::std::error::Error))
165        } else {
166            quote!(return None)
167        }
168    });
169
170    let source_method = quote! {
171        #[allow(unreachable_code)]
172        fn source(&self) -> ::std::option::Option<&(::std::error::Error + 'static)> {
173            match *self { #source_body }
174            None
175        }
176    };
177
178    let cause_method = quote! {
179        #[allow(unreachable_code)]
180        fn cause(&self) -> ::std::option::Option<& ::std::error::Error> {
181            match *self { #source_body }
182            None
183        }
184    };
185
186    let error = if cfg!(feature = "std") {
187        s.unbound_impl(
188            quote!(::std::error::Error),
189            quote! {
190                fn description(&self) -> &str {
191                    "description() is deprecated; use Display"
192                }
193
194                #cause_method
195                #source_method
196            },
197        )
198    } else {
199        quote!()
200    };
201
202    let display_body = display_body(&s);
203    let display = s.unbound_impl(
204        quote!(::core::fmt::Display),
205        quote! {
206            #[allow(unreachable_code)]
207            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
208                match *self { #display_body }
209                write!(f, "An error has occurred.")
210            }
211        },
212    );
213
214    let from = from_body(&s);
215
216    quote!(#error #display #from).into()
217}
218
219fn display_body(s: &synstructure::Structure) -> TokenStream2 {
220    s.each_variant(|v| {
221        let span = v.ast().ident.span();
222        let msg = match find_error_msg(&v.ast().attrs) {
223            Some(msg) => msg,
224            None => abort!(span, "Variant is missing display attribute."),
225        };
226        if msg.nested.is_empty() {
227            abort!(span, "Expected at least one argument to error attribute");
228        }
229
230        let format_string = match msg.nested[0] {
231            syn::NestedMeta::Meta(syn::Meta::NameValue(ref nv))
232                if nv
233                    .path
234                    .get_ident()
235                    .map_or(false, |ident| ident == "display") =>
236            {
237                nv.lit.clone()
238            }
239            _ => abort!(
240                msg.nested.span(),
241                "Error attribute must begin `display = \"\"` to control the Display message."
242            ),
243        };
244        let args = msg.nested.iter().skip(1).map(|arg| match *arg {
245            syn::NestedMeta::Lit(syn::Lit::Int(ref i)) => {
246                let bi = &v.bindings()[i
247                    .base10_parse::<usize>()
248                    .unwrap_or_else(|_| abort!(i.span(), "integer literal overflows usize"))];
249                quote!(#bi)
250            }
251            syn::NestedMeta::Meta(syn::Meta::Path(ref path)) => {
252                let id = match path.get_ident() {
253                    Some(id) => id,
254                    // Allows std::u8::MAX (for example)
255                    None => return quote!(#arg),
256                };
257                let id_s = id.to_string();
258                if id_s.starts_with('_') {
259                    if let Ok(idx) = id_s[1..].parse::<usize>() {
260                        let bi = match v.bindings().get(idx) {
261                            Some(bi) => bi,
262                            None => {
263                                abort!(
264                                    id.span(),
265                                    "display attempted to access field `{}` in `{}::{}` which \
266                                     does not exist (there {} {} field{})",
267                                    idx,
268                                    s.ast().ident,
269                                    v.ast().ident,
270                                    if v.bindings().len() != 1 { "are" } else { "is" },
271                                    v.bindings().len(),
272                                    if v.bindings().len() != 1 { "s" } else { "" }
273                                );
274                            }
275                        };
276                        return quote!(#bi);
277                    }
278                }
279                for bi in v.bindings() {
280                    if bi.ast().ident.as_ref() == Some(id) {
281                        return quote!(#bi);
282                    }
283                }
284                // Arg is not a field - might be in global scope
285                return quote!(#id);
286            }
287            // Allows u8::max_value() (for example)
288            syn::NestedMeta::Meta(syn::Meta::List(ref list)) => return quote!(#list),
289            _ => abort!(msg.nested.span(), "Invalid argument to error attribute!"),
290        });
291
292        quote! {
293            return write!(f, #format_string #(, #args)*)
294        }
295    })
296}
297
298fn find_error_msg(attrs: &[syn::Attribute]) -> Option<syn::MetaList> {
299    let mut error_msg = None;
300    for attr in attrs {
301        if let Ok(meta) = attr.parse_meta() {
302            if meta
303                .path()
304                .get_ident()
305                .map_or(false, |ident| ident == "error")
306            {
307                let span = attr.span();
308                if error_msg.is_some() {
309                    abort!(span, "Cannot have two display attributes")
310                } else if let syn::Meta::List(list) = meta {
311                    error_msg = Some(list);
312                } else {
313                    abort!(span, "error attribute must take a list in parentheses")
314                }
315            }
316        }
317    }
318    error_msg
319}
320
321fn has_attr(attributes: &[Attribute], attr_name: &str) -> bool {
322    let mut found_attr = false;
323    for attr in attributes {
324        if let Ok(meta) = attr.parse_meta() {
325            if meta
326                .path()
327                .get_ident()
328                .map_or(false, |ident| ident == attr_name)
329            {
330                if found_attr {
331                    abort!(attr.span(), "Cannot have two `{}` attributes", attr_name);
332                }
333                found_attr = true;
334            }
335
336            if meta
337                .path()
338                .get_ident()
339                .map_or(false, |ident| ident == "error")
340            {
341                if let syn::Meta::List(ref list) = meta {
342                    for pair in list.nested.iter() {
343                        if let syn::NestedMeta::Meta(syn::Meta::Path(path)) = pair {
344                            let is_attr_name = |ident: &syn::Ident| {
345                                ident.to_string().split(", ").any(|part| part == attr_name)
346                            };
347
348                            if path.get_ident().map_or(false, is_attr_name) {
349                                if found_attr {
350                                    abort!(
351                                        path.span(),
352                                        "Cannot have two `{}` attributes",
353                                        attr_name
354                                    );
355                                }
356                                found_attr = true;
357                            }
358                        }
359                    }
360                }
361            }
362        }
363    }
364    found_attr
365}
366
367// If `ty` is a `Box<T>`, return `Some(T)`. Otherwise return `None`.
368fn box_content_type(ty: &syn::Type) -> Option<&syn::Type> {
369    let path = match ty {
370        syn::Type::Path(p) => p,
371        _ => return None,
372    };
373
374    if path.path.segments.len() != 1 {
375        return None;
376    }
377
378    let seg = &path.path.segments[0];
379    if seg.ident != "Box" {
380        return None;
381    }
382
383    let args = match &seg.arguments {
384        syn::PathArguments::AngleBracketed(args) => args,
385        _ => return None,
386    };
387
388    if args.args.len() != 1 {
389        return None;
390    }
391
392    let inner_ty = match &args.args[0] {
393        syn::GenericArgument::Type(ty) => ty,
394        _ => return None,
395    };
396
397    Some(inner_ty)
398}
399
400fn from_body(s: &synstructure::Structure) -> TokenStream2 {
401    let default_from = !has_attr(&s.ast().attrs, "no_from");
402    let mut from_types = Vec::new();
403    let froms = s.variants().iter().flat_map(|v| {
404        let span = v.ast().ident.span();
405        if let Some((from, is_explicit)) = v.bindings().iter().find_map(|binding| {
406            let is_explicit = has_attr(&binding.ast().attrs, "from");
407            let is_source = has_attr(&binding.ast().attrs, "source");
408            // TODO https://github.com/rust-lang/rust/issues/54140 deprecate cause with warning
409            let is_cause = has_attr(&binding.ast().attrs, "cause");
410            let exclude = has_attr(&binding.ast().attrs, "no_from");
411
412            if is_source && is_cause {
413                abort!(
414                    span,
415                    "#[error(cause)] is deprecated, use #[error(source)] instead"
416                )
417            }
418
419            let is_source = is_source || is_cause;
420
421            if ((default_from && is_source) || is_explicit) && !exclude {
422                Some((binding, is_explicit))
423            } else {
424                None
425            }
426        }) {
427            if v.bindings().len() > 1 {
428                if is_explicit {
429                    abort!(
430                        span,
431                        "Variants containing `from` can only contain a single field"
432                    );
433                } else {
434                    return vec![];
435                }
436            }
437
438            let from_ident = &from.ast().ty;
439
440            if from_types
441                .iter()
442                .any(|existing_from_type| *existing_from_type == from_ident)
443            {
444                abort!(
445                    from_ident.span(),
446                    "`from` can only be applied for a type once{}",
447                    if is_explicit {
448                        ""
449                    } else {
450                        ", hint: use #[error(no_from)] to disable automatic From derive"
451                    }
452                );
453            }
454
455            from_types.push(from_ident);
456            let construct = v.construct(|_, _| quote! {from});
457
458            let mut from_impls = vec![s.unbound_impl(
459                quote! {::core::convert::From<#from_ident>},
460                quote! {
461                    fn from(from: #from_ident) -> Self {
462                        #construct
463                    }
464                },
465            )];
466
467            if let Some(box_content_ty) = box_content_type(from_ident) {
468                from_impls.push(s.unbound_impl(
469                    quote! {::core::convert::From<#box_content_ty>},
470                    quote! {
471                        fn from(from: #box_content_ty) -> Self {
472                            Box::new(from).into()
473                        }
474                    },
475                ));
476            }
477
478            from_impls
479        } else {
480            vec![]
481        }
482    });
483
484    quote! {
485        #(#froms)*
486    }
487}