cairo_lang_starknet/plugin/
events.rs

1use cairo_lang_defs::db::get_all_path_leaves;
2use cairo_lang_defs::plugin::PluginDiagnostic;
3use cairo_lang_starknet_classes::abi::EventFieldKind;
4use cairo_lang_syntax::node::db::SyntaxGroup;
5use cairo_lang_syntax::node::helpers::{GetIdentifier, QueryAttrs};
6use cairo_lang_syntax::node::{Terminal, TypedStablePtr, TypedSyntaxNode, ast};
7use const_format::formatcp;
8use smol_str::SmolStr;
9
10use super::consts::{EVENT_ATTR, EVENT_TRAIT, EVENT_TYPE_NAME};
11use super::starknet_module::StarknetModuleKind;
12
13/// Generated auxiliary data for the `#[derive(starknet::Event)]` attribute.
14#[derive(Clone, Debug, PartialEq, Eq)]
15pub enum EventData {
16    Struct { members: Vec<(SmolStr, EventFieldKind)> },
17    Enum { variants: Vec<(SmolStr, EventFieldKind)> },
18}
19
20/// The code for an empty event.
21pub const EMPTY_EVENT_CODE: &str = formatcp! {"\
22#[{EVENT_ATTR}]
23#[derive(Drop, {EVENT_TRAIT})]
24pub enum {EVENT_TYPE_NAME} {{}}
25"};
26
27/// Checks whether the given item is a starknet event, and if so - makes sure it's valid and returns
28/// its variants. Returns None if it's not a starknet event.
29pub fn get_starknet_event_variants(
30    db: &dyn SyntaxGroup,
31    diagnostics: &mut Vec<PluginDiagnostic>,
32    item: &ast::ModuleItem,
33    module_kind: StarknetModuleKind,
34) -> Option<Vec<SmolStr>> {
35    let (has_event_name, stable_ptr, variants) = match item {
36        ast::ModuleItem::Struct(strct) => (
37            strct.name(db).text(db) == EVENT_TYPE_NAME,
38            strct.name(db).stable_ptr().untyped(),
39            vec![],
40        ),
41        ast::ModuleItem::Enum(enm) => {
42            let has_event_name = enm.name(db).text(db) == EVENT_TYPE_NAME;
43            let variants = if has_event_name {
44                enm.variants(db).elements(db).into_iter().map(|v| v.name(db).text(db)).collect()
45            } else {
46                vec![]
47            };
48            (has_event_name, enm.name(db).stable_ptr().untyped(), variants)
49        }
50        ast::ModuleItem::Use(item) => {
51            for leaf in get_all_path_leaves(db, item) {
52                let stable_ptr = &leaf.stable_ptr();
53                if stable_ptr.identifier(db) == EVENT_TYPE_NAME {
54                    if !item.has_attr(db, EVENT_ATTR) {
55                        diagnostics.push(PluginDiagnostic::error(
56                            stable_ptr.untyped(),
57                            format!(
58                                "{} type that is named `{EVENT_TYPE_NAME}` must be marked with \
59                                 #[{EVENT_ATTR}].",
60                                module_kind.to_str_capital()
61                            ),
62                        ));
63                    }
64                    return Some(vec![]);
65                }
66            }
67            return None;
68        }
69        _ => return None,
70    };
71    let has_event_attr = item.has_attr(db, EVENT_ATTR);
72
73    match (has_event_attr, has_event_name) {
74        (true, false) => {
75            diagnostics.push(PluginDiagnostic::error(
76                stable_ptr,
77                format!(
78                    "{} type that is marked with #[{EVENT_ATTR}] must be named \
79                     `{EVENT_TYPE_NAME}`.",
80                    module_kind.to_str_capital()
81                ),
82            ));
83            None
84        }
85        (false, true) => {
86            diagnostics.push(PluginDiagnostic::error(
87                stable_ptr,
88                format!(
89                    "{} type that is named `{EVENT_TYPE_NAME}` must be marked with \
90                     #[{EVENT_ATTR}].",
91                    module_kind.to_str_capital()
92                ),
93            ));
94            // The attribute is missing, but this counts as an event - we can't create another
95            // (empty) event.
96            Some(variants)
97        }
98        (true, true) => Some(variants),
99        (false, false) => None,
100    }
101}