use ahash::RandomState;
use std::hash::Hash;
use std::marker::PhantomData;
use std::time::{Duration, Instant};
use tinyufo::TinyUfo;
mod read_through;
pub use read_through::{Lookup, MultiLookup, RTCache};
#[derive(Debug, PartialEq, Eq)]
pub enum CacheStatus {
Hit,
Miss,
Expired,
LockHit,
}
impl CacheStatus {
pub fn as_str(&self) -> &str {
match self {
Self::Hit => "hit",
Self::Miss => "miss",
Self::Expired => "expired",
Self::LockHit => "lock_hit",
}
}
pub fn is_hit(&self) -> bool {
match self {
CacheStatus::Hit | CacheStatus::LockHit => true,
CacheStatus::Miss | CacheStatus::Expired => false,
}
}
}
#[derive(Debug, Clone)]
struct Node<T: Clone> {
pub value: T,
expire_on: Option<Instant>,
}
impl<T: Clone> Node<T> {
fn new(value: T, ttl: Option<Duration>) -> Self {
let expire_on = match ttl {
Some(t) => Instant::now().checked_add(t),
None => None,
};
Node { value, expire_on }
}
fn will_expire_at(&self, time: &Instant) -> bool {
match self.expire_on.as_ref() {
Some(t) => t <= time,
None => false,
}
}
fn is_expired(&self) -> bool {
self.will_expire_at(&Instant::now())
}
}
pub struct MemoryCache<K: Hash, T: Clone> {
store: TinyUfo<u64, Node<T>>,
_key_type: PhantomData<K>,
pub(crate) hasher: RandomState,
}
impl<K: Hash, T: Clone + Send + Sync + 'static> MemoryCache<K, T> {
pub fn new(size: usize) -> Self {
MemoryCache {
store: TinyUfo::new(size, size),
_key_type: PhantomData,
hasher: RandomState::new(),
}
}
pub fn get(&self, key: &K) -> (Option<T>, CacheStatus) {
let hashed_key = self.hasher.hash_one(key);
if let Some(n) = self.store.get(&hashed_key) {
if !n.is_expired() {
(Some(n.value), CacheStatus::Hit)
} else {
(None, CacheStatus::Expired)
}
} else {
(None, CacheStatus::Miss)
}
}
pub fn put(&self, key: &K, value: T, ttl: Option<Duration>) {
if let Some(t) = ttl {
if t.is_zero() {
return;
}
}
let hashed_key = self.hasher.hash_one(key);
let node = Node::new(value, ttl);
self.store.put(hashed_key, node, 1);
}
pub fn remove(&self, key: &K) {
let hashed_key = self.hasher.hash_one(key);
self.store.remove(&hashed_key);
}
pub(crate) fn force_put(&self, key: &K, value: T, ttl: Option<Duration>) {
if let Some(t) = ttl {
if t.is_zero() {
return;
}
}
let hashed_key = self.hasher.hash_one(key);
let node = Node::new(value, ttl);
self.store.force_put(hashed_key, node, 1);
}
pub fn multi_get<'a, I>(&self, keys: I) -> Vec<(Option<T>, CacheStatus)>
where
I: Iterator<Item = &'a K>,
K: 'a,
{
let mut resp = Vec::with_capacity(keys.size_hint().0);
for key in keys {
resp.push(self.get(key));
}
resp
}
pub fn multi_get_with_miss<'a, I>(&self, keys: I) -> (Vec<(Option<T>, CacheStatus)>, Vec<&'a K>)
where
I: Iterator<Item = &'a K>,
K: 'a,
{
let mut resp = Vec::with_capacity(keys.size_hint().0);
let mut missed = Vec::with_capacity(keys.size_hint().0 / 2);
for key in keys {
let (lookup, cache_status) = self.get(key);
if lookup.is_none() {
missed.push(key);
}
resp.push((lookup, cache_status));
}
(resp, missed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread::sleep;
#[test]
fn test_get() {
let cache: MemoryCache<i32, ()> = MemoryCache::new(10);
let (res, hit) = cache.get(&1);
assert_eq!(res, None);
assert_eq!(hit, CacheStatus::Miss);
}
#[test]
fn test_put_get() {
let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
let (res, hit) = cache.get(&1);
assert_eq!(res, None);
assert_eq!(hit, CacheStatus::Miss);
cache.put(&1, 2, None);
let (res, hit) = cache.get(&1);
assert_eq!(res.unwrap(), 2);
assert_eq!(hit, CacheStatus::Hit);
}
#[test]
fn test_put_get_remove() {
let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
let (res, hit) = cache.get(&1);
assert_eq!(res, None);
assert_eq!(hit, CacheStatus::Miss);
cache.put(&1, 2, None);
cache.put(&3, 4, None);
cache.put(&5, 6, None);
let (res, hit) = cache.get(&1);
assert_eq!(res.unwrap(), 2);
assert_eq!(hit, CacheStatus::Hit);
cache.remove(&1);
cache.remove(&3);
let (res, hit) = cache.get(&1);
assert_eq!(res, None);
assert_eq!(hit, CacheStatus::Miss);
let (res, hit) = cache.get(&3);
assert_eq!(res, None);
assert_eq!(hit, CacheStatus::Miss);
let (res, hit) = cache.get(&5);
assert_eq!(res.unwrap(), 6);
assert_eq!(hit, CacheStatus::Hit);
}
#[test]
fn test_get_expired() {
let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
let (res, hit) = cache.get(&1);
assert_eq!(res, None);
assert_eq!(hit, CacheStatus::Miss);
cache.put(&1, 2, Some(Duration::from_secs(1)));
sleep(Duration::from_millis(1100));
let (res, hit) = cache.get(&1);
assert_eq!(res, None);
assert_eq!(hit, CacheStatus::Expired);
}
#[test]
fn test_eviction() {
let cache: MemoryCache<i32, i32> = MemoryCache::new(2);
cache.put(&1, 2, None);
cache.put(&2, 4, None);
cache.put(&3, 6, None);
let (res, hit) = cache.get(&1);
assert_eq!(res, None);
assert_eq!(hit, CacheStatus::Miss);
let (res, hit) = cache.get(&2);
assert_eq!(res.unwrap(), 4);
assert_eq!(hit, CacheStatus::Hit);
let (res, hit) = cache.get(&3);
assert_eq!(res.unwrap(), 6);
assert_eq!(hit, CacheStatus::Hit);
}
#[test]
fn test_multi_get() {
let cache: MemoryCache<i32, i32> = MemoryCache::new(10);
cache.put(&2, -2, None);
let keys: Vec<i32> = vec![1, 2, 3];
let resp = cache.multi_get(keys.iter());
assert_eq!(resp[0].0, None);
assert_eq!(resp[0].1, CacheStatus::Miss);
assert_eq!(resp[1].0.unwrap(), -2);
assert_eq!(resp[1].1, CacheStatus::Hit);
assert_eq!(resp[2].0, None);
assert_eq!(resp[2].1, CacheStatus::Miss);
let (resp, missed) = cache.multi_get_with_miss(keys.iter());
assert_eq!(resp[0].0, None);
assert_eq!(resp[0].1, CacheStatus::Miss);
assert_eq!(resp[1].0.unwrap(), -2);
assert_eq!(resp[1].1, CacheStatus::Hit);
assert_eq!(resp[2].0, None);
assert_eq!(resp[2].1, CacheStatus::Miss);
assert_eq!(missed[0], &1);
assert_eq!(missed[1], &3);
}
}