wit_parser/
live.rs

1use crate::{
2    Function, FunctionKind, InterfaceId, Resolve, Type, TypeDef, TypeDefKind, TypeId, WorldId,
3    WorldItem,
4};
5use indexmap::IndexSet;
6
7#[derive(Default)]
8pub struct LiveTypes {
9    set: IndexSet<TypeId>,
10}
11
12impl LiveTypes {
13    pub fn iter(&self) -> impl Iterator<Item = TypeId> + '_ {
14        self.set.iter().copied()
15    }
16
17    pub fn len(&self) -> usize {
18        self.set.len()
19    }
20
21    pub fn add_interface(&mut self, resolve: &Resolve, iface: InterfaceId) {
22        self.visit_interface(resolve, iface);
23    }
24
25    pub fn add_world(&mut self, resolve: &Resolve, world: WorldId) {
26        self.visit_world(resolve, world);
27    }
28
29    pub fn add_world_item(&mut self, resolve: &Resolve, item: &WorldItem) {
30        self.visit_world_item(resolve, item);
31    }
32
33    pub fn add_func(&mut self, resolve: &Resolve, func: &Function) {
34        self.visit_func(resolve, func);
35    }
36
37    pub fn add_type_id(&mut self, resolve: &Resolve, ty: TypeId) {
38        self.visit_type_id(resolve, ty);
39    }
40
41    pub fn add_type(&mut self, resolve: &Resolve, ty: &Type) {
42        self.visit_type(resolve, ty);
43    }
44}
45
46impl TypeIdVisitor for LiveTypes {
47    fn before_visit_type_id(&mut self, id: TypeId) -> bool {
48        !self.set.contains(&id)
49    }
50
51    fn after_visit_type_id(&mut self, id: TypeId) {
52        assert!(self.set.insert(id));
53    }
54}
55
56/// Helper trait to walk the structure of a type and visit all `TypeId`s that
57/// it refers to, possibly transitively.
58pub trait TypeIdVisitor {
59    /// Callback invoked just before a type is visited.
60    ///
61    /// If this function returns `false` the type is not visited, otherwise it's
62    /// recursed into.
63    fn before_visit_type_id(&mut self, id: TypeId) -> bool {
64        let _ = id;
65        true
66    }
67
68    /// Callback invoked once a type is finished being visited.
69    fn after_visit_type_id(&mut self, id: TypeId) {
70        let _ = id;
71    }
72
73    fn visit_interface(&mut self, resolve: &Resolve, iface: InterfaceId) {
74        let iface = &resolve.interfaces[iface];
75        for (_, id) in iface.types.iter() {
76            self.visit_type_id(resolve, *id);
77        }
78        for (_, func) in iface.functions.iter() {
79            self.visit_func(resolve, func);
80        }
81    }
82
83    fn visit_world(&mut self, resolve: &Resolve, world: WorldId) {
84        let world = &resolve.worlds[world];
85        for (_, item) in world.imports.iter().chain(world.exports.iter()) {
86            self.visit_world_item(resolve, item);
87        }
88    }
89
90    fn visit_world_item(&mut self, resolve: &Resolve, item: &WorldItem) {
91        match item {
92            WorldItem::Interface { id, .. } => self.visit_interface(resolve, *id),
93            WorldItem::Function(f) => self.visit_func(resolve, f),
94            WorldItem::Type(t) => self.visit_type_id(resolve, *t),
95        }
96    }
97
98    fn visit_func(&mut self, resolve: &Resolve, func: &Function) {
99        match func.kind {
100            // This resource is live as it's attached to a static method but
101            // it's not guaranteed to be present in either params or results, so
102            // be sure to attach it here.
103            FunctionKind::Static(id) | FunctionKind::AsyncStatic(id) => {
104                self.visit_type_id(resolve, id)
105            }
106
107            // The resource these are attached to is in the params/results, so
108            // no need to re-add it here.
109            FunctionKind::Method(_)
110            | FunctionKind::AsyncMethod(_)
111            | FunctionKind::Constructor(_) => {}
112
113            FunctionKind::Freestanding | FunctionKind::AsyncFreestanding => {}
114        }
115
116        for (_, ty) in func.params.iter() {
117            self.visit_type(resolve, ty);
118        }
119        if let Some(ty) = &func.result {
120            self.visit_type(resolve, ty);
121        }
122    }
123
124    fn visit_type_id(&mut self, resolve: &Resolve, ty: TypeId) {
125        if self.before_visit_type_id(ty) {
126            self.visit_type_def(resolve, &resolve.types[ty]);
127            self.after_visit_type_id(ty);
128        }
129    }
130
131    fn visit_type_def(&mut self, resolve: &Resolve, ty: &TypeDef) {
132        match &ty.kind {
133            TypeDefKind::Type(t)
134            | TypeDefKind::List(t)
135            | TypeDefKind::Option(t)
136            | TypeDefKind::Future(Some(t))
137            | TypeDefKind::Stream(Some(t)) => self.visit_type(resolve, t),
138            TypeDefKind::Handle(handle) => match handle {
139                crate::Handle::Own(ty) => self.visit_type_id(resolve, *ty),
140                crate::Handle::Borrow(ty) => self.visit_type_id(resolve, *ty),
141            },
142            TypeDefKind::Resource => {}
143            TypeDefKind::Record(r) => {
144                for field in r.fields.iter() {
145                    self.visit_type(resolve, &field.ty);
146                }
147            }
148            TypeDefKind::Tuple(r) => {
149                for ty in r.types.iter() {
150                    self.visit_type(resolve, ty);
151                }
152            }
153            TypeDefKind::Variant(v) => {
154                for case in v.cases.iter() {
155                    if let Some(ty) = &case.ty {
156                        self.visit_type(resolve, ty);
157                    }
158                }
159            }
160            TypeDefKind::Result(r) => {
161                if let Some(ty) = &r.ok {
162                    self.visit_type(resolve, ty);
163                }
164                if let Some(ty) = &r.err {
165                    self.visit_type(resolve, ty);
166                }
167            }
168            TypeDefKind::Flags(_)
169            | TypeDefKind::Enum(_)
170            | TypeDefKind::Future(None)
171            | TypeDefKind::Stream(None) => {}
172            TypeDefKind::Unknown => unreachable!(),
173        }
174    }
175
176    fn visit_type(&mut self, resolve: &Resolve, ty: &Type) {
177        match ty {
178            Type::Id(id) => self.visit_type_id(resolve, *id),
179            _ => {}
180        }
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::{LiveTypes, Resolve};
187
188    fn live(wit: &str, ty: &str) -> Vec<String> {
189        let mut resolve = Resolve::default();
190        resolve.push_str("test.wit", wit).unwrap();
191        let (_, interface) = resolve.interfaces.iter().next_back().unwrap();
192        let ty = interface.types[ty];
193        let mut live = LiveTypes::default();
194        live.add_type_id(&resolve, ty);
195
196        live.iter()
197            .filter_map(|ty| resolve.types[ty].name.clone())
198            .collect()
199    }
200
201    #[test]
202    fn no_deps() {
203        let types = live(
204            "
205                package foo:bar;
206
207                interface foo {
208                    type t = u32;
209                }
210            ",
211            "t",
212        );
213        assert_eq!(types, ["t"]);
214    }
215
216    #[test]
217    fn one_dep() {
218        let types = live(
219            "
220                package foo:bar;
221
222                interface foo {
223                    type t = u32;
224                    type u = t;
225                }
226            ",
227            "u",
228        );
229        assert_eq!(types, ["t", "u"]);
230    }
231
232    #[test]
233    fn chain() {
234        let types = live(
235            "
236                package foo:bar;
237
238                interface foo {
239                    resource t1;
240                    record t2 {
241                        x: t1,
242                    }
243                    variant t3 {
244                        x(t2),
245                    }
246                    flags t4 { a }
247                    enum t5 { a }
248                    type t6 = tuple<t5, t4, t3>;
249                }
250            ",
251            "t6",
252        );
253        assert_eq!(types, ["t5", "t4", "t1", "t2", "t3", "t6"]);
254    }
255}