pingora_load_balancing/selection/
weighted.rsuse super::{Backend, BackendIter, BackendSelection, SelectionAlgorithm};
use fnv::FnvHasher;
use std::collections::BTreeSet;
use std::sync::Arc;
pub struct Weighted<H = FnvHasher> {
backends: Box<[Backend]>,
weighted: Box<[u16]>,
algorithm: H,
}
impl<H: SelectionAlgorithm> BackendSelection for Weighted<H> {
type Iter = WeightedIterator<H>;
fn build(backends: &BTreeSet<Backend>) -> Self {
assert!(
backends.len() <= u16::MAX as usize,
"support up to 2^16 backends"
);
let backends = Vec::from_iter(backends.iter().cloned()).into_boxed_slice();
let mut weighted = Vec::with_capacity(backends.len());
for (index, b) in backends.iter().enumerate() {
for _ in 0..b.weight {
weighted.push(index as u16);
}
}
Weighted {
backends,
weighted: weighted.into_boxed_slice(),
algorithm: H::new(),
}
}
fn iter(self: &Arc<Self>, key: &[u8]) -> Self::Iter {
WeightedIterator::new(key, self.clone())
}
}
pub struct WeightedIterator<H> {
index: u64,
backend: Arc<Weighted<H>>,
first: bool,
}
impl<H: SelectionAlgorithm> WeightedIterator<H> {
fn new(input: &[u8], backend: Arc<Weighted<H>>) -> Self {
Self {
index: backend.algorithm.next(input),
backend,
first: true,
}
}
}
impl<H: SelectionAlgorithm> BackendIter for WeightedIterator<H> {
fn next(&mut self) -> Option<&Backend> {
if self.backend.backends.is_empty() {
return None;
}
if self.first {
self.first = false;
let len = self.backend.weighted.len();
let index = self.backend.weighted[self.index as usize % len];
Some(&self.backend.backends[index as usize])
} else {
self.index = self.backend.algorithm.next(&self.index.to_le_bytes());
let len = self.backend.backends.len();
Some(&self.backend.backends[self.index as usize % len])
}
}
}
#[cfg(test)]
mod test {
use super::super::algorithms::*;
use super::*;
use std::collections::HashMap;
#[test]
fn test_fnv() {
let b1 = Backend::new("1.1.1.1:80").unwrap();
let mut b2 = Backend::new("1.0.0.1:80").unwrap();
b2.weight = 10; let b3 = Backend::new("1.0.0.255:80").unwrap();
let backends = BTreeSet::from_iter([b1.clone(), b2.clone(), b3.clone()]);
let hash: Arc<Weighted> = Arc::new(Weighted::build(&backends));
let mut iter = hash.iter(b"test");
assert_eq!(iter.next(), Some(&b2));
assert_eq!(iter.next(), Some(&b2));
assert_eq!(iter.next(), Some(&b2));
assert_eq!(iter.next(), Some(&b1));
assert_eq!(iter.next(), Some(&b3));
assert_eq!(iter.next(), Some(&b2));
assert_eq!(iter.next(), Some(&b2));
assert_eq!(iter.next(), Some(&b1));
assert_eq!(iter.next(), Some(&b2));
assert_eq!(iter.next(), Some(&b3));
assert_eq!(iter.next(), Some(&b1));
let mut iter = hash.iter(b"test1");
assert_eq!(iter.next(), Some(&b2));
let mut iter = hash.iter(b"test2");
assert_eq!(iter.next(), Some(&b2));
let mut iter = hash.iter(b"test3");
assert_eq!(iter.next(), Some(&b3));
let mut iter = hash.iter(b"test4");
assert_eq!(iter.next(), Some(&b1));
let mut iter = hash.iter(b"test5");
assert_eq!(iter.next(), Some(&b2));
let mut iter = hash.iter(b"test6");
assert_eq!(iter.next(), Some(&b2));
let mut iter = hash.iter(b"test7");
assert_eq!(iter.next(), Some(&b2));
}
#[test]
fn test_round_robin() {
let b1 = Backend::new("1.1.1.1:80").unwrap();
let mut b2 = Backend::new("1.0.0.1:80").unwrap();
b2.weight = 8; let b3 = Backend::new("1.0.0.255:80").unwrap();
let backends = BTreeSet::from_iter([b1.clone(), b2.clone(), b3.clone()]);
let hash: Arc<Weighted<RoundRobin>> = Arc::new(Weighted::build(&backends));
let mut iter = hash.iter(b"test");
assert_eq!(iter.next(), Some(&b2));
assert_eq!(iter.next(), Some(&b3));
assert_eq!(iter.next(), Some(&b1));
assert_eq!(iter.next(), Some(&b2));
assert_eq!(iter.next(), Some(&b3));
let mut iter = hash.iter(b"test1");
assert_eq!(iter.next(), Some(&b2));
let mut iter = hash.iter(b"test1");
assert_eq!(iter.next(), Some(&b2));
let mut iter = hash.iter(b"test1");
assert_eq!(iter.next(), Some(&b2));
let mut iter = hash.iter(b"test1");
assert_eq!(iter.next(), Some(&b3));
let mut iter = hash.iter(b"test1");
assert_eq!(iter.next(), Some(&b1));
let mut iter = hash.iter(b"test1");
assert_eq!(iter.next(), Some(&b2));
let mut iter = hash.iter(b"test1");
assert_eq!(iter.next(), Some(&b2));
}
#[test]
fn test_random() {
let b1 = Backend::new("1.1.1.1:80").unwrap();
let mut b2 = Backend::new("1.0.0.1:80").unwrap();
b2.weight = 8; let b3 = Backend::new("1.0.0.255:80").unwrap();
let backends = BTreeSet::from_iter([b1.clone(), b2.clone(), b3.clone()]);
let hash: Arc<Weighted<Random>> = Arc::new(Weighted::build(&backends));
let mut count = HashMap::new();
count.insert(b1.clone(), 0);
count.insert(b2.clone(), 0);
count.insert(b3.clone(), 0);
for _ in 0..100 {
let mut iter = hash.iter(b"test");
*count.get_mut(iter.next().unwrap()).unwrap() += 1;
}
let b2_count = *count.get(&b2).unwrap();
assert!((70..=90).contains(&b2_count));
}
}