sway_core/language/ty/
module.rs

1use std::sync::Arc;
2
3use sway_error::handler::{ErrorEmitted, Handler};
4use sway_types::Span;
5
6use crate::{
7    decl_engine::{DeclEngine, DeclEngineGet, DeclId, DeclRef, DeclRefFunction},
8    language::{ty::*, HasModule, HasSubmodules, ModName},
9    transform::{self, AllowDeprecatedState},
10    Engines,
11};
12
13#[derive(Clone, Debug)]
14pub struct TyModule {
15    pub span: Span,
16    pub submodules: Vec<(ModName, TySubmodule)>,
17    pub all_nodes: Vec<TyAstNode>,
18    pub attributes: transform::AttributesMap,
19}
20
21impl TyModule {
22    /// Iter on all constants in this module, which means, globals constants and
23    /// local constants, but it does not enter into submodules.
24    pub fn iter_constants(&self, de: &DeclEngine) -> Vec<ConstantDecl> {
25        fn inside_code_block(de: &DeclEngine, block: &TyCodeBlock) -> Vec<ConstantDecl> {
26            block
27                .contents
28                .iter()
29                .flat_map(|node| inside_ast_node(de, node))
30                .collect::<Vec<_>>()
31        }
32
33        fn inside_ast_node(de: &DeclEngine, node: &TyAstNode) -> Vec<ConstantDecl> {
34            match &node.content {
35                TyAstNodeContent::Declaration(decl) => match decl {
36                    TyDecl::ConstantDecl(decl) => {
37                        vec![decl.clone()]
38                    }
39                    TyDecl::FunctionDecl(decl) => {
40                        let decl = de.get(&decl.decl_id);
41                        inside_code_block(de, &decl.body)
42                    }
43                    TyDecl::ImplSelfOrTrait(decl) => {
44                        let decl = de.get(&decl.decl_id);
45                        decl.items
46                            .iter()
47                            .flat_map(|item| match item {
48                                TyTraitItem::Fn(decl) => {
49                                    let decl = de.get(decl.id());
50                                    inside_code_block(de, &decl.body)
51                                }
52                                TyTraitItem::Constant(decl) => {
53                                    vec![ConstantDecl {
54                                        decl_id: *decl.id(),
55                                    }]
56                                }
57                                _ => vec![],
58                            })
59                            .collect()
60                    }
61                    _ => vec![],
62                },
63                _ => vec![],
64            }
65        }
66
67        self.all_nodes
68            .iter()
69            .flat_map(|node| inside_ast_node(de, node))
70            .collect::<Vec<_>>()
71    }
72
73    /// Recursively find all test function declarations.
74    pub fn test_fns_recursive<'a: 'b, 'b>(
75        &'b self,
76        decl_engine: &'a DeclEngine,
77    ) -> impl 'b + Iterator<Item = (Arc<TyFunctionDecl>, DeclRefFunction)> {
78        self.submodules_recursive()
79            .flat_map(|(_, submod)| submod.module.test_fns(decl_engine))
80            .chain(self.test_fns(decl_engine))
81    }
82}
83
84#[derive(Clone, Debug)]
85pub struct TySubmodule {
86    pub module: Arc<TyModule>,
87    pub mod_name_span: Span,
88}
89
90/// Iterator type for iterating over submodules.
91///
92/// Used rather than `impl Iterator` to enable recursive submodule iteration.
93pub struct SubmodulesRecursive<'module> {
94    submods: std::slice::Iter<'module, (ModName, TySubmodule)>,
95    current: Option<(
96        &'module (ModName, TySubmodule),
97        Box<SubmodulesRecursive<'module>>,
98    )>,
99}
100
101impl TyModule {
102    /// An iterator yielding all submodules recursively, depth-first.
103    pub fn submodules_recursive(&self) -> SubmodulesRecursive {
104        SubmodulesRecursive {
105            submods: self.submodules.iter(),
106            current: None,
107        }
108    }
109
110    /// All test functions within this module.
111    pub fn test_fns<'a: 'b, 'b>(
112        &'b self,
113        decl_engine: &'a DeclEngine,
114    ) -> impl 'b + Iterator<Item = (Arc<TyFunctionDecl>, DeclRefFunction)> {
115        self.all_nodes.iter().filter_map(|node| {
116            if let TyAstNodeContent::Declaration(TyDecl::FunctionDecl(FunctionDecl { decl_id })) =
117                &node.content
118            {
119                let fn_decl = decl_engine.get_function(decl_id);
120                let name = fn_decl.name.clone();
121                let span = fn_decl.span.clone();
122                if fn_decl.is_test() {
123                    return Some((fn_decl, DeclRef::new(name, *decl_id, span)));
124                }
125            }
126            None
127        })
128    }
129
130    /// All contract functions within this module.
131    pub fn contract_fns<'a: 'b, 'b>(
132        &'b self,
133        engines: &'a Engines,
134    ) -> impl 'b + Iterator<Item = DeclId<TyFunctionDecl>> {
135        self.all_nodes
136            .iter()
137            .flat_map(move |node| node.contract_fns(engines))
138    }
139
140    /// All contract supertrait functions within this module.
141    pub fn contract_supertrait_fns<'a: 'b, 'b>(
142        &'b self,
143        engines: &'a Engines,
144    ) -> impl 'b + Iterator<Item = DeclId<TyFunctionDecl>> {
145        self.all_nodes
146            .iter()
147            .flat_map(move |node| node.contract_supertrait_fns(engines))
148    }
149
150    pub(crate) fn check_deprecated(
151        &self,
152        engines: &Engines,
153        handler: &Handler,
154        allow_deprecated: &mut AllowDeprecatedState,
155    ) {
156        for (_, submodule) in self.submodules.iter() {
157            submodule
158                .module
159                .check_deprecated(engines, handler, allow_deprecated);
160        }
161
162        for node in self.all_nodes.iter() {
163            node.check_deprecated(engines, handler, allow_deprecated);
164        }
165    }
166
167    pub(crate) fn check_recursive(
168        &self,
169        engines: &Engines,
170        handler: &Handler,
171    ) -> Result<(), ErrorEmitted> {
172        handler.scope(|handler| {
173            for (_, submodule) in self.submodules.iter() {
174                let _ = submodule.module.check_recursive(engines, handler);
175            }
176
177            for node in self.all_nodes.iter() {
178                let _ = node.check_recursive(engines, handler);
179            }
180
181            Ok(())
182        })
183    }
184}
185
186impl<'module> Iterator for SubmodulesRecursive<'module> {
187    type Item = &'module (ModName, TySubmodule);
188    fn next(&mut self) -> Option<Self::Item> {
189        loop {
190            self.current = match self.current.take() {
191                None => match self.submods.next() {
192                    None => return None,
193                    Some(submod) => {
194                        Some((submod, Box::new(submod.1.module.submodules_recursive())))
195                    }
196                },
197                Some((submod, mut submods)) => match submods.next() {
198                    Some(next) => {
199                        self.current = Some((submod, submods));
200                        return Some(next);
201                    }
202                    None => return Some(submod),
203                },
204            }
205        }
206    }
207}
208
209impl HasModule<TyModule> for TySubmodule {
210    fn module(&self) -> &TyModule {
211        &self.module
212    }
213}
214
215impl HasSubmodules<TySubmodule> for TyModule {
216    fn submodules(&self) -> &[(ModName, TySubmodule)] {
217        &self.submodules
218    }
219}