1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
use rlua::prelude::*;
use rlua::{UserDataMethods, UserData, MetaMethod, Value, Table, Lua, FromLua};
use std::collections::HashSet;

#[derive(Clone)]
struct StringSet (HashSet<String>);

impl UserData for StringSet {
    fn add_methods<'lua, M: UserDataMethods<'lua, Self>>(methods: &mut M) {

        methods.add_method_mut("insert", |_, this, elem: String| {
            this.0.insert(elem);
            Ok(())
        });

        methods.add_method_mut("remove", |_, this, elem: String| {
            this.0.remove(&elem);
            Ok(())
        });

        methods.add_method("contains", |_, this, elem: String| {
            Ok(this.0.contains(&elem))
        });

        methods.add_method_mut("clear", |_, this, _: ()| {
            this.0.clear();
            Ok(())
        });

        methods.add_method("is_empty", |_, this, _: ()| {
            Ok(this.0.is_empty())
        });

        methods.add_method("difference", |_, this, other: StringSet| {
            let result: HashSet<String> = this.0.difference(&other.0).cloned().collect();
            Ok(StringSet(result))
        });

        methods.add_method("symmetric", |_, this, other: StringSet| {
            let result: HashSet<String> = this.0.symmetric_difference(&other.0).cloned().collect();
            Ok(StringSet(result))
        });

        methods.add_method("intersection", |_, this, other: StringSet| {
            let result: HashSet<String> = this.0.intersection(&other.0).cloned().collect();
            Ok(StringSet(result))
        });

        methods.add_method("union", |_, this, other: StringSet| {
            let result: HashSet<String> = this.0.union(&other.0).cloned().collect();
            Ok(StringSet(result))
        });

        methods.add_method("is_disjoint", |_, this, other: StringSet| {
            Ok(this.0.is_disjoint(&other.0))
        });

        methods.add_method("is_subset", |_, this, other: StringSet| {
            Ok(this.0.is_subset(&other.0))
        });

        methods.add_method("is_superset", |_, this, other: StringSet| {
            Ok(this.0.is_superset(&other.0))
        });

        methods.add_meta_method(MetaMethod::ToString, |_, this, _: ()| {
            Ok(format!("{:?}", this.0))
        });

        methods.add_meta_method(MetaMethod::Len, |_, this, _: ()| {
            Ok(this.0.len())
        });

        methods.add_method("clone", |_, this, _: ()| {
            Ok(StringSet(this.0.clone()))
        });

        methods.add_method("to_table", |lua, this, _: ()| {
            let table = lua.create_sequence_from(this.0.iter().cloned())?;
            Ok(table)
        });
    }
}

pub fn init(lua: &Lua) -> Result<(), LuaError> {

    type Set = HashSet<String>;

    fn from_table (table: Table) -> Result<Set, LuaError> {
        let mut set = HashSet::new();
        for elem in table.sequence_values() {
            set.insert(elem?);
        }
        Ok(set)
    }

    fn get_sets (lua: &Lua, args: (Value, Value)) -> Result<(Set, Set), LuaError> {
        fn from_value (a: Value, lua: &Lua) -> Result<Set, LuaError> {
            Ok(match a {
                Value::Table(t) => from_table(t)?,
                a@_ => StringSet::from_lua(a, lua)?.0
            })
        }
        let (a, b) = args;
        Ok((from_value(a, lua)?, from_value(b, lua)?))
    }

    let module = lua.create_table()?;

    module.set("create", lua.create_function(
        |_, _: ()|  Ok(StringSet(HashSet::new()))
    )? )?;

    module.set("from_table", lua.create_function(
        |_, t: Table|  Ok(StringSet(from_table(t)?))
    )? )?;

    let g = lua.globals();
    g.set("stringset", module)?;

    g.set("difference", lua.create_function( |lua, args: (Value, Value)| {
        let (a, b) = get_sets(lua, args)?;
        let c = a.difference(&b).cloned().collect();
        Ok(StringSet(c))
    })? )?;

    g.set("symmetric", lua.create_function( |lua, args: (Value, Value)| {
        let (a, b) = get_sets(lua, args)?;
        let c = a.symmetric_difference(&b).cloned().collect();
        Ok(StringSet(c))
    })? )?;

    g.set("intersection", lua.create_function( |lua, args: (Value, Value)| {
        let (a, b) = get_sets(lua, args)?;
        let c = a.intersection(&b).cloned().collect();
        Ok(StringSet(c))
    })? )?;

    g.set("union", lua.create_function( |lua, args: (Value, Value)| {
        let (a, b) = get_sets(lua, args)?;
        let c = a.union(&b).cloned().collect();
        Ok(StringSet(c))
    })? )?;

    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn direct_methods() {
        let lua = Lua::new();
        init(&lua).unwrap();
        lua.exec::<_, Value>(r#"
            local a = stringset.create()
            a:insert("Colombia")
            a:insert("Canada")
            a:insert("China")

            local b = stringset.create()
            b:insert("Venezuela")
            b:insert("Colombia")
            b:insert("Brazil")

            local c = a:union(b)
            assert(#c == 5)
            assert(c:contains("Colombia"))
            assert(c:contains("Venezuela"))
            assert(c:contains("Canada"))

            c = a:intersection(b)
            assert(#c == 1)
            assert(c:contains("Colombia"))
            assert(not c:contains("Canada"))

            c = a:difference(b)
            assert(#c == 2)
            assert(c:contains("Canada"))
            assert(not c:contains("Colombia"))

            c = a:symmetric(b)
            assert(#c == 4)
            assert(c:contains("Canada"))
            assert(c:contains("Venezuela"))
            assert(not c:contains("Colombia"))

            d = a:clone()
            d:remove("Canada")
            assert(a:is_superset(d))
            assert(d:is_subset(a))
            assert(not a:is_disjoint(b))
            assert(a:is_disjoint(b:difference(a)))

            d:clear()
            assert(d:is_empty())

            local t = a:union(b):to_table()
            for i, v in ipairs(t) do
                print(i, v)
            end
        "#, None).unwrap();
    }

    #[test]
    fn shortcut_syntax() {
        let lua = Lua::new();
        init(&lua).unwrap();
        lua.exec::<_, Value>(r#"
            local a = stringset.create()
            a:insert("Canada")
            a:insert("China")
            a:insert("Colombia")

            local b = {"Colombia", "Brazil", "Venezuela"}
            local c = intersection(a, b)
            assert(c:contains("Colombia"))
            assert(not c:contains("Canada"))

            union(a, b)
            difference(a, b)
            symmetric(a, b)
        "#, None).unwrap();
    }
}