random_pick/
lib.rs

1/*!
2# Random Pick
3Pick an element from a slice randomly by given weights.
4
5## Examples
6
7```rust
8enum Prize {
9    Legendary,
10    Rare,
11    Enchanted,
12    Common,
13}
14
15let prize_list = [Prize::Legendary, Prize::Rare, Prize::Enchanted, Prize::Common]; // available prizes
16
17let slice = &prize_list;
18let weights = [1, 5, 15, 30]; // a scale of chance of picking each kind of prize
19
20let n = 1000000;
21let mut counter = [0usize; 4];
22
23for _ in 0..n {
24    let picked_item = random_pick::pick_from_slice(slice, &weights).unwrap();
25
26    match picked_item {
27        Prize::Legendary=>{
28            counter[0] += 1;
29           }
30        Prize::Rare=>{
31            counter[1] += 1;
32        }
33        Prize::Enchanted=>{
34            counter[2] += 1;
35        }
36        Prize::Common=>{
37            counter[3] += 1;
38        }
39    }
40}
41
42println!("{}", counter[0]); // Should be close to 20000
43println!("{}", counter[1]); // Should be close to 100000
44println!("{}", counter[2]); // Should be close to 300000
45println!("{}", counter[3]); // Should be close to 600000
46```
47
48The length of the slice is usually an integral multiple (larger than zero) of that of weights.
49
50If you have multiple slices, you don't need to use extra space to concat them, just use the `pick_from_multiple_slices` function, instead of `pick_from_slice`.
51
52Besides picking a single element from a slice or slices, you can also use `pick_multiple_from_slice` and `pick_multiple_from_multiple_slices` functions. Their overhead is lower than that of non-multiple-pick functions with extra loops.
53*/
54
55use random_number::rand::thread_rng;
56use random_number::random;
57
58const MAX_NUMBER: usize = usize::MAX;
59
60/// Pick an element from a slice randomly by given weights.
61pub fn pick_from_slice<'a, T>(slice: &'a [T], weights: &'a [usize]) -> Option<&'a T> {
62    let slice_len = slice.len();
63
64    let index = gen_usize_with_weights(slice_len, weights)?;
65
66    Some(&slice[index])
67}
68
69/// Pick an element from multiple slices randomly by given weights.
70pub fn pick_from_multiple_slices<'a, T>(slices: &[&'a [T]], weights: &'a [usize]) -> Option<&'a T> {
71    let len: usize = slices.iter().map(|slice| slice.len()).sum();
72
73    let mut index = gen_usize_with_weights(len, weights)?;
74
75    for slice in slices {
76        let len = slice.len();
77
78        if index < len {
79            return Some(&slice[index]);
80        } else {
81            index -= len;
82        }
83    }
84
85    None
86}
87
88/// Pick multiple elements from a slice randomly by given weights.
89pub fn pick_multiple_from_slice<'a, T>(
90    slice: &'a [T],
91    weights: &'a [usize],
92    count: usize,
93) -> Vec<&'a T> {
94    let slice_len = slice.len();
95
96    gen_multiple_usize_with_weights(slice_len, weights, count)
97        .iter()
98        .map(|&index| &slice[index])
99        .collect()
100}
101
102/// Pick multiple elements from multiple slices randomly by given weights.
103pub fn pick_multiple_from_multiple_slices<'a, T>(
104    slices: &[&'a [T]],
105    weights: &'a [usize],
106    count: usize,
107) -> Vec<&'a T> {
108    let len: usize = slices.iter().map(|slice| slice.len()).sum();
109
110    gen_multiple_usize_with_weights(len, weights, count)
111        .iter()
112        .map(|index| {
113            let mut index = *index;
114
115            let mut s = slices[0];
116
117            for slice in slices {
118                let len = slice.len();
119
120                if index < len {
121                    s = slice;
122                    break;
123                } else {
124                    index -= len;
125                }
126            }
127
128            &s[index]
129        })
130        .collect()
131}
132
133/// Get a usize value by given weights.
134pub fn gen_usize_with_weights(high: usize, weights: &[usize]) -> Option<usize> {
135    let weights_len = weights.len();
136
137    if weights_len == 0 || high == 0 {
138        return None;
139    } else if weights_len == 1 {
140        if weights[0] == 0 {
141            return None;
142        }
143
144        return Some(random!(0..high));
145    } else {
146        let mut weights_sum = 0f64;
147        let mut max_weight = 0;
148
149        for w in weights.iter().copied() {
150            weights_sum += w as f64;
151            if w > max_weight {
152                max_weight = w;
153            }
154        }
155
156        if max_weight == 0 {
157            return None;
158        }
159
160        let mut rng = thread_rng();
161
162        let index_scale = (high as f64) / (weights_len as f64);
163
164        let weights_scale = (MAX_NUMBER as f64) / weights_sum;
165
166        let rnd = random!(0..=MAX_NUMBER, rng) as f64;
167
168        let mut temp = 0f64;
169
170        for (i, w) in weights.iter().copied().enumerate() {
171            temp += (w as f64) * weights_scale;
172            if temp > rnd {
173                let index = ((i as f64) * index_scale) as usize;
174
175                return Some(random!(index..((((i + 1) as f64) * index_scale) as usize), rng));
176            }
177        }
178    }
179
180    None
181}
182
183/// Get multiple usize values by given weights.
184pub fn gen_multiple_usize_with_weights(high: usize, weights: &[usize], count: usize) -> Vec<usize> {
185    let mut result: Vec<usize> = Vec::with_capacity(count);
186
187    let weights_len = weights.len();
188
189    if weights_len > 0 && high > 0 {
190        if weights_len == 1 {
191            if weights[0] != 0 {
192                let mut rng = thread_rng();
193
194                for _ in 0..count {
195                    result.push(random!(0..high, rng));
196                }
197            }
198        } else {
199            let mut weights_sum = 0f64;
200            let mut max_weight = 0;
201
202            for w in weights.iter().copied() {
203                weights_sum += w as f64;
204                if w > max_weight {
205                    max_weight = w;
206                }
207            }
208
209            if max_weight > 0 {
210                let index_scale = (high as f64) / (weights_len as f64);
211
212                let weights_scale = (MAX_NUMBER as f64) / weights_sum;
213
214                let mut rng = thread_rng();
215
216                for _ in 0..count {
217                    let rnd = random!(0..=MAX_NUMBER, rng) as f64;
218
219                    let mut temp = 0f64;
220
221                    for (i, w) in weights.iter().copied().enumerate() {
222                        temp += (w as f64) * weights_scale;
223                        if temp > rnd {
224                            let index = ((i as f64) * index_scale) as usize;
225
226                            result.push(random!(
227                                index..((((i + 1) as f64) * index_scale) as usize),
228                                rng
229                            ));
230                            break;
231                        }
232                    }
233                }
234            }
235        }
236    }
237
238    result
239}